brevitas.core.quant package#

Submodules#

brevitas.core.quant.binary module#

class brevitas.core.quant.binary.BinaryQuant(scaling_impl, signed=True, quant_delay_steps=0)[source]#

Bases: Module

ScriptModule that implements scaled uniform binary quantization of an input tensor. Quantization is performed with binary_sign_ste().

Parameters:
  • scaling_impl (Module) – Module that returns a scale factor.

  • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

Returns:

Quantized output in de-quantized format, scale, zero-point, bit_width.

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor]

Examples

>>> from brevitas.core.scaling import ConstScaling
>>> binary_quant = BinaryQuant(ConstScaling(0.1))
>>> inp = torch.Tensor([0.04, -0.6, 3.3])
>>> out, scale, zero_point, bit_width = binary_quant(inp)
>>> out
tensor([ 0.1000, -0.1000,  0.1000])
>>> scale
tensor(0.1000)
>>> zero_point
tensor(0.)
>>> bit_width
tensor(1.)

Note

Maps to quant_type == QuantType.BINARY == ‘BINARY’ == ‘binary’ when applied to weights in higher-level APIs.

Note

Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.quant.binary.ClampedBinaryQuant(scaling_impl, tensor_clamp_impl=TensorClamp(), quant_delay_steps=0)[source]#

Bases: Module

ScriptModule that implements scaled uniform binary quantization of an input tensor. Before going through quantization, the input tensor is clamped between (- scale, scale), which on the backward pass zeroes gradients corresponding to inputs outside that range. Quantization is performed with binary_sign_ste().

Parameters:
  • scaling_impl (Module) – Module that returns a scale factor.

  • tensor_clamp_impl (Module) – Module that performs tensor-wise clamping. Default TensorClamp()

  • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

Returns:

Quantized output in de-quantized format, scale, zero-point, bit_width.

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor]

Examples

>>> from brevitas.core.scaling import ConstScaling
>>> binary_quant = ClampedBinaryQuant(ConstScaling(0.1))
>>> inp = torch.Tensor([0.04, -0.6, 3.3]).requires_grad_(True)
>>> out, scale, zero_point, bit_width = binary_quant(inp)
>>> out
tensor([ 0.1000, -0.1000,  0.1000], grad_fn=<MulBackward0>)
>>> out.backward(torch.Tensor([1.0, 1.0, 1.0]))
>>> inp.grad
tensor([0.1000, 0.0000, 0.0000])
>>> scale
tensor(0.1000)
>>> zero_point
tensor(0.)
>>> bit_width
tensor(1.)

Note

Maps to quant_type == QuantType.BINARY == ‘BINARY’ == ‘binary’ when applied to activations

in higher-level APIs.

Note

Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

brevitas.core.quant.delay module#

class brevitas.core.quant.delay.DelayWrapper(quant_delay_steps)[source]#

Bases: Module

forward(x, y)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

brevitas.core.quant.int module#

class brevitas.core.quant.int.DecoupledRescalingIntQuant(decoupled_int_quant, pre_scaling_impl, scaling_impl, int_scaling_impl, pre_zero_point_impl, zero_point_impl, bit_width_impl)[source]#

Bases: Module

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.quant.int.DecoupledRescalingIntQuantWithInput(decoupled_int_quant, pre_scaling_impl, scaling_impl, int_scaling_impl, pre_zero_point_impl, zero_point_impl, bit_width_impl)[source]#

Bases: DecoupledRescalingIntQuant

forward(x, input_bit_width, input_is_signed)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.quant.int.PrescaledRestrictIntQuant(int_quant, bit_width_impl)[source]#

Bases: Module

forward(x, scale)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.quant.int.PrescaledRestrictIntQuantWithInputBitWidth(int_quant, bit_width_impl)[source]#

Bases: Module

ScriptModule that wraps around an integer quantization implementation like IntQuant. Zero-point is set to zero, scale is taken as input, bit-width is computed from an input bit-width.

Parameters:
  • int_quant (Module) – Module that implements integer quantization.

  • bit_width_impl (Module) – Module that takes the input bit-width in and returns the bit-width to be used for quantization.

Returns:

Quantized output in de-quantized format, scale,

zero-point, bit_width.

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor]

Examples

>>> from brevitas.core.scaling import ConstScaling
>>> from brevitas.core.function_wrapper import Identity
>>> from brevitas.core.quant import IntQuant
>>> int_quant = IntQuant(narrow_range=True, signed=True)
>>> int_quant_wrapper = PrescaledRestrictIntQuantWithInputBitWidth(int_quant, Identity())
>>> scale, input_bit_width = torch.tensor(0.01), torch.tensor(4.)
>>> inp = torch.Tensor([0.042, -0.053, 0.31, -0.44])
>>> out, scale, zero_point, bit_width = int_quant_wrapper(inp, scale, input_bit_width)
>>> out
tensor([ 0.0400, -0.0500,  0.0700, -0.0700])
>>> scale
tensor(0.0100)
>>> zero_point
tensor(0.)
>>> bit_width
tensor(4.)

Note

Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.

forward(x, scale, input_bit_width)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.quant.int.RescalingIntQuant(int_quant, scaling_impl, int_scaling_impl, zero_point_impl, bit_width_impl)[source]#

Bases: Module

ScriptModule that wraps around an integer quantization implementation like IntQuant. Scale, zero-point and bit-width are returned from their respective implementations and passed on to the integer quantization implementation.

Parameters:
  • int_quant (Module) – Module that implements integer quantization.

  • scaling_impl (Module) – Module that takes in the input to quantize and returns a scale factor, here interpreted as threshold on the floating-point range of quantization.

  • int_scaling_impl (Module) – Module that takes in a bit-width and returns an integer scale factor, here interpreted as threshold on the integer range of quantization.

  • zero_point_impl (Module) – Module that returns an integer zero-point.

  • bit_width_impl (Module) – Module that returns a bit-width.

Returns:

Quantized output in de-quantized format, scale,

zero-point, bit_width.

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor]

Examples

>>> from brevitas.core.scaling import ConstScaling
>>> from brevitas.core.zero_point import ZeroZeroPoint
>>> from brevitas.core.scaling import IntScaling
>>> from brevitas.core.quant import IntQuant
>>> from brevitas.core.bit_width import BitWidthConst
>>> int_quant_wrapper = RescalingIntQuant(
...                         IntQuant(narrow_range=True, signed=True),
...                         ConstScaling(0.1),
...                         IntScaling(signed=True, narrow_range=True),
...                         ZeroZeroPoint(),
...                         BitWidthConst(4))
>>> inp = torch.Tensor([0.042, -0.053, 0.31, -0.44])
>>> out, scale, zero_point, bit_width = int_quant_wrapper(inp)
>>> out
tensor([ 0.0429, -0.0571,  0.1000, -0.1000])
>>> scale
tensor(0.0143)
>>> zero_point
tensor(0.)
>>> bit_width
tensor(4.)

Note

scale = scaling_impl(x) / int_scaling_impl(bit_width)

Note

Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class brevitas.core.quant.int.TruncIntQuant(float_to_int_impl, bit_width_impl, quant_delay_steps=0)[source]#

Bases: Module

forward(x, scale, zero_point, input_bit_width)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

brevitas.core.quant.int_base module#

class brevitas.core.quant.int_base.DecoupledIntQuant(narrow_range, signed, input_view_impl, float_to_int_impl=RoundSte(), tensor_clamp_impl=TensorClamp(), quant_delay_steps=0)[source]#

Bases: Module

ScriptModule that implements scale, shifted, uniform integer quantization of an input tensor, according to an input pre-scale, scale, pre-zero-point, zero-point and bit-width.

Parameters:
  • narrow_range (bool) – Flag that determines whether restrict quantization to a narrow range or not.

  • signed (bool) – Flag that determines whether to quantize to a signed range or not.

  • float_to_int_impl (Module) – Module that performs the conversion from floating point to integer representation. Default: RoundSte()

  • tensor_clamp_impl (Module) – Module that performs clamping. Default: TensorClamp()

  • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

Returns:

Quantized output in de-quantized format.

Return type:

Tensor

Examples

>>> from brevitas.core.scaling import ConstScaling
>>> int_quant = DecoupledIntQuant(narrow_range=True, signed=True)
>>> scale, zero_point, bit_width = torch.tensor(0.01), torch.tensor(0.), torch.tensor(4.)
>>> pre_scale, pre_zero_point = torch.tensor(0.02), torch.tensor(0.)
>>> inp = torch.Tensor([0.042, -0.053, 0.31, -0.44])
>>> out = int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, inp)
>>> out
tensor([ 0.0200, -0.0300,  0.0700, -0.0700])

Note

Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.

forward(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

max_int(bit_width)[source]#
min_int(bit_width)[source]#
to_int(pre_scale, pre_zero_point, bit_width, x)[source]#
Return type:

Tensor

class brevitas.core.quant.int_base.IntQuant(narrow_range, signed, input_view_impl, float_to_int_impl=RoundSte(), tensor_clamp_impl=TensorClamp(), quant_delay_steps=0)[source]#

Bases: Module

ScriptModule that implements scale, shifted, uniform integer quantization of an input tensor, according to an input scale, zero-point and bit-width.

Parameters:
  • narrow_range (bool) – Flag that determines whether restrict quantization to a narrow range or not.

  • signed (bool) – Flag that determines whether to quantize to a signed range or not.

  • float_to_int_impl (Module) – Module that performs the conversion from floating point to integer representation. Default: RoundSte()

  • tensor_clamp_impl (Module) – Module that performs clamping. Default: TensorClamp()

  • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

Returns:

Quantized output in de-quantized format.

Return type:

Tensor

Examples

>>> from brevitas.core.scaling import ConstScaling
>>> int_quant = IntQuant(narrow_range=True, signed=True)
>>> scale, zero_point, bit_width = torch.tensor(0.01), torch.tensor(0.), torch.tensor(4.)
>>> inp = torch.Tensor([0.042, -0.053, 0.31, -0.44])
>>> out = int_quant(scale, zero_point, bit_width, inp)
>>> out
tensor([ 0.0400, -0.0500,  0.0700, -0.0700])

Note

Maps to quant_type == QuantType.INT == ‘INT’ == ‘int’ in higher-level APIs.

Note

Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.

forward(scale, zero_point, bit_width, x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

max_int(bit_width)[source]#
min_int(bit_width)[source]#
to_int(scale, zero_point, bit_width, x)[source]#
Return type:

Tensor

brevitas.core.quant.ternary module#

class brevitas.core.quant.ternary.TernaryQuant(scaling_impl, threshold, quant_delay_steps=None)[source]#

Bases: Module

ScriptModule that implements scaled uniform ternary quantization of an input tensor. Quantization is performed with ternary_sign_ste().

Parameters:
  • scaling_impl (Module) – Module that returns a scale factor.

  • threshold (float) – Ternarization threshold w.r.t. to the scale factor.

  • quant_delay_steps (int) – Number of training steps to delay quantization for. Default: 0

Returns:

Quantized output in de-quantized format, scale,

zero-point, bit_width.

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor]

Examples

>>> from brevitas.core.scaling import ConstScaling
>>> ternary_quant = TernaryQuant(ConstScaling(1.0), 0.5)
>>> inp = torch.Tensor([0.04, -0.6, 3.3])
>>> out, scale, zero_point, bit_width = ternary_quant(inp)
>>> out
tensor([ 0., -1.,  1.])
>>> scale
tensor(1.)
>>> zero_point
tensor(0.)
>>> bit_width
tensor(2.)

Note

Maps to quant_type == QuantType.TERNARY == ‘TERNARY’ == ‘ternary’ in higher-level APIs.

Note

Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Tensor, Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Module contents#