brevitas.core.bit_width package#

Submodules#

brevitas.core.bit_width.const module#

class brevitas.core.bit_width.const.BitWidthConst(bit_width, dtype=None, device=None)[source]#

Bases: Module

ScriptModule that returns a constant bit-width wrapped in a float torch.tensor.

Parameters:

bit_width (int) – bit-width value.

Examples

>>> bit_width = BitWidthConst(8)
>>> bit_width()
tensor(8.)

Note

The bit-width is not part of the Module’s state, meaning that it won’t be saved as part of a checkpoint.

Note

Maps to bit_width_impl_type == BitWidthImplType.CONST == ‘CONST’ == ‘const’ in higher-level APIs.

forward()[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.bit_width.const.BitWidthStatefulConst(bit_width, dtype=None, device=None)[source]#

Bases: Module

ScriptModule that returns a constant bit-width wrapped in a float torch.tensor but retains the bit-width as part of the module state.

Parameters:

bit_width (int) – bit-width value.

Examples

>>> bit_width = BitWidthStatefulConst(8)
>>> bit_width()
tensor(8.)

Note

The BitWidthStatefulConst is a counterpart to BitWidthConst with the difference that the BitWidthStatefulConst retains the bit-width as part of the Module’s state. This means that it will be saved as part of a checkpoint.

Note

Maps to bit_width_impl_type == BitWidthImplType.STATEFUL_CONST == ‘STATEFUL_CONST’ == ‘stateful_const’ in higher-level APIs.

forward()[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.bit_width.const.MsbClampBitWidth(bit_width_to_remove_impl, min_overall_bit_width, max_overall_bit_width)[source]#

Bases: Module

forward(input_bit_width)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

brevitas.core.bit_width.parameter module#

class brevitas.core.bit_width.parameter.BitWidthParameter(bit_width, min_bit_width=2, restrict_bit_width_impl=IntRestrictValue(   (float_to_int_impl): RoundSte() ), override_pretrained_bit_width=False, dtype=None, device=None)[source]#

Bases: Module

ScriptModule that returns a learnable bit-width wrapped in a float torch.Tensor.

Parameters:
  • bit_width (int) – value to initialize the output learned bit-width.

  • min_bit_width (int) – lower bound for the output learned bit-width. Default: 2.

  • restrict_bit_width_impl (Module) – restrict the learned bit-width to a subset of values. Default: IntRestrictValue(RoundSte()).

  • override_pretrained_bit_width (bool) – ignore pretrained bit-width loaded from a state dict. Default: False.

Returns:

bit-width wrapped in a float torch.tensor and backend by a learnable torch.nn.Parameter.

Return type:

Tensor

Raises:

RuntimeError – if bit_width < min_bit_width.

Examples

>>> bit_width_parameter = BitWidthParameter(8)
>>> bit_width_parameter()
tensor(8., grad_fn=<RoundSteFnBackward>)

Note

Set env variable BREVITAS_IGNORE_MISSING_KEYS=1 to avoid errors when retraining from a floating point state dict.

Note

Maps to bit_width_impl_type == BitWidthImplType.PARAMETER == ‘PARAMETER’ == ‘parameter’ in higher-level APIs.

forward()[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.bit_width.parameter.RemoveBitwidthParameter(bit_width_to_remove, override_pretrained_bit_width=False, non_zero_epsilon=1e-06, remove_zero_bit_width=0.1, dtype=None, device=None)[source]#

Bases: Module

forward()[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Module contents#