brevitas.core.scaling package#
Submodules#
brevitas.core.scaling.int_scaling module#
- class brevitas.core.scaling.int_scaling.IntScaling(signed, narrow_range)[source]#
Bases:
Module
- forward(bit_width)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.scaling.int_scaling.PowerOfTwoIntScaling(signed)[source]#
Bases:
Module
- forward(bit_width)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
brevitas.core.scaling.runtime module#
- class brevitas.core.scaling.runtime.RuntimeDynamicGroupStatsScaling(group_size, group_dim, input_view_impl, scaling_stats_impl, scaling_min_val, restrict_scaling_impl=FloatRestrictValue())[source]#
Bases:
Module
- forward(stats_input, threshold=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.scaling.runtime.RuntimeStatsScaling(scaling_stats_impl, scaling_stats_input_view_shape_impl, scaling_shape, affine_rescaling=False, affine_shift_scale=False, restrict_scaling_impl=FloatRestrictValue(), scaling_stats_momentum=0.1, scaling_min_val=None, dtype=None, device=None)[source]#
Bases:
Module
- forward(x, threshold=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.scaling.runtime.StatsFromParameterScaling(scaling_stats_impl, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list, scaling_shape, restrict_scaling_impl=FloatRestrictValue(), affine_rescaling=False, affine_shift_scale=False, scaling_min_val=None, dtype=None, device=None)[source]#
Bases:
Module
- forward(ignored, threshold=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
brevitas.core.scaling.standalone module#
- class brevitas.core.scaling.standalone.ConstScaling(scaling_init, restrict_scaling_impl=FloatRestrictValue(), scaling_min_val=None, dtype=None, device=None)[source]#
Bases:
Module
ScriptModule implementation of a constant scale factor.
- Parameters:
- Returns:
scale factor wrapped in a float torch.tensor.
- Return type:
Tensor
Examples
>>> scaling_impl = ConstScaling(1.0) >>> scaling_impl(torch.empty(1)) tensor(1.) >>> scaling_impl = ConstScaling(1.0, scaling_min_val=3.0) >>> scaling_impl(torch.empty(1)) tensor(3.) >>> scaling_impl = ConstScaling(3.0, restrict_scaling_impl=PowerOfTwoRestrictValue()) >>> scaling_impl(torch.empty(1)) tensor(4.)
Note
The forward method accepts a single placeholder argument. This is required by (early versions of) TorchScript to be consistent across different scaling implementations.
Note
Maps to scaling_impl_type == ScalingImplType.CONST == ‘CONST’ == ‘const’ in higher-level APIs.
- forward(placeholder, threshold=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.scaling.standalone.ParameterFromRuntimeStatsScaling(collect_stats_steps, scaling_stats_impl, scaling_stats_input_view_shape_impl=OverBatchOverTensorView( (permute_impl): Identity() ), scaling_shape=(), restrict_scaling_impl=FloatRestrictValue(), scaling_stats_momentum=0.1, scaling_min_val=None, dtype=None, device=None)[source]#
Bases:
Module
ScriptModule implementation of a learned scale factor initialized from runtime statistics. The implementation works in two phases. During the first phase, statistics are collected in the same fashion as batchnorm, meaning that while the module is in training mode a set of per-batch statistics are computed and returned, while in background an average of them is retained and returned in inference mode. During the second phase, the average accumulated during the first phase is used to initialize a learned torch.nn.Parameter, and then the behaviour is the same as ParameterScaling.
- Parameters:
collect_stats_steps (int) – Number of calls to the forward method in training mode to collect statistics for.
scaling_stats_impl (Module) – Implementation of the statistics computed during the collection phase.
scaling_stats_input_view_shape_impl (Module) – Implementation of the view applied to the runtime input during the statistics collection phase. Default: OverBatchOverTensorView().
scaling_shape (Tuple[int, ...]) – shape of the torch.nn.Parameter used in the second phase. Default: SCALAR_SHAPE.
restrict_scaling_impl (Module) – restrict the learned scale factor according to some criteria. Default: None input before going into scaling_stats_input_view_shape_impl. Default: None
scaling_stats_momentum (
Optional
[float
]) – float = Momentum for the statistics moving average. Default: DEFAULT_MOMENTUM.scaling_min_val (float) – force a lower-bound on the learned scale factor. Default: None.
- Returns:
learned scale factor wrapped in a float torch.tensor.
- Return type:
Tensor
- Raises:
RuntimeError – if scaling_shape != SCALAR_SHAPE and scaling_stats_permute_dims is None
Examples
>>> scaling_impl = ParameterFromRuntimeStatsScaling(collect_stats_steps=1, scaling_stats_impl=AbsMax()) >>> scaling_impl.training True >>> x = torch.arange(-3, 2, 0.1) >>> scaling_impl(x) tensor(3.) >>> scaling_impl(torch.randn_like(x)) tensor(3., grad_fn=<AbsBinarySignGradFnBackward>)
Note
Set env variable BREVITAS_IGNORE_MISSING_KEYS=1 to avoid errors when retraining from a floating point state dict.
Note
Maps to scaling_impl_type == ScalingImplType.PARAMETER_FROM_STATS == ‘PARAMETER_FROM_STATS’ == ‘parameter_from_stats’ when applied to runtime values (inputs/outputs/activations) in higher-level APIs.
- forward(stats_input, threshold=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- state_dict(destination=None, prefix='', keep_vars=False)[source]#
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- class brevitas.core.scaling.standalone.ParameterFromStatsFromParameterScaling(scaling_stats_impl, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list, scaling_shape, restrict_scaling_impl=FloatRestrictValue(), scaling_min_val=None, dtype=None, device=None)[source]#
Bases:
Module
ScriptModule implementation of a learned scale factor initialized from statistics of a parameter, e.g. weights MSE or AbsMax.
- forward(ignored, threshold=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- state_dict(destination=None, prefix='', keep_vars=False)[source]#
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- class brevitas.core.scaling.standalone.ParameterScaling(scaling_init, scaling_shape=None, restrict_scaling_impl=FloatRestrictValue(), scaling_min_val=None, dtype=None, device=None)[source]#
Bases:
Module
ScriptModule implementation of a learned scale factor.
- Parameters:
scaling_init (Union[float, Tensor]) – value to initialize the learned scale factor.
scaling_shape (Tuple[int, ...]) – shape to extend a scalar float or tensor scaling_init. Default: None
restrict_scaling_impl (Module) – restrict the learned scale factor according to some criteria. Default: None
scaling_min_val (float) – force a lower-bound on the learned scale factor. Default: None
- Returns:
learned scale factor wrapped in a float torch.tensor.
- Return type:
Tensor
- Raises:
RuntimeError – if scaling_init is a non-scalar tensor and scaling_shape is != scaling_init.shape.
Examples
>>> scaling_impl = ParameterScaling(6.0) >>> scaling_impl(torch.empty(1)) tensor(6., grad_fn=<AbsBinarySignGradFnBackward>) >>> scaling_impl = ParameterScaling(6.0, scaling_shape=(3,)) >>> scaling_impl(torch.empty(1)) tensor([6., 6., 6.], grad_fn=<AbsBinarySignGradFnBackward>) >>> scaling_impl = ParameterScaling(6.0, scaling_shape=(3,), restrict_scaling_impl=PowerOfTwoRestrictValue()) >>> scaling_impl(torch.empty(1)) tensor([8., 8., 8.], grad_fn=<PowBackward1>)
Note
Set env variable BREVITAS_IGNORE_MISSING_KEYS=1 to avoid errors when retraining from a floating point state dict.
Note
The forward method accepts a single placeholder argument. This is required by (early versions of) TorchScript to be consistent across different scaling implementations.
Note
Maps to scaling_impl_type == ScalingImplType.PARAMETER == ‘PARAMETER’ == ‘parameter’ in higher-level APIs.
- forward(placeholder, threshold=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.