brevitas.core.function_wrapper package#

Submodules#

brevitas.core.function_wrapper.clamp module#

ScriptModule wrappers for various variants of clamping.

class brevitas.core.function_wrapper.clamp.ClampMin(min_val)[source]#

Bases: Module

ScriptModule wrapper for clamp_min().

Examples

>>> clamp_min = ClampMin(min_val=-2.0)
>>> clamp_min(torch.tensor(-3.0))
tensor(-2.)
forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.clamp.FloatClamp(tensor_clamp_impl, signed, inf_values=None, nan_values=None, max_available_float=None, saturating=True, device=None, dtype=None)[source]#

Bases: Module

” ScriptModule for clamping minifloat formats to their inf/NaN implementations.

Currently, inf/NaN codes have to be encoded through the mantissa. I.e. setting inf to 1101.111 (E4M3) is not a valid code.

forward(x, exponent_bit_width, mantissa_bit_width, exponent_bias)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inf_nan_clamp(x, inf_mask, p_max_val_mask, n_max_val_mask)[source]#
saturating_clamp(x, max_value, min_value)[source]#
class brevitas.core.function_wrapper.clamp.ScalarClamp(min_val, max_val)[source]#

Bases: Module

ScriptModule wrapper for clamp().

Examples

>>> scalar_clamp = ScalarClamp(min_val=-2.0, max_val=2.0)
>>> scalar_clamp(torch.tensor([-3.0, 3.0]))
tensor([-2.,  2.])
forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.clamp.TensorClamp[source]#

Bases: Module

ScriptModule wrapper for tensor_clamp().

Examples

>>> tensor_clamp = TensorClamp()
>>> min_val = torch.tensor(-2.0)
>>> max_val = torch.tensor(2.0)
>>> tensor_clamp(torch.tensor([-3.0, 3.0]), min_val, max_val)
tensor([-2.,  2.])
forward(x, min_val, max_val)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

brevitas.core.function_wrapper.misc module#

A collection of miscellaneous ScriptModule used in various quantizers.

class brevitas.core.function_wrapper.misc.Identity[source]#

Bases: Module

Identity ScriptModule.

Examples

>>> identity = Identity()
>>> x = torch.randn(size=[10,])
>>> y = identity(x)
>>> y is x
True
forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.misc.InplaceLogTwo[source]#

Bases: Module

Module wrapper for log2_().

Examples

>>> inplace_log_two = InplaceLogTwo()
>>> x = torch.tensor(8.0)
>>> inplace_log_two(x)
>>> x
tensor(3.)

Notes

Inplace operations in TorchScript can be problematic, compilation is disabled.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.misc.LogTwo[source]#

Bases: Module

ScriptModule wrapper for log2().

Examples

>>> log_two = LogTwo()
>>> x = torch.tensor(8.0)
>>> log_two(x)
tensor(3.)
forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.misc.PowerOfTwo[source]#

Bases: Module

ScriptModule implementation of 2.0 ** x.

Examples

>>> power_of_two = PowerOfTwo()
>>> x = torch.tensor(5.0)
>>> power_of_two(x)
tensor(32.)
forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

brevitas.core.function_wrapper.ops_ste module#

ScriptModule wrappers of various functions defined in ops_ste.

class brevitas.core.function_wrapper.ops_ste.CeilSte[source]#

Bases: Module

ScriptModule wrapper for ceil_ste().

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.ops_ste.DPURoundSte[source]#

Bases: Module

ScriptModule wrapper for dpu_round_ste().

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.ops_ste.FloorSte[source]#

Bases: Module

ScriptModule wrapper for floor_ste().

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.ops_ste.InplaceTensorClampSte[source]#

Bases: Module

ScriptModule wrapper for tensor_clamp_ste_().

forward(x, min_val, max_val)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.ops_ste.RoundSte[source]#

Bases: Module

ScriptModule wrapper for round_ste().

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.ops_ste.RoundToZeroSte[source]#

Bases: Module

ScriptModule wrapper for round_to_zero_ste().

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.ops_ste.ScalarClampMinSte(min_val)[source]#

Bases: Module

ScriptModule wrapper for scalar_clamp_min_ste().

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.ops_ste.TensorClampSte[source]#

Bases: Module

ScriptModule wrapper for tensor_clamp_ste().

forward(x, min_val, max_val)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

brevitas.core.function_wrapper.shape module#

ScriptModule classes to compute the view of a tensor according to various different criteria.

class brevitas.core.function_wrapper.shape.DynamicOverSubChannelBlockView(group_size, group_dim)[source]#

Bases: Module

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.shape.OverBatchOverOutputChannelView(permute_dims=None)[source]#

Bases: Module

ScriptModule to compute the over_batch_over_output_channels() view of an input tensor.

Examples

>>> view_module = OverBatchOverOutputChannelView()
>>> y = view_module(torch.empty(size=[8, 10, 5, 5]))
>>> y.shape
torch.Size([8, 10, 25])
forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.shape.OverBatchOverTensorView(permute_dims=None)[source]#

Bases: Module

ScriptMoodule to compute the over_batch_over_tensor() view of an input tensor.

Examples

>>> view_module = OverBatchOverTensorView()
>>> y = view_module(torch.empty(size=[8, 10, 5, 5]))
>>> y.shape
torch.Size([8, 250])
forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.shape.OverOutputChannelView(permute_dims=None)[source]#

Bases: Module

ScriptMoodule to compute the over_output_channels() view of an input tensor.

Examples

>>> view_module = OverOutputChannelView(permute_dims=None)
>>> y = view_module(torch.empty(size=[16, 8, 5, 5]))
>>> y.shape
torch.Size([16, 200])
forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.shape.OverOutputFeaturesView(permute_dims=None)[source]#

Bases: Module

ScriptModule to compute the over_output_features() view of an input tensor.

Examples

>>> view_module = OverOutputFeaturesView()
>>> y = view_module(torch.empty(size=[8, 10, 25]))
>>> y.shape
torch.Size([80, 25])
forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.shape.OverSubChannelBlockView(expanded_groupwise_shape, group_size, group_dim)[source]#

Bases: Module

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.shape.OverTensorView[source]#

Bases: Module

ScriptMoodule to compute the over_tensor() view of an input tensor.

Examples

>>> view_module = OverTensorView()
>>> y = view_module(torch.empty(size=[16, 6, 5, 5]))
>>> y.shape
torch.Size([2400])
forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.shape.PermuteDims(permute_dims)[source]#

Bases: Module

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.function_wrapper.shape.StatsInputViewShapeImpl[source]#

Bases: object

Enum-like object to collect pointers to variants of ScriptModules that perform a view on a tensor. All adhere to the same interface.

DYNAMIC_OVER_SUBCHANNEL_BLOCK#

alias of DynamicOverSubChannelBlockView

OVER_BATCH_OVER_OUTPUT_CHANNELS#

alias of OverBatchOverOutputChannelView

OVER_BATCH_OVER_TENSOR#

alias of OverBatchOverTensorView

OVER_OUTPUT_CHANNELS#

alias of OverOutputChannelView

OVER_OUTPUT_FEATURES#

alias of OverOutputFeaturesView

OVER_SUBCHANNEL_BLOCK#

alias of OverSubChannelBlockView

OVER_TENSOR#

alias of OverTensorView

Module contents#