brevitas.core.quant package#
Submodules#
brevitas.core.quant.binary module#
- class brevitas.core.quant.binary.BinaryQuant(scaling_impl, signed=True, quant_delay_steps=0)[source]#
Bases:
Module
ScriptModule that implements scaled uniform binary quantization of an input tensor. Quantization is performed with
binary_sign_ste()
.- Parameters:
scaling_impl (Module) – Module that returns a scale factor.
quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0
- Returns:
Quantized output in de-quantized format, scale, zero-point, bit_width.
- Return type:
Tuple[Tensor, Tensor, Tensor, Tensor]
Examples
>>> from brevitas.core.scaling import ConstScaling >>> binary_quant = BinaryQuant(ConstScaling(0.1)) >>> inp = torch.Tensor([0.04, -0.6, 3.3]) >>> out, scale, zero_point, bit_width = binary_quant(inp) >>> out tensor([ 0.1000, -0.1000, 0.1000]) >>> scale tensor(0.1000) >>> zero_point tensor(0.) >>> bit_width tensor(1.)
Note
Maps to quant_type == QuantType.BINARY == ‘BINARY’ == ‘binary’ when applied to weights in higher-level APIs.
Note
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple
[Tensor
,Tensor
,Tensor
,Tensor
]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.quant.binary.ClampedBinaryQuant(scaling_impl, tensor_clamp_impl=TensorClamp(), quant_delay_steps=0)[source]#
Bases:
Module
ScriptModule that implements scaled uniform binary quantization of an input tensor. Before going through quantization, the input tensor is clamped between (- scale, scale), which on the backward pass zeroes gradients corresponding to inputs outside that range. Quantization is performed with
binary_sign_ste()
.- Parameters:
scaling_impl (Module) – Module that returns a scale factor.
tensor_clamp_impl (Module) – Module that performs tensor-wise clamping. Default TensorClamp()
quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0
- Returns:
Quantized output in de-quantized format, scale, zero-point, bit_width.
- Return type:
Tuple[Tensor, Tensor, Tensor, Tensor]
Examples
>>> from brevitas.core.scaling import ConstScaling >>> binary_quant = ClampedBinaryQuant(ConstScaling(0.1)) >>> inp = torch.Tensor([0.04, -0.6, 3.3]).requires_grad_(True) >>> out, scale, zero_point, bit_width = binary_quant(inp) >>> out tensor([ 0.1000, -0.1000, 0.1000], grad_fn=<MulBackward0>) >>> out.backward(torch.Tensor([1.0, 1.0, 1.0])) >>> inp.grad tensor([0.1000, 0.0000, 0.0000]) >>> scale tensor(0.1000) >>> zero_point tensor(0.) >>> bit_width tensor(1.)
Note
- Maps to quant_type == QuantType.BINARY == ‘BINARY’ == ‘binary’ when applied to activations
in higher-level APIs.
Note
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple
[Tensor
,Tensor
,Tensor
,Tensor
]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
brevitas.core.quant.delay module#
- class brevitas.core.quant.delay.DelayWrapper(quant_delay_steps)[source]#
Bases:
Module
- forward(x, y)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
brevitas.core.quant.int module#
- class brevitas.core.quant.int.DecoupledRescalingIntQuant(decoupled_int_quant, pre_scaling_impl, scaling_impl, int_scaling_impl, pre_zero_point_impl, zero_point_impl, bit_width_impl)[source]#
Bases:
Module
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple
[Tensor
,Tensor
,Tensor
,Tensor
,Tensor
,Tensor
]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.quant.int.DecoupledRescalingIntQuantWithInput(decoupled_int_quant, pre_scaling_impl, scaling_impl, int_scaling_impl, pre_zero_point_impl, zero_point_impl, bit_width_impl)[source]#
Bases:
DecoupledRescalingIntQuant
- forward(x, input_bit_width, input_is_signed)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple
[Tensor
,Tensor
,Tensor
,Tensor
,Tensor
,Tensor
]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.quant.int.PrescaledRestrictIntQuant(int_quant, bit_width_impl)[source]#
Bases:
Module
- forward(x, scale)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple
[Tensor
,Tensor
,Tensor
,Tensor
]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.quant.int.PrescaledRestrictIntQuantWithInputBitWidth(int_quant, bit_width_impl)[source]#
Bases:
Module
ScriptModule that wraps around an integer quantization implementation like
IntQuant
. Zero-point is set to zero, scale is taken as input, bit-width is computed from an input bit-width.- Parameters:
int_quant (Module) – Module that implements integer quantization.
bit_width_impl (Module) – Module that takes the input bit-width in and returns the bit-width to be used for quantization.
- Returns:
- Quantized output in de-quantized format, scale,
zero-point, bit_width.
- Return type:
Tuple[Tensor, Tensor, Tensor, Tensor]
Examples
>>> from brevitas.core.scaling import ConstScaling >>> from brevitas.core.function_wrapper import Identity >>> from brevitas.core.quant import IntQuant >>> int_quant = IntQuant(narrow_range=True, signed=True) >>> int_quant_wrapper = PrescaledRestrictIntQuantWithInputBitWidth(int_quant, Identity()) >>> scale, input_bit_width = torch.tensor(0.01), torch.tensor(4.) >>> inp = torch.Tensor([0.042, -0.053, 0.31, -0.44]) >>> out, scale, zero_point, bit_width = int_quant_wrapper(inp, scale, input_bit_width) >>> out tensor([ 0.0400, -0.0500, 0.0700, -0.0700]) >>> scale tensor(0.0100) >>> zero_point tensor(0.) >>> bit_width tensor(4.)
Note
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
- forward(x, scale, input_bit_width)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple
[Tensor
,Tensor
,Tensor
,Tensor
]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.quant.int.RescalingIntQuant(int_quant, scaling_impl, int_scaling_impl, zero_point_impl, bit_width_impl)[source]#
Bases:
Module
ScriptModule that wraps around an integer quantization implementation like
IntQuant
. Scale, zero-point and bit-width are returned from their respective implementations and passed on to the integer quantization implementation.- Parameters:
int_quant (Module) – Module that implements integer quantization.
scaling_impl (Module) – Module that takes in the input to quantize and returns a scale factor, here interpreted as threshold on the floating-point range of quantization.
int_scaling_impl (Module) – Module that takes in a bit-width and returns an integer scale factor, here interpreted as threshold on the integer range of quantization.
zero_point_impl (Module) – Module that returns an integer zero-point.
bit_width_impl (Module) – Module that returns a bit-width.
- Returns:
- Quantized output in de-quantized format, scale,
zero-point, bit_width.
- Return type:
Tuple[Tensor, Tensor, Tensor, Tensor]
Examples
>>> from brevitas.core.scaling import ConstScaling >>> from brevitas.core.zero_point import ZeroZeroPoint >>> from brevitas.core.scaling import IntScaling >>> from brevitas.core.quant import IntQuant >>> from brevitas.core.bit_width import BitWidthConst >>> int_quant_wrapper = RescalingIntQuant( ... IntQuant(narrow_range=True, signed=True), ... ConstScaling(0.1), ... IntScaling(signed=True, narrow_range=True), ... ZeroZeroPoint(), ... BitWidthConst(4)) >>> inp = torch.Tensor([0.042, -0.053, 0.31, -0.44]) >>> out, scale, zero_point, bit_width = int_quant_wrapper(inp) >>> out tensor([ 0.0429, -0.0571, 0.1000, -0.1000]) >>> scale tensor(0.0143) >>> zero_point tensor(0.) >>> bit_width tensor(4.)
Note
scale = scaling_impl(x) / int_scaling_impl(bit_width)
Note
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple
[Tensor
,Tensor
,Tensor
,Tensor
]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.quant.int.TruncIntQuant(float_to_int_impl, bit_width_impl, quant_delay_steps=0)[source]#
Bases:
Module
- forward(x, scale, zero_point, input_bit_width)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple
[Tensor
,Tensor
,Tensor
,Tensor
]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
brevitas.core.quant.int_base module#
- class brevitas.core.quant.int_base.DecoupledIntQuant(narrow_range, signed, input_view_impl, float_to_int_impl=RoundSte(), tensor_clamp_impl=TensorClamp(), quant_delay_steps=0)[source]#
Bases:
Module
ScriptModule that implements scale, shifted, uniform integer quantization of an input tensor, according to an input pre-scale, scale, pre-zero-point, zero-point and bit-width.
- Parameters:
narrow_range (bool) – Flag that determines whether restrict quantization to a narrow range or not.
signed (bool) – Flag that determines whether to quantize to a signed range or not.
float_to_int_impl (Module) – Module that performs the conversion from floating point to integer representation. Default: RoundSte()
tensor_clamp_impl (Module) – Module that performs clamping. Default: TensorClamp()
quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0
- Returns:
Quantized output in de-quantized format.
- Return type:
Tensor
Examples
>>> from brevitas.core.scaling import ConstScaling >>> int_quant = DecoupledIntQuant(narrow_range=True, signed=True) >>> scale, zero_point, bit_width = torch.tensor(0.01), torch.tensor(0.), torch.tensor(4.) >>> pre_scale, pre_zero_point = torch.tensor(0.02), torch.tensor(0.) >>> inp = torch.Tensor([0.042, -0.053, 0.31, -0.44]) >>> out = int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, inp) >>> out tensor([ 0.0200, -0.0300, 0.0700, -0.0700])
Note
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
- forward(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class brevitas.core.quant.int_base.IntQuant(narrow_range, signed, input_view_impl, float_to_int_impl=RoundSte(), tensor_clamp_impl=TensorClamp(), quant_delay_steps=0)[source]#
Bases:
Module
ScriptModule that implements scale, shifted, uniform integer quantization of an input tensor, according to an input scale, zero-point and bit-width.
- Parameters:
narrow_range (bool) – Flag that determines whether restrict quantization to a narrow range or not.
signed (bool) – Flag that determines whether to quantize to a signed range or not.
float_to_int_impl (Module) – Module that performs the conversion from floating point to integer representation. Default: RoundSte()
tensor_clamp_impl (Module) – Module that performs clamping. Default: TensorClamp()
quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0
- Returns:
Quantized output in de-quantized format.
- Return type:
Tensor
Examples
>>> from brevitas.core.scaling import ConstScaling >>> int_quant = IntQuant(narrow_range=True, signed=True) >>> scale, zero_point, bit_width = torch.tensor(0.01), torch.tensor(0.), torch.tensor(4.) >>> inp = torch.Tensor([0.042, -0.053, 0.31, -0.44]) >>> out = int_quant(scale, zero_point, bit_width, inp) >>> out tensor([ 0.0400, -0.0500, 0.0700, -0.0700])
Note
Maps to quant_type == QuantType.INT == ‘INT’ == ‘int’ in higher-level APIs.
Note
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
- forward(scale, zero_point, bit_width, x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
brevitas.core.quant.ternary module#
- class brevitas.core.quant.ternary.TernaryQuant(scaling_impl, threshold, quant_delay_steps=None)[source]#
Bases:
Module
ScriptModule that implements scaled uniform ternary quantization of an input tensor. Quantization is performed with
ternary_sign_ste()
.- Parameters:
- Returns:
- Quantized output in de-quantized format, scale,
zero-point, bit_width.
- Return type:
Tuple[Tensor, Tensor, Tensor, Tensor]
Examples
>>> from brevitas.core.scaling import ConstScaling >>> ternary_quant = TernaryQuant(ConstScaling(1.0), 0.5) >>> inp = torch.Tensor([0.04, -0.6, 3.3]) >>> out, scale, zero_point, bit_width = ternary_quant(inp) >>> out tensor([ 0., -1., 1.]) >>> scale tensor(1.) >>> zero_point tensor(0.) >>> bit_width tensor(2.)
Note
Maps to quant_type == QuantType.TERNARY == ‘TERNARY’ == ‘ternary’ in higher-level APIs.
Note
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple
[Tensor
,Tensor
,Tensor
,Tensor
]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.