brevitas.core.bit_width package#
Submodules#
brevitas.core.bit_width.const module#
- class brevitas.core.bit_width.const.BitWidthConst(bit_width, dtype=None, device=None)[source]#
Bases:
Module
ScriptModule that returns a constant bit-width wrapped in a float torch.tensor.
- Parameters:
bit_width (int) – bit-width value.
Examples
>>> bit_width = BitWidthConst(8) >>> bit_width() tensor(8.)
Note
The bit-width is not part of the Module’s state, meaning that it won’t be saved as part of a checkpoint.
Note
Maps to bit_width_impl_type == BitWidthImplType.CONST == ‘CONST’ == ‘const’ in higher-level APIs.
- forward()[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.bit_width.const.BitWidthStatefulConst(bit_width, dtype=None, device=None)[source]#
Bases:
Module
ScriptModule that returns a constant bit-width wrapped in a float torch.tensor but retains the bit-width as part of the module state.
- Parameters:
bit_width (int) – bit-width value.
Examples
>>> bit_width = BitWidthStatefulConst(8) >>> bit_width() tensor(8.)
Note
The BitWidthStatefulConst is a counterpart to BitWidthConst with the difference that the BitWidthStatefulConst retains the bit-width as part of the Module’s state. This means that it will be saved as part of a checkpoint.
Note
Maps to bit_width_impl_type == BitWidthImplType.STATEFUL_CONST == ‘STATEFUL_CONST’ == ‘stateful_const’ in higher-level APIs.
- forward()[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.bit_width.const.MsbClampBitWidth(bit_width_to_remove_impl, min_overall_bit_width, max_overall_bit_width)[source]#
Bases:
Module
- forward(input_bit_width)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
brevitas.core.bit_width.parameter module#
- class brevitas.core.bit_width.parameter.BitWidthParameter(bit_width, min_bit_width=2, restrict_bit_width_impl=IntRestrictValue( (float_to_int_impl): RoundSte() ), override_pretrained_bit_width=False, dtype=None, device=None)[source]#
Bases:
Module
ScriptModule that returns a learnable bit-width wrapped in a float torch.Tensor.
- Parameters:
bit_width (int) – value to initialize the output learned bit-width.
min_bit_width (int) – lower bound for the output learned bit-width. Default: 2.
restrict_bit_width_impl (
Module
) – restrict the learned bit-width to a subset of values. Default: IntRestrictValue(RoundSte()).override_pretrained_bit_width (bool) – ignore pretrained bit-width loaded from a state dict. Default: False.
- Returns:
bit-width wrapped in a float torch.tensor and backend by a learnable torch.nn.Parameter.
- Return type:
Tensor
- Raises:
RuntimeError – if bit_width < min_bit_width.
Examples
>>> bit_width_parameter = BitWidthParameter(8) >>> bit_width_parameter() tensor(8., grad_fn=<RoundSteFnBackward>)
Note
Set env variable BREVITAS_IGNORE_MISSING_KEYS=1 to avoid errors when retraining from a floating point state dict.
Note
Maps to bit_width_impl_type == BitWidthImplType.PARAMETER == ‘PARAMETER’ == ‘parameter’ in higher-level APIs.
- forward()[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.bit_width.parameter.RemoveBitwidthParameter(bit_width_to_remove, override_pretrained_bit_width=False, non_zero_epsilon=1e-06, remove_zero_bit_width=0.1, dtype=None, device=None)[source]#
Bases:
Module
- forward()[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.