brevitas.core.scaling package#

Submodules#

brevitas.core.scaling.int_scaling module#

class brevitas.core.scaling.int_scaling.IntScaling(narrow_range, signed=None)[source]#

Bases: Module

forward(bit_width, signed=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.scaling.int_scaling.PowerOfTwoIntScaling(signed=None)[source]#

Bases: Module

forward(bit_width, signed=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

brevitas.core.scaling.runtime module#

class brevitas.core.scaling.runtime.RuntimeDynamicGroupStatsScaling(group_size, group_dim, input_view_impl, scaling_stats_impl, scaling_min_val, restrict_scaling_impl=FloatRestrictValue(), restrict_threshold_impl=None, restrict_scale_threshold_impl=None, is_scale_unsigned=True)[source]#

Bases: Module

forward(stats_input, threshold=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.scaling.runtime.RuntimeStatsScaling(scaling_stats_impl, scaling_stats_input_view_shape_impl, scaling_shape, is_scale_unsigned=True, scaling_affine_rescaling_init=None, scaling_affine_shifting_init=None, restrict_scaling_impl=FloatRestrictValue(), restrict_threshold_impl=None, restrict_scale_threshold_impl=None, scaling_stats_momentum=0.1, scaling_min_val=None, dtype=None, device=None)[source]#

Bases: Module

forward(x, threshold=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.scaling.runtime.StatsFromParameterScaling(scaling_stats_impl, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list, scaling_shape, force_parameter=False, restrict_scaling_impl=FloatRestrictValue(), restrict_threshold_impl=None, restrict_scale_threshold_impl=None, scaling_affine_rescaling_init=None, scaling_affine_shifting_init=None, scaling_min_val=None, dtype=None, device=None)[source]#

Bases: Module

forward(x, threshold=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

brevitas.core.scaling.standalone module#

class brevitas.core.scaling.standalone.ConstScaling(scaling_init, is_scale_unsigned=True, restrict_scaling_impl=FloatRestrictValue(), restrict_threshold_impl=None, restrict_scale_threshold_impl=None, scaling_min_val=None, dtype=None, device=None)[source]#

Bases: Module

ScriptModule implementation of a constant scale factor.

Parameters:
  • scaling_init (Union[float, Tensor]) – value to initialize the constant scale factor.

  • is_scale_unsigned (bool) – Whether the scale is unsigned. Default: True.

  • restrict_scaling_impl (Module) – restrict the scale factor according to some criteria. Default: FloatRestrictValue().

  • restrict_threshold_impl (Optional[Module]) – restrict the threshold according to some criteria. Default: None.

  • restrict_scale_threshold_impl (Optional[Module]) – restrict value of scale / threshold according to some criteria. Default: None.

  • scaling_min_val (Optional[float]) – force a lower-bound on the scale factor. Default: None.

  • dtype (Optional[torch.dtype]) – data type of the scale factor. Default: None.

  • device (Optional[torch.device]) – device of the scale factor. Default: None.

Returns:

scale factor wrapped in a float torch.tensor.

Return type:

Tensor

Examples

>>> scaling_impl = ConstScaling(1.0)
>>> scaling_impl(torch.empty(1))
tensor(1.)
>>> scaling_impl = ConstScaling(1.0, scaling_min_val=3.0)
>>> scaling_impl(torch.empty(1))
tensor(3.)
>>> scaling_impl = ConstScaling(3.0, restrict_scaling_impl=PowerOfTwoRestrictValue())
>>> scaling_impl(torch.empty(1))
tensor(4.)

Note

The forward method accepts a single placeholder argument. This is required by (early versions of) TorchScript to be consistent across different scaling implementations.

Note

Maps to scaling_impl_type == ScalingImplType.CONST == ‘CONST’ == ‘const’ in higher-level APIs.

forward(placeholder, threshold=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.scaling.standalone.ParameterFromRuntimeStatsScaling(collect_stats_steps, scaling_stats_impl, is_scale_unsigned=True, scaling_stats_input_view_shape_impl=OverBatchOverTensorView(   (permute_impl): Identity() ), scaling_shape=(), restrict_scaling_impl=FloatRestrictValue(), restrict_threshold_impl=None, restrict_scale_threshold_impl=None, scaling_stats_momentum=0.1, scaling_min_val=None, dtype=None, device=None)[source]#

Bases: Module

ScriptModule implementation of a learned scale factor initialized from runtime statistics. The implementation works in two phases. During the first phase, statistics are collected in the same fashion as batchnorm, meaning that while the module is in training mode a set of per-batch statistics are computed and returned, while in background an average of them is retained and returned in inference mode. During the second phase, the average accumulated during the first phase is used to initialize a learned torch.nn.Parameter, and then the behaviour is the same as ParameterScaling.

Parameters:
  • collect_stats_steps (int) – Number of calls to the forward method in training mode to collect statistics for.

  • scaling_stats_impl (Module) – Implementation of the statistics computed during the collection phase.

  • is_scale_unsigned (bool, optional) – Whether the scale is unsigned. Default: True.

  • scaling_stats_input_view_shape_impl (Module, optional) – Implementation of the view applied to the runtime input during the statistics collection phase. Default: OverBatchOverTensorView().

  • scaling_shape (Tuple[int, ...], optional) – Shape of the torch.nn.Parameter used in the second phase. Default: SCALAR_SHAPE.

  • restrict_scaling_impl (Module, optional) – Restrict the learned scale factor according to some criteria. Default: FloatRestrictValue().

  • restrict_threshold_impl (Optional[Module], optional) – Restrict the threshold according to some criteria. Default: None.

  • restrict_scale_threshold_impl (Optional[Module]) – restrict value of scale / threshold according to some criteria. Default: None.

  • scaling_stats_momentum (Optional[float], optional) – Momentum for the statistics moving average. Default: DEFAULT_MOMENTUM.

  • scaling_min_val (Optional[float], optional) – Force a lower-bound on the learned scale factor. Default: None.

  • dtype (Optional[torch.dtype], optional) – Data type of the scale factor. Default: None.

  • device (Optional[torch.device], optional) – Device of the scale factor. Default: None.

Returns:

learned scale factor wrapped in a float torch.tensor.

Return type:

Tensor

Raises:

RuntimeError – if collect_stats_steps <= 0.

Examples

>>> scaling_impl = ParameterFromRuntimeStatsScaling(
...     collect_stats_steps=1,
...     scaling_stats_impl=AbsMax())
>>> scaling_impl.training
True
>>> x = torch.arange(-3, 2, 0.1)
>>> scaling_impl(x)
tensor(3.)
>>> scaling_impl(torch.randn_like(x))
tensor(3., grad_fn=<AbsBinarySignGradFnBackward>)

Note

Set env variable BREVITAS_IGNORE_MISSING_KEYS=1 to avoid errors when retraining from a floating point state dict.

Note

Maps to scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS == ‘PARAMETER_FROM_STATS’ == ‘parameter_from_stats’ when applied to runtime values (inputs/outputs/activations) in higher-level APIs.

forward(stats_input, threshold=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

init_scale()[source]#
state_dict(destination=None, prefix='', keep_vars=False)[source]#

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
training_forward(stats_input, threshold)[source]#
Return type:

Tensor

class brevitas.core.scaling.standalone.ParameterFromStatsFromParameterScaling(scaling_stats_impl, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list, scaling_shape, is_scale_unsigned=True, force_parameter=False, restrict_scaling_impl=FloatRestrictValue(), restrict_threshold_impl=None, restrict_scale_threshold_impl=None, scaling_affine_rescaling_init=None, scaling_affine_shifting_init=None, scaling_min_val=None, dtype=None, device=None)[source]#

Bases: Module

ScriptModule implementation of a learned scale factor initialized from statistics computed over a list of parameters.

Parameters:
  • scaling_stats_impl (Module) – Implementation of the statistics computed over the parameter list.

  • scaling_stats_input_view_shape_impl (Module) – Implementation of the view applied to the input before statistics computation.

  • scaling_stats_input_concat_dim (int) – Dimension along which to concatenate parameter tensors for statistics computation.

  • tracked_parameter_list (List[torch.nn.Parameter]) – List of parameters to track and compute statistics over.

  • scaling_shape (Tuple[int, ...]) – Shape of the learned scale factor.

  • is_scale_unsigned (bool) – Whether the scale is unsigned. Default: True.

  • force_parameter (bool) – If True, always use a tracked_parameter_list for statistics, even if only one is tracked. Default: False.

  • restrict_scaling_impl (Module) – Restrict the scale factor according to some criteria. Default: FloatRestrictValue().

  • restrict_threshold_impl (Optional[Module]) – Restrict the threshold according to some criteria. Default: None.

  • restrict_scale_threshold_impl (Optional[Module]) – restrict value of scale / threshold according to some criteria. Default: None.

  • scaling_affine_rescaling_init (Optional[float]) – Initial value for affine rescaling. Default: None.

  • scaling_affine_shifting_init (Optional[float]) – Initial value for affine shifting. Default: None.

  • scaling_min_val (Optional[float]) – Force a lower-bound on the scale factor. Default: None.

  • dtype (Optional[torch.dtype]) – Data type of the scale factor. Default: None.

  • device (Optional[torch.device]) – Device of the scale factor. Default: None.

Returns:

learned scale factor wrapped in a float torch.tensor.

Return type:

Tensor

Note

Set env variable BREVITAS_IGNORE_MISSING_KEYS=1 to avoid errors when retraining from a floating point state dict.

Note

Maps to scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS == ‘PARAMETER_FROM_STATS’ == ‘parameter_from_stats’ in higher-level APIs.

Example

>>> scaling_impl = ParameterFromStatsFromParameterScaling(
...     scaling_stats_impl=AbsMax(),
...     scaling_stats_input_view_shape_impl=Identity(),
...     scaling_stats_input_concat_dim=0,
...     tracked_parameter_list=[torch.nn.Parameter(torch.ones(3))],
...     scaling_shape=(3,))
>>> x = torch.randn(3)
>>> scaling_impl(x)
tensor([...], grad_fn=<...>)
forward(x, threshold=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

state_dict(destination=None, prefix='', keep_vars=False)[source]#

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Parameters:
  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:

a dictionary containing a whole state of the module

Return type:

dict

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
class brevitas.core.scaling.standalone.ParameterScaling(scaling_init, is_scale_unsigned=True, scaling_shape=None, restrict_scaling_impl=FloatRestrictValue(), restrict_threshold_impl=None, restrict_scale_threshold_impl=None, scaling_min_val=None, dtype=None, device=None)[source]#

Bases: Module

ScriptModule implementation of a learned scale factor.

Parameters:
  • scaling_init (Union[float, Tensor]) – Value to initialize the learned scale factor.

  • is_scale_unsigned (bool) – Whether the scale is unsigned. Default: True.

  • scaling_shape (Optional[Tuple[int, ...]]) – Shape of the learned scale factor. Default: None.

  • restrict_scaling_impl (Module) – Restrict the scale factor according to some criteria. Default: FloatRestrictValue().

  • restrict_threshold_impl (Optional[Module]) – Restrict the threshold according to some criteria. Default: None.

  • restrict_scale_threshold_impl (Optional[Module]) – restrict value of scale / threshold according to some criteria. Default: None.

  • scaling_min_val (Optional[float]) – Force a lower-bound on the scale factor. Default: None.

  • dtype (Optional[torch.dtype]) – Data type of the scale factor. Default: None.

  • device (Optional[torch.device]) – Device of the scale factor. Default: None.

Returns:

learned scale factor wrapped in a float torch.tensor.

Return type:

Tensor

Raises:

RuntimeError – if scaling_init is a non-scalar tensor and scaling_shape is != scaling_init.shape.

Examples

>>> scaling_impl = ParameterScaling(6.0)
>>> scaling_impl(torch.empty(1))
tensor(6., grad_fn=<AbsBinarySignGradFnBackward>)
>>> scaling_impl = ParameterScaling(6.0, scaling_shape=(3,))
>>> scaling_impl(torch.empty(1))
tensor([6., 6., 6.], grad_fn=<AbsBinarySignGradFnBackward>)
>>> scaling_impl = ParameterScaling(6.0, scaling_shape=(3,), restrict_scaling_impl=PowerOfTwoRestrictValue())
>>> scaling_impl(torch.empty(1))
tensor([8., 8., 8.], grad_fn=<PowBackward1>)

Note

Set env variable BREVITAS_IGNORE_MISSING_KEYS=1 to avoid errors when retraining from a floating point state dict.

Note

The forward method accepts a single placeholder argument. This is required by (early versions of) TorchScript to be consistent across different scaling implementations.

Note

Maps to scaling_impl_type == ScalingImplType.PARAMETER == ‘PARAMETER’ == ‘parameter’ in higher-level

APIs.

forward(placeholder, threshold=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.scaling.standalone.TruncMsbScaling[source]#

Bases: Module

ScriptModule implementation of an integer scaling which calculates the scaling required to keep the most significant bits of the input. Interface compatible with TruncIntQuant’s trunc_scaling_impl member.

Args:

Returns:

truncation scale factor wrapped in a float torch.tensor.

Return type:

Tensor

Examples

>>> from brevitas.core.scaling import TruncMsbScaling
>>> trunc_scaling_impl = TruncMsbScaling()
>>> input_bit_width, output_bit_width, signed = torch.tensor(8.), torch.tensor(4.), torch.tensor(True)
>>> scaling_input = torch.Tensor([0.04, -0.05, 0.31, -0.44])
>>> trunc_scale = trunc_scaling_impl(scaling_input, input_bit_width, output_bit_width, signed)
>>> trunc_scale
tensor(16.)

Note

The forward method accepts a multiple placeholder arguments: scaling_input and signed to match the calling convention other trunc_scaling_impl modules. This is required by (early versions of) TorchScript to be consistent across different scaling implementations.

Note

Maps to trunc_scaling_impl == TruncScalingImplType.MSB == ‘MSB’ == ‘msb’ in higher-level APIs.

forward(scaling_input, input_bit_width, output_bit_width, signed)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.scaling.standalone.TruncScalingWrapper(trunc_int_scaling_impl, scaling_impl, tensor_clamp_impl=TensorClamp())[source]#

Bases: Module

ScriptModule wrapper which maps the inferface requirements of TruncIntQuant’s trunc_scaling_impl to standard scaling implementations through scaling_impl.

Parameters:
  • trunc_int_scaling_impl (Module) – Module that takes in a bit-width and returns an integer scale factor, here interpreted as threshold on the integer range of quantization.

  • scaling_impl (Module) – Module that takes in the input to quantize and returns a scale factor, here interpreted as threshold on the floating-point range of quantization.

  • tensor_clamp_impl (Module) – Module that performs clamping. Default: TensorClamp()

Returns:

truncation scale factor wrapped in a float torch.tensor.

Return type:

Tensor

Examples

>>> from brevitas.core.scaling import TruncScalingWrapper
>>> from brevitas.core.scaling import ConstScaling
>>> from brevitas.core.scaling import PowerOfTwoIntScaling
>>> trunc_scaling_impl = TruncScalingWrapper(PowerOfTwoIntScaling(), ConstScaling(1.))
>>> input_bit_width, output_bit_width, signed = torch.tensor(8.), torch.tensor(4.), torch.tensor(True)
>>> scaling_input = torch.Tensor([0.04, -0.05, 0.31, -0.44])
>>> trunc_scale = trunc_scaling_impl(scaling_input, input_bit_width, output_bit_width, signed)
>>> trunc_scale
tensor(1.)

Note

Maps to trunc_scaling_impl == TruncScalingImplType.WRAPPER == ‘WRAPPER’ == ‘wrapper’ in higher-level APIs.

forward(scaling_input, input_bit_width, output_bit_width, signed)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Module contents#