brevitas.core.function_wrapper package#
Submodules#
brevitas.core.function_wrapper.clamp module#
ScriptModule wrappers for various variants of clamping.
- class brevitas.core.function_wrapper.clamp.ClampMin(min_val)[source]#
Bases:
Module
ScriptModule wrapper for
clamp_min()
.Examples
>>> clamp_min = ClampMin(min_val=-2.0) >>> clamp_min(torch.tensor(-3.0)) tensor(-2.)
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.clamp.FloatClamp(tensor_clamp_impl, signed, inf_values=None, nan_values=None, max_available_float=None, saturating=True, device=None, dtype=None)[source]#
Bases:
Module
” ScriptModule for clamping minifloat formats to their inf/NaN implementations.
Currently, inf/NaN codes have to be encoded through the mantissa. I.e. setting inf to 1101.111 (E4M3) is not a valid code.
- forward(x, exponent_bit_width, mantissa_bit_width, exponent_bias)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.clamp.ScalarClamp(min_val, max_val)[source]#
Bases:
Module
ScriptModule wrapper for
clamp()
.Examples
>>> scalar_clamp = ScalarClamp(min_val=-2.0, max_val=2.0) >>> scalar_clamp(torch.tensor([-3.0, 3.0])) tensor([-2., 2.])
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.clamp.TensorClamp[source]#
Bases:
Module
ScriptModule wrapper for
tensor_clamp()
.Examples
>>> tensor_clamp = TensorClamp() >>> min_val = torch.tensor(-2.0) >>> max_val = torch.tensor(2.0) >>> tensor_clamp(torch.tensor([-3.0, 3.0]), min_val, max_val) tensor([-2., 2.])
- forward(x, min_val, max_val)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
brevitas.core.function_wrapper.misc module#
A collection of miscellaneous ScriptModule used in various quantizers.
- class brevitas.core.function_wrapper.misc.Identity[source]#
Bases:
Module
Identity ScriptModule.
Examples
>>> identity = Identity() >>> x = torch.randn(size=[10,]) >>> y = identity(x) >>> y is x True
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.misc.InplaceLogTwo[source]#
Bases:
Module
Module wrapper for
log2_()
.Examples
>>> inplace_log_two = InplaceLogTwo() >>> x = torch.tensor(8.0) >>> inplace_log_two(x) >>> x tensor(3.)
Notes
Inplace operations in TorchScript can be problematic, compilation is disabled.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.misc.LogTwo[source]#
Bases:
Module
ScriptModule wrapper for
log2()
.Examples
>>> log_two = LogTwo() >>> x = torch.tensor(8.0) >>> log_two(x) tensor(3.)
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.misc.PowerOfTwo[source]#
Bases:
Module
ScriptModule implementation of 2.0 ** x.
Examples
>>> power_of_two = PowerOfTwo() >>> x = torch.tensor(5.0) >>> power_of_two(x) tensor(32.)
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
brevitas.core.function_wrapper.ops_ste module#
ScriptModule wrappers of various functions defined in ops_ste
.
- class brevitas.core.function_wrapper.ops_ste.CeilSte[source]#
Bases:
Module
ScriptModule wrapper for
ceil_ste()
.- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.DPURoundSte[source]#
Bases:
Module
ScriptModule wrapper for
dpu_round_ste()
.- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.FloorSte[source]#
Bases:
Module
ScriptModule wrapper for
floor_ste()
.- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.InplaceTensorClampSte[source]#
Bases:
Module
ScriptModule wrapper for
tensor_clamp_ste_()
.- forward(x, min_val, max_val)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.RoundSte[source]#
Bases:
Module
ScriptModule wrapper for
round_ste()
.- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.RoundToZeroSte[source]#
Bases:
Module
ScriptModule wrapper for
round_to_zero_ste()
.- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.ScalarClampMinSte(min_val)[source]#
Bases:
Module
ScriptModule wrapper for
scalar_clamp_min_ste()
.- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.ops_ste.TensorClampSte[source]#
Bases:
Module
ScriptModule wrapper for
tensor_clamp_ste()
.- forward(x, min_val, max_val)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
brevitas.core.function_wrapper.shape module#
ScriptModule classes to compute the view of a tensor according to various different criteria.
- class brevitas.core.function_wrapper.shape.DynamicOverSubChannelBlockView(group_size, group_dim)[source]#
Bases:
Module
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.OverBatchOverOutputChannelView(permute_dims=None)[source]#
Bases:
Module
ScriptModule to compute the
over_batch_over_output_channels()
view of an input tensor.Examples
>>> view_module = OverBatchOverOutputChannelView() >>> y = view_module(torch.empty(size=[8, 10, 5, 5])) >>> y.shape torch.Size([8, 10, 25])
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.OverBatchOverTensorView(permute_dims=None)[source]#
Bases:
Module
ScriptMoodule to compute the
over_batch_over_tensor()
view of an input tensor.Examples
>>> view_module = OverBatchOverTensorView() >>> y = view_module(torch.empty(size=[8, 10, 5, 5])) >>> y.shape torch.Size([8, 250])
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.OverOutputChannelView(permute_dims=None)[source]#
Bases:
Module
ScriptMoodule to compute the
over_output_channels()
view of an input tensor.Examples
>>> view_module = OverOutputChannelView(permute_dims=None) >>> y = view_module(torch.empty(size=[16, 8, 5, 5])) >>> y.shape torch.Size([16, 200])
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.OverOutputFeaturesView(permute_dims=None)[source]#
Bases:
Module
ScriptModule to compute the
over_output_features()
view of an input tensor.Examples
>>> view_module = OverOutputFeaturesView() >>> y = view_module(torch.empty(size=[8, 10, 25])) >>> y.shape torch.Size([80, 25])
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.OverSubChannelBlockView(expanded_groupwise_shape, group_size, group_dim)[source]#
Bases:
Module
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.OverTensorView[source]#
Bases:
Module
ScriptMoodule to compute the
over_tensor()
view of an input tensor.Examples
>>> view_module = OverTensorView() >>> y = view_module(torch.empty(size=[16, 6, 5, 5])) >>> y.shape torch.Size([2400])
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.PermuteDims(permute_dims)[source]#
Bases:
Module
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.function_wrapper.shape.StatsInputViewShapeImpl[source]#
Bases:
object
Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor. All adhere to the same interface.
- DYNAMIC_OVER_SUBCHANNEL_BLOCK#
alias of
DynamicOverSubChannelBlockView
- OVER_BATCH_OVER_OUTPUT_CHANNELS#
alias of
OverBatchOverOutputChannelView
- OVER_BATCH_OVER_TENSOR#
alias of
OverBatchOverTensorView
- OVER_OUTPUT_CHANNELS#
alias of
OverOutputChannelView
- OVER_OUTPUT_FEATURES#
alias of
OverOutputFeaturesView
- OVER_SUBCHANNEL_BLOCK#
alias of
OverSubChannelBlockView
- OVER_TENSOR#
alias of
OverTensorView