Anatomy of a Quantizer#

What’s in a Quantizer?#

In a broad sense, a quantizer is anything that implements a quantization technique, and the flexibility of Brevitas means that there are different ways to do so.
However, to keep our terminology straight, we refer to a quantizer as a specific kind of way to implement quantization, the one preferred and adopted by default. That is, a quantizer is a subclass of a brevitas.inject.ExtendedInjector that carries a tensor_quant attribute, which points to an instance of a torch Module that implements quantization.

We have seen in previous tutorials quantizers being imported from brevitas.quant and passed on to quantized layers. We can easily very what we just said on one of them:

[1]:
from brevitas.inject import ExtendedInjector
from brevitas.quant.scaled_int import Int8ActPerTensorFloat

issubclass(Int8ActPerTensorFloat, ExtendedInjector)
[1]:
True
[2]:
Int8ActPerTensorFloat.tensor_quant
[2]:
RescalingIntQuant(
  (int_quant): IntQuant(
    (float_to_int_impl): RoundSte()
    (tensor_clamp_impl): TensorClamp()
    (delay_wrapper): DelayWrapper(
      (delay_impl): _NoDelay()
    )
    (input_view_impl): Identity()
  )
  (scaling_impl): ParameterFromRuntimeStatsScaling(
    (stats_input_view_shape_impl): OverTensorView()
    (stats): _Stats(
      (stats_impl): AbsPercentile()
    )
    (restrict_scaling): _RestrictValue(
      (restrict_value_impl): FloatRestrictValue()
    )
    (clamp_scaling): _ClampValue(
      (clamp_min_ste): ScalarClampMinSte()
    )
    (restrict_inplace_preprocess): Identity()
    (restrict_preprocess): Identity()
  )
  (int_scaling_impl): IntScaling()
  (zero_point_impl): ZeroZeroPoint(
    (zero_point): StatelessBuffer()
  )
  (msb_clamp_bit_width_impl): BitWidthConst(
    (bit_width): StatelessBuffer()
  )
)

Note how we said subclass and not instance. To understand why that’s the case, we have to understand what an ExtendedInjector is and why it’s used in the first place.

Quantization with auto-wiring Dependency Injection#

Pytorch has exploded in popularity thanks to its straightforward numpy-like define-by-run execution model. However, when it comes to applying quantization, this style of programming poses a problem.

Many quantization methods depend on making decisions based on the (in Pytorch terms) state_dict of the original floating-point model to finetune with quantization. However, when we instantiate a model in Pytorch we can’t know on the spot if a state_dict is going to be loaded a few lines of code later or not. Yet, because Pytorch is define-by-run, we need our model to work consistently both before and after a state_dict is possibly loaded. In a traditional scenario that wouldn’t pose a problem. However, with quantization in the loop, the way a quantizer is defined might change before and after a pretrained state_dict is loaded.

That means that we need a way to define our quantized model such that it can react appropriately in case the state_dict changes. In a Python-only world that wouldn’t be too hard. However, in order to mitigate the performance impact of quantization-aware training, Brevitas makes extended use of Pytorch’s JIT compiler for a custom subset of Python, TorchScript. That means that in most scenarios, when a state_dict is loaded, we need to recompile parts of the model. Because compilation in general is a lossy process, a TorchScript component cannot simply re-compile itself based on new input information.

We need then a way to declare a quantization method such that it can be re-initialized and JIT compiled any time the state_dict changes. Because we want to support arbitrarly-complex user-defined quantization algorithms, this method has to be generic, i.e. it cannot depend on the specifics of the quantization algorithm implemented.

Implementing a quantizer with an ExtendedInjector is a way to do so. Specifically, an ExtendedInjector extends an Injector from an older version (0.2.1) of the excellent dependency-injection library dependencies with support for a couple of extra features that are specific to Brevitas’ needs.

An Injector (and an ExtendedInjector) allows to take what might be a very complicated graph of interwined objects and turns it into a flat list of variables that are capable of auto-assembly by matching variable names to arguments names. This technique typically goes under the name of auto-wiring dependency injection.

In the context of Brevitas, the goal is gather all the modules and hyperparameters that contribute to a quantization implementation such that they can be re-assembled automatically on demand. What comes out of this process is a tensor_quant object.

A Practical Example: Binary Quantization#

To make things practical, let’s look at how we can implement a simple variant of binary quantization. All the components typically used to implement quantization can be found under brevitas.core. As mentioned before, Brevitas makes heavy use of TorchScript. In particular, all the components found under brevitas.core are implemented as ScriptModule that can be assembled together. The core ScriptModule that implements binarization can be found under brevitas.core.quant:

[3]:
import inspect
from IPython.display import Markdown, display

# helpers
def assert_with_message(condition):
    assert condition
    print(condition)

def pretty_print_source(source):
    display(Markdown('```python\n' + source + '\n```'))
[4]:
from brevitas.core.quant import BinaryQuant

source = inspect.getsource(BinaryQuant)
pretty_print_source(source)
class BinaryQuant(brevitas.jit.ScriptModule):
    """
    ScriptModule that implements scaled uniform binary quantization of an input tensor.
    Quantization is performed with :func:`~brevitas.function.ops_ste.binary_sign_ste`.

    Args:
        scaling_impl (Module): Module that returns a scale factor.
        quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0

    Returns:
        Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width.

    Examples:
        >>> from brevitas.core.scaling import ConstScaling
        >>> binary_quant = BinaryQuant(ConstScaling(0.1))
        >>> inp = torch.Tensor([0.04, -0.6, 3.3])
        >>> out, scale, zero_point, bit_width = binary_quant(inp)
        >>> out
        tensor([ 0.1000, -0.1000,  0.1000])
        >>> scale
        tensor(0.1000)
        >>> zero_point
        tensor(0.)
        >>> bit_width
        tensor(1.)

    Note:
        Maps to quant_type == QuantType.BINARY == 'BINARY' == 'binary' when applied to weights in higher-level APIs.

    Note:
        Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
    """

    def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):
        super(BinaryQuant, self).__init__()
        assert signed, "Unsigned binary quant not supported"
        self.scaling_impl = scaling_impl
        self.bit_width = BitWidthConst(1)
        self.zero_point = StatelessBuffer(torch.tensor(0.0))
        self.delay_wrapper = DelayWrapper(quant_delay_steps)

    @brevitas.jit.script_method
    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        scale = self.scaling_impl(x)
        y = binary_sign_ste(x) * scale
        y = self.delay_wrapper(x, y)
        return y, scale, self.zero_point(), self.bit_width()

The implementation is quite simple. Apart from quant_delay_steps, which allows to delay quantization by a certain number of training steps (default = 0), the only other argument that BinaryQuant accepts is an implementation to compute the scale factor. bit_width is fixed to 1 and zero-point is fixed to 0.

We pick as scale factor implementation a ScriptModule called ParameterScaling, which implements a learned parameter with user-defined initialization. It can be found under brevitas.core.scaling:

[5]:
from brevitas.core.scaling import ParameterScaling

Manual Binary Quantization#

As a first step, we simply instantiate BinaryQuant with ParameterScaling using scaling_init equal 0.1 and we call it on a random floating-point input tensor:

[6]:
import torch

# set seed for notebook
torch.manual_seed(0)

manual_tensor_quant = BinaryQuant(scaling_impl=ParameterScaling(scaling_init=0.1))
manual_tensor_quant(torch.randn(4, 4))
[6]:
(tensor([[-0.1000, -0.1000, -0.1000, -0.1000],
         [ 0.1000,  0.1000, -0.1000, -0.1000],
         [ 0.1000, -0.1000,  0.1000,  0.1000],
         [ 0.1000,  0.1000,  0.1000, -0.1000]], grad_fn=<MulBackward0>),
 tensor(0.1000, grad_fn=<AbsBinarySignGradFnBackward>),
 tensor(0.),
 tensor(1.))

Nothing too surprising here, as expected the tensor is binarized with the scale factor we defined. Note however how manual_tensor_quant is returning a tuple and not a QuantTensor. This is because support for custom data structures in TorchScript is still quite limited, so QuantTensor are allocated only in Python-world abstractions.

Binary Quantization with an ExtendedInjector#

Let’s now declare tensor_quant through an ExtendedInjector:

[7]:
from brevitas.inject import ExtendedInjector

class MyBinaryQuantizer(ExtendedInjector):
    tensor_quant = BinaryQuant
    scaling_impl=ParameterScaling
    scaling_init=0.1

inj_tensor_quant = MyBinaryQuantizer.tensor_quant
inj_tensor_quant(torch.randn(4, 4))
[7]:
(tensor([[-0.1000, -0.1000,  0.1000,  0.1000],
         [ 0.1000, -0.1000, -0.1000,  0.1000],
         [ 0.1000, -0.1000, -0.1000,  0.1000],
         [ 0.1000,  0.1000,  0.1000, -0.1000]], grad_fn=<MulBackward0>),
 tensor(0.1000, grad_fn=<AbsBinarySignGradFnBackward>),
 tensor(0.),
 tensor(1.))

Any time MyBinaryQuantizer.tensor_quant is called, a new instance of BinaryQuant is created. Note how the attributes of MyBinaryQuantizer are designed to match the name of the arguments of each other, except for tensor_quant, which is what we are interested in retrieving from the outside.

Inheritance and Composition of Quantizers#

The advantage of expressing a quantizer through a Python class also means that we can leverage both inheritance and composition. So for example we can inherit from MyBinaryQuantizer and override scaling_init with a new value:

[8]:
class MyChildBinaryQuantizer(MyBinaryQuantizer):
    scaling_init=1.0

child_inj_tensor_quant = MyChildBinaryQuantizer.tensor_quant
child_inj_tensor_quant(torch.randn(4, 4))
[8]:
(tensor([[-1.,  1., -1.,  1.],
         [ 1.,  1.,  1.,  1.],
         [-1.,  1., -1.,  1.],
         [ 1.,  1., -1., -1.]], grad_fn=<MulBackward0>),
 tensor(1., grad_fn=<AbsBinarySignGradFnBackward>),
 tensor(0.),
 tensor(1.))

Or we can leverage composition by assembling together various classes containing different pieces of a quantizer:

[9]:
class MyBinaryImpl(ExtendedInjector):
    tensor_quant = BinaryQuant

class MyScalingImpl(ExtendedInjector):
    scaling_impl=ParameterScaling
    scaling_init=0.1

class MyComposedBinaryQuantizer(MyBinaryImpl, MyScalingImpl):
    pass

comp_inj_tensor_quant = MyComposedBinaryQuantizer.tensor_quant
comp_inj_tensor_quant(torch.randn(4, 4))
[9]:
(tensor([[-0.1000,  0.1000,  0.1000,  0.1000],
         [-0.1000, -0.1000,  0.1000, -0.1000],
         [-0.1000,  0.1000,  0.1000,  0.1000],
         [-0.1000, -0.1000,  0.1000, -0.1000]], grad_fn=<MulBackward0>),
 tensor(0.1000, grad_fn=<AbsBinarySignGradFnBackward>),
 tensor(0.),
 tensor(1.))

Interfacing a Quantizer with a Quantized Layer#

Before we can pass the quantizer to a quantized layer such as QuantConv2d, we need a last component to define, a proxy. A proxy (found under brevitas.proxy) is an nn.Module that serve as interface between a quantizer and a quantized layer.

While a quantizer lives mostly in JIT-land, a proxy lives mostly in Python-land, and as such can afford much more flexibility. Proxies take care of returning a QuantTensor and re-initializing the output of quantizer whenever a new state_dict is loaded.

Proxies are specific to the kind of tensor being quantized, as in weights vs biases vs activations. For convenience, they are declared as part of the quantizer itself under the attribute proxy_class. For example, for weights we can use WeightQuantProxyFromInjector:

[10]:
from brevitas.proxy import WeightQuantProxyFromInjector

class MyBinaryWeightQuantizer(MyBinaryQuantizer):
    proxy_class = WeightQuantProxyFromInjector

We can now use MyBinaryWeightQuantizer as the weight quantizer of a layer:

[11]:
from brevitas.nn import QuantConv2d

binary_weight_quant_conv = QuantConv2d(3, 2, (3,3), weight_quant=MyBinaryWeightQuantizer)
try:
    quant_weight = binary_weight_quant_conv.quant_weight()
except TypeError:
    pass

Note however that we cannot compute the quantized weight, as the signed attribute is None.

signed is one of those attributes that in the case of binary quantization has to be explicitly defined by the user. The idea is that it informs the proxy on whether the value generated by our quantizer should be considered signed or not. We can do so by simply setting it in the quantizer:

[12]:
class MySignedBinaryWeightQuantizer(MyBinaryWeightQuantizer):
    signed = True

binary_weight_quant_conv = QuantConv2d(3, 2, (3,3), weight_quant=MySignedBinaryWeightQuantizer)
signed_quant_weight = binary_weight_quant_conv.quant_weight()
signed_quant_weight
[12]:
IntQuantTensor(value=tensor([[[[ 0.1000,  0.1000, -0.1000],
          [ 0.1000, -0.1000, -0.1000],
          [-0.1000, -0.1000, -0.1000]],

         [[-0.1000, -0.1000,  0.1000],
          [-0.1000, -0.1000,  0.1000],
          [-0.1000,  0.1000, -0.1000]],

         [[ 0.1000,  0.1000,  0.1000],
          [ 0.1000,  0.1000, -0.1000],
          [-0.1000, -0.1000,  0.1000]]],


        [[[-0.1000, -0.1000, -0.1000],
          [-0.1000, -0.1000, -0.1000],
          [ 0.1000,  0.1000, -0.1000]],

         [[ 0.1000, -0.1000, -0.1000],
          [-0.1000, -0.1000, -0.1000],
          [-0.1000, -0.1000, -0.1000]],

         [[ 0.1000, -0.1000,  0.1000],
          [-0.1000, -0.1000,  0.1000],
          [-0.1000,  0.1000,  0.1000]]]], grad_fn=<MulBackward0>), scale=tensor(0.1000, grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))
[13]:
assert_with_message(signed_quant_weight.is_valid)
True

And now the quant weights are valid.

When we want to add or override an single attribute of a quantizer passed to a layer, defining a whole new quantizer can be too verbose. There is a simpler syntax to achieve the same goal. Let’s say we want to have add the signed attribute to MyBinaryQuantizer, as we just did. We could have also simply done the following:

[14]:
small_scale_quant_conv = QuantConv2d(3, 2, (3,3), weight_quant=MyBinaryWeightQuantizer, weight_signed=True)
small_scale_quant_conv.quant_weight()
[14]:
IntQuantTensor(value=tensor([[[[-0.1000, -0.1000,  0.1000],
          [-0.1000,  0.1000, -0.1000],
          [ 0.1000,  0.1000,  0.1000]],

         [[ 0.1000, -0.1000, -0.1000],
          [-0.1000,  0.1000,  0.1000],
          [ 0.1000, -0.1000,  0.1000]],

         [[-0.1000, -0.1000, -0.1000],
          [ 0.1000, -0.1000, -0.1000],
          [ 0.1000,  0.1000, -0.1000]]],


        [[[-0.1000, -0.1000,  0.1000],
          [-0.1000,  0.1000,  0.1000],
          [ 0.1000, -0.1000, -0.1000]],

         [[ 0.1000, -0.1000, -0.1000],
          [ 0.1000, -0.1000, -0.1000],
          [ 0.1000, -0.1000, -0.1000]],

         [[ 0.1000,  0.1000,  0.1000],
          [-0.1000,  0.1000, -0.1000],
          [-0.1000, -0.1000, -0.1000]]]], grad_fn=<MulBackward0>), scale=tensor(0.1000, grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))

What we did was to take the name of the attribute signed, add the prefix weight_, and pass it as a keyword argument to QuantConv2d. What happens in the background is that the keyword arguments prefixed with weight_ are set as attributes of weight_quant, possibly overriding any pre-existing value. The same principle applies to input_, output_ and bias_.

This is the reason why, as it was mentioned in the first tutorial, quantized layers can accept arbitrary keyword arguments. It’s really just a way to support different styles of syntax when defining a quantizer.

Passing a custom quantizer to QuantIdentity#

We can do a similar thing with quantized activations:

[15]:
from brevitas.proxy import ActQuantProxyFromInjector
from brevitas.nn import QuantIdentity

class MySignedBinaryActQuantizer(MyBinaryQuantizer):
    proxy_class = ActQuantProxyFromInjector
    signed = True

binary_relu = QuantIdentity(act_quant=MySignedBinaryActQuantizer, return_quant_tensor=True)
binary_relu(torch.randn(4, 4))
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/torch/_tensor.py:1255: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/c10/core/TensorImpl.h:1758.)
  return super(Tensor, self).rename(names)
[15]:
IntQuantTensor(value=tensor([[-0.1000,  0.1000, -0.1000, -0.1000],
        [ 0.1000,  0.1000,  0.1000,  0.1000],
        [-0.1000, -0.1000,  0.1000, -0.1000],
        [-0.1000,  0.1000, -0.1000,  0.1000]], grad_fn=<MulBackward0>), scale=tensor(0.1000, grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))

So there isn’t really much difference between a quantizer for weights and a quantizer for activations, they are just wrapped by different proxies. Also, with activations a prefix is not required when passing keyword arguments. For example, when can override the existing scaling_init defined in MyBinaryQuantizer with a new value passed in as a keywork argument:

[16]:
small_scale_binary_identity = QuantIdentity(
    act_quant=MySignedBinaryActQuantizer, scaling_init=0.001, return_quant_tensor=True)
small_scale_binary_identity(torch.randn(4, 4))
[16]:
IntQuantTensor(value=tensor([[ 0.0010, -0.0010, -0.0010,  0.0010],
        [ 0.0010,  0.0010,  0.0010,  0.0010],
        [ 0.0010, -0.0010,  0.0010,  0.0010],
        [ 0.0010, -0.0010, -0.0010, -0.0010]], grad_fn=<MulBackward0>), scale=tensor(0.0010, grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))

A Custom Quantizer initialized with Weight Statistics#

So far we have seen use-cases where an ExtendedInjector provides, at best, a different kind of syntax to define a quantizer, without any particular other advantage. Let’s now make things a bit more complicated to show the sort of situations where it really shines.

Let’s say we want to define a binary weight quantizer where scaling_impl is still ParameterScaling. However, instead of being user-defined, we want scaling_init to be the maximum value found in the weight tensor of the quantized layer. To support this sort of use cases where the quantizer depends on the layer, a quantized layer automatically passes itself to all its quantizers under the name of module. With only a few lines of code then, we can achieve our goal:

[17]:
from brevitas.inject import value

class ParamFromMaxWeightQuantizer(MySignedBinaryWeightQuantizer):

    @value
    def scaling_init(module):
        return module.weight.abs().max()

Note how we are leveraging the @value decorator to define a function that is executed at dependency-injection (DI) time. This kind of behaviour is similar in spirit to defining a @property instead of an attribute, with the difference that a @value function can depend on other attributes of the Injector, which are automatically passed in as arguments of the function during DI.

Let’s now pass the quantizer to a QuantConv2d and retrieve its quantized weights:

[18]:
param_from_max_quant_conv = QuantConv2d(3, 2, (3, 3), weight_quant=ParamFromMaxWeightQuantizer)
param_from_max_quant_conv.quant_weight()
[18]:
IntQuantTensor(value=tensor([[[[ 0.1820, -0.1820, -0.1820],
          [ 0.1820,  0.1820,  0.1820],
          [-0.1820, -0.1820, -0.1820]],

         [[ 0.1820, -0.1820, -0.1820],
          [ 0.1820, -0.1820, -0.1820],
          [ 0.1820,  0.1820, -0.1820]],

         [[-0.1820,  0.1820,  0.1820],
          [ 0.1820, -0.1820,  0.1820],
          [-0.1820, -0.1820, -0.1820]]],


        [[[ 0.1820,  0.1820, -0.1820],
          [-0.1820, -0.1820,  0.1820],
          [-0.1820,  0.1820, -0.1820]],

         [[-0.1820,  0.1820, -0.1820],
          [ 0.1820,  0.1820, -0.1820],
          [ 0.1820,  0.1820,  0.1820]],

         [[-0.1820, -0.1820, -0.1820],
          [ 0.1820, -0.1820,  0.1820],
          [ 0.1820, -0.1820,  0.1820]]]], grad_fn=<MulBackward0>), scale=tensor(0.1820, grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))

Indeed we can verify that weight_quant.scale() is equal to weight.abs().max():

[19]:
assert_with_message((param_from_max_quant_conv.weight_quant.scale() == param_from_max_quant_conv.weight.abs().max()).item())
True

Let’s say now that we want to load a pretrained floating-point weight tensor on top of our quantized model. We simuate this scenario by defining a separate nn.Conv2d layer with the same weight shape:

[20]:
from torch import nn

float_conv = nn.Conv2d(3, 2, (3, 3))
float_conv.weight.abs().max()
[20]:
tensor(0.1924, grad_fn=<MaxBackward1>)

and then we load it on top of param_from_max_quant_conv:

[21]:
param_from_max_quant_conv.load_state_dict(float_conv.state_dict())
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[21], line 1
----> 1 param_from_max_quant_conv.load_state_dict(float_conv.state_dict())

File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/torch/nn/modules/module.py:1671, in Module.load_state_dict(self, state_dict, strict)
   1666         error_msgs.insert(
   1667             0, 'Missing key(s) in state_dict: {}. '.format(
   1668                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1670 if len(error_msgs) > 0:
-> 1671     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1672                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1673 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for QuantConv2d:
        Missing key(s) in state_dict: "weight_quant.tensor_quant.scaling_impl.value".

Ouch, we get an error. This is because ParameterScaling contains a learned torch.nn.Parameter, and Pytorch expects all learned parameters of a model to be contained in a state_dict that is being loaded. We can work around the issue by either setting the IGNORE_MISSING_KEYS config flag in Brevitas, or by passing strict=False to load_state_dict. We go with the former as setting strict=False is too forgiving to other kind of problems:

[22]:
from brevitas import config
config.IGNORE_MISSING_KEYS = True

param_from_max_quant_conv.load_state_dict(float_conv.state_dict())
[22]:
<All keys matched successfully>

Note that we could have also achieve the same goal by setting the env variable BREVITAS_IGNORE_MISSING_KEYS=1.

And now if we take a look at the quantized weights again:

[23]:
param_from_max_quant_conv.quant_weight()
[23]:
IntQuantTensor(value=tensor([[[[ 0.1924,  0.1924, -0.1924],
          [ 0.1924,  0.1924,  0.1924],
          [ 0.1924,  0.1924,  0.1924]],

         [[-0.1924, -0.1924, -0.1924],
          [ 0.1924,  0.1924, -0.1924],
          [ 0.1924,  0.1924,  0.1924]],

         [[ 0.1924,  0.1924, -0.1924],
          [-0.1924,  0.1924,  0.1924],
          [-0.1924,  0.1924,  0.1924]]],


        [[[ 0.1924, -0.1924,  0.1924],
          [-0.1924,  0.1924, -0.1924],
          [ 0.1924,  0.1924,  0.1924]],

         [[-0.1924,  0.1924, -0.1924],
          [ 0.1924, -0.1924, -0.1924],
          [-0.1924,  0.1924,  0.1924]],

         [[ 0.1924,  0.1924, -0.1924],
          [-0.1924, -0.1924, -0.1924],
          [ 0.1924, -0.1924, -0.1924]]]], grad_fn=<MulBackward0>), scale=tensor(0.1924, grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))

We see that, as expected, the scale factor has been updated to the new weight.abs().max().

What happens internally is that after load_state_dict is called on the layer, ParamFromMaxWeightQuantizer.tensor_quant gets called again to re-initialize BinaryQuant, and in turn ParameterScaling is re-initialized with a new scaling_init value computed based on the updated module.weight tensor. This whole process wouldn’t have been possible without an ExtendedInjector behind it.

Sharing a Quantizer#

There are two ways to share a quantizer between multiple layers, with importance differences.

The first one, which we have seen so far, is to simply pass the same ExtendedInjector to multiple layers. What that does is sharing the same quantization strategy among different layers. Each layer still gets its own instance of the quantization implementation.

[24]:
quant_conv1 = QuantConv2d(3, 2, (3, 3), weight_quant=MySignedBinaryWeightQuantizer)
quant_conv2 = QuantConv2d(3, 2, (3, 3), weight_quant=MySignedBinaryWeightQuantizer)

assert_with_message(quant_conv1.weight_quant is not quant_conv2.weight_quant)
True

Sharing a proxy#

The second one, which we are introducing now, allows to share the same quantization instance among multiple layers. This is done by simply sharing the proxy wrapping it. This can be useful in those scenarios where, for example, we want different layers to share the same scale factor. The syntax goes as follows:

[25]:
quant_conv1 = QuantConv2d(3, 2, (3, 3), weight_quant=MySignedBinaryWeightQuantizer)
quant_conv2 = QuantConv2d(3, 2, (3, 3), weight_quant=quant_conv1.weight_quant)

assert_with_message(quant_conv1.weight_quant is quant_conv2.weight_quant)
True
[26]:
assert_with_message((quant_conv1.weight_quant.scale() == quant_conv2.weight_quant.scale()).item())
True

What happens in background is that the weight quantizer now has access to both quant_conv1 and quant_conv2. So let’s say we want to build a quantizer similar to ParamFromMaxWeightQuantizer, but in this case we want the scale factor to be initialized with the average of both weight tensors. When a quantizer has access to multiple parent modules, they are passed in at dependency injection time as a tuple under the same name module as before. So we can do the following:

[27]:
class SharedParamFromMeanWeightQuantizer(MySignedBinaryWeightQuantizer):

    @value
    def scaling_init(module):
        if isinstance(module, tuple):
            return torch.cat((module[0].weight.view(-1), module[1].weight.view(-1))).abs().mean()
        else:
            return module.weight.abs().mean()

quant_conv1 = QuantConv2d(3, 2, (3, 3), weight_quant=SharedParamFromMeanWeightQuantizer)
old_quant_conv1_scale = quant_conv1.weight_quant.scale()
quant_conv2 = QuantConv2d(3, 2, (3, 3), weight_quant=quant_conv1.weight_quant)
new_quant_conv1_scale = quant_conv1.weight_quant.scale()

assert_with_message(not (old_quant_conv1_scale == new_quant_conv1_scale).item())
True
[28]:
assert_with_message((new_quant_conv1_scale == quant_conv2.weight_quant.scale()).item())
True

Note how, when quant_conv2 is initialized using the weight_quant of quant_conv1, weight quantization is re-initialized for both layers such that they end up having the same scale.

We can see in this example how Brevitas works consistently with Pytorch’s eager execution model. When we initialize quant_conv1 we still don’t know that its weight quantizer is going to be shared with quant_conv2, and the semantics of Pytorch impose that quant_conv1 should work correctly both before and after quant_conv2 is declared. The way we take advantage of dependency injection allows to do so.

Sharing an instance of Activation Quantization#

Sharing an instance of activation quantization is easier because for most scenarios it’s enough to simply share the whole layer itself, e.g. calling the same QuantReLU from multiple places in the forward pass.

For those scenarios where sharing the whole layer is not possible, there is something important to keep in mind. Instances of activation quantization include (for performance reasons) the implementation of the non-linear activation itself (if any). So, for example, using a QuantReLU.act_quant to initialize a QuantConv2d.output_quant should be avoided as we would not share not only the quantizer, but also the relu activation function.
In general then sharing of instances of activations quantization should be done only between activations of the same kind.

Note: we say kind and not type because input_quant, output_quant and IdentityQuant count as being the same kind of activation, even though they belong to different type of layers.

Dealing with Weight Initialization#

There is a type of situation that Brevitas cannot deal with automatically. That is, when the initialization of the quantizer depends on the layer to which it is applied (like with the ParamFromMaxWeightQuantizer or SharedParamFromMeanWeightQuantizer quantizers), but the layer gets modified after it is initialized.

The typical example is with weight initialization when training from scratch (so rather than loading from a floating-point state_dict):

[29]:
quant_conv_w_init = QuantConv2d(3, 2, (3, 3), weight_quant=ParamFromMaxWeightQuantizer)
torch.nn.init.uniform_(quant_conv_w_init.weight)

assert_with_message(not (quant_conv_w_init.weight.abs().max() == quant_conv_w_init.weight_quant.scale()).item())
True

We can see how the scale factor is not initialized correctly anymore. In this case we can simply trigger re-initialization of the weight quantizer manually:

[30]:
quant_conv_w_init.weight_quant.init_tensor_quant()

assert_with_message((quant_conv_w_init.weight.abs().max() == quant_conv_w_init.weight_quant.scale()).item())
True

Note: because the way weights are initialized is often the same as to how an optimizer performs the weight update step, there are currently no plans to try to perform re-initialization automatically (as it happens e.g. when a state_dict is loaded) since it wouldn’t be possible to distinguish between the two scenarios.

Building a Custom Quantization API#

Finally, let’s go through an even more complicated example. We are going to look at a scenario that illustrates the differences between a standard Injector (implemented in the dependencies library) and our ExtendedInjector extension.

Let’s say we want to build two quantizers for respectively weights and activations and build a simple API on top of them. In particular, we want to be able to switch between BinaryQuant and ClampedBinaryQuant (a variant of binary quantization with clamping), and we want to optionally perform per-channel scaling. To do so, we are going to implement the controlling logic through a hierarchy of ExtendedInjector, leaving two boolean flags exposed as arguments of the quantizers, with the idea then that the flags can be set through keyword arguments of the respective quantized layers.

We can go as follows:

[31]:
from brevitas.core.quant import ClampedBinaryQuant
from brevitas.proxy import WeightQuantProxyFromInjector, ActQuantProxyFromInjector
from brevitas.inject import this


class CommonQuantizer(ExtendedInjector):
    scaling_impl = ParameterScaling
    signed=True

    @value
    def tensor_quant(is_clamped):
        # returning a class to auto-wire from a value function
        # wouldn't be allowed in a standard Injector
        if is_clamped:
            return ClampedBinaryQuant
        else:
            return BinaryQuant

    @value
    def scaling_shape(scaling_per_output_channel):
        if scaling_per_output_channel:
            # returning this.something from a value function
            # wouldn't be allowed in a standard Injector
            return this.per_channel_broadcastable_shape
        else:
            return ()


class AdvancedWeightQuantizer(CommonQuantizer):
    proxy_class = WeightQuantProxyFromInjector

    @value
    def per_channel_broadcastable_shape(module):
        return (module.weight.shape[0], 1, 1, 1)

    @value
    def scaling_init(module, scaling_per_output_channel):
        if scaling_per_output_channel:
            num_ch = module.weight.shape[0]
            return module.weight.abs().view(num_ch, -1).max(dim=1)[0].view(-1, 1, 1, 1)
        else:
            return module.weight.abs().max()


class AdvancedActQuantizer(CommonQuantizer):
    scaling_init = 0.01
    proxy_class = ActQuantProxyFromInjector

There are a bunch of things going on here to unpack.

The first one is that a @value function can return a class to auto-wire and inject, as seen in the definition of tensor_quant. This wouldn’t normally be possible with a standard Injector, but it’s possible with an ExtendedInjector. This way we can switch between different implementations of tensor_quant.

The second one is the special object this. this is already present in the dependencies library, and it’s used as a way to retrieve attributes of the quantizer from within the quantizer itself. However, normally it wouldn’t be possible to return a reference to this from a @value function. Again this is something that only a ExtendedInjector supports, and it allows to chain different attributes in a way such that the chained values are computed only when necessary.

Let’s see the quantizers applied to a layer:

[32]:
per_channel_quant_conv = QuantConv2d(
    3, 2, (3, 3),
    weight_quant=AdvancedWeightQuantizer,
    weight_is_clamped=False,
    weight_scaling_per_output_channel=True)
per_channel_quant_conv.quant_weight()
[32]:
IntQuantTensor(value=tensor([[[[ 0.1612, -0.1612, -0.1612],
          [-0.1612, -0.1612, -0.1612],
          [ 0.1612,  0.1612,  0.1612]],

         [[-0.1612,  0.1612, -0.1612],
          [-0.1612,  0.1612,  0.1612],
          [-0.1612, -0.1612,  0.1612]],

         [[-0.1612,  0.1612,  0.1612],
          [ 0.1612,  0.1612, -0.1612],
          [ 0.1612,  0.1612,  0.1612]]],


        [[[ 0.1924,  0.1924,  0.1924],
          [-0.1924, -0.1924,  0.1924],
          [-0.1924,  0.1924, -0.1924]],

         [[ 0.1924, -0.1924,  0.1924],
          [ 0.1924,  0.1924, -0.1924],
          [ 0.1924, -0.1924, -0.1924]],

         [[-0.1924, -0.1924,  0.1924],
          [ 0.1924, -0.1924, -0.1924],
          [ 0.1924, -0.1924,  0.1924]]]], grad_fn=<MulBackward0>), scale=tensor([[[[0.1612]]],


        [[[0.1924]]]], grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))

As expected the weight scale is now a vector. Everything we said so far about quantizers still applies, so for example we can load the floating-point state dict we defined before and observe how it triggers an update of the weight scale:

[33]:
per_channel_quant_conv.load_state_dict(float_conv.state_dict())
per_channel_quant_conv.quant_weight()
[33]:
IntQuantTensor(value=tensor([[[[ 0.1924,  0.1924, -0.1924],
          [ 0.1924,  0.1924,  0.1924],
          [ 0.1924,  0.1924,  0.1924]],

         [[-0.1924, -0.1924, -0.1924],
          [ 0.1924,  0.1924, -0.1924],
          [ 0.1924,  0.1924,  0.1924]],

         [[ 0.1924,  0.1924, -0.1924],
          [-0.1924,  0.1924,  0.1924],
          [-0.1924,  0.1924,  0.1924]]],


        [[[ 0.1899, -0.1899,  0.1899],
          [-0.1899,  0.1899, -0.1899],
          [ 0.1899,  0.1899,  0.1899]],

         [[-0.1899,  0.1899, -0.1899],
          [ 0.1899, -0.1899, -0.1899],
          [-0.1899,  0.1899,  0.1899]],

         [[ 0.1899,  0.1899, -0.1899],
          [-0.1899, -0.1899, -0.1899],
          [ 0.1899, -0.1899, -0.1899]]]], grad_fn=<MulBackward0>), scale=tensor([[[[0.1924]]],


        [[[0.1899]]]], grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))

In this case we have a per-channel quantizer, so the original floating-point weight tensor is now quantized per channel.

Similarly, we can apply our custom activation quantizer to e.g. a QuantIdentity layer:

[34]:
from brevitas.nn import QuantIdentity

quant_identity = QuantIdentity(
    act_quant=AdvancedActQuantizer, is_clamped=True, scaling_per_output_channel=False)
quant_identity(torch.randn(4, 4))
[34]:
tensor([[ 0.0100,  0.0100, -0.0100,  0.0100],
        [ 0.0100,  0.0100, -0.0100,  0.0100],
        [-0.0100,  0.0100, -0.0100, -0.0100],
        [ 0.0100, -0.0100, -0.0100, -0.0100]], grad_fn=<MulBackward0>)

Note how AdvancedActQuantizer doesn’t define a per_channel_broadcastable_shape, yet no errors are triggered. This is because this.per_channel_broadcastable_shape is required only when scaling_per_output_channel is True, while in this case scaling_per_output_channel is False. Let’ try to set it to True then:

[35]:
from brevitas.nn import QuantIdentity

quant_identity = QuantIdentity(
    act_quant=AdvancedActQuantizer, is_clamped=True, scaling_per_output_channel=True)
---------------------------------------------------------------------------
DependencyError                           Traceback (most recent call last)
Cell In[35], line 3
      1 from brevitas.nn import QuantIdentity
----> 3 quant_identity = QuantIdentity(
      4     act_quant=AdvancedActQuantizer, is_clamped=True, scaling_per_output_channel=True)

File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/quant_activation.py:113, in QuantIdentity.__init__(self, act_quant, return_quant_tensor, **kwargs)
    108 def __init__(
    109         self,
    110         act_quant: Optional[ActQuantType] = Int8ActPerTensorFloat,
    111         return_quant_tensor: bool = False,
    112         **kwargs):
--> 113     QuantNLAL.__init__(
    114         self,
    115         input_quant=None,
    116         act_impl=None,
    117         passthrough_act=True,
    118         act_quant=act_quant,
    119         return_quant_tensor=return_quant_tensor,
    120         **kwargs)

File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/quant_layer.py:34, in QuantNonLinearActLayer.__init__(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)
     32 QuantLayerMixin.__init__(self, return_quant_tensor)
     33 QuantInputMixin.__init__(self, input_quant, **kwargs)
---> 34 QuantNonLinearActMixin.__init__(self, act_impl, passthrough_act, act_quant, **kwargs)

File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/act.py:66, in QuantNonLinearActMixin.__init__(self, act_impl, passthrough_act, act_quant, act_proxy_prefix, act_kwargs_prefix, **kwargs)
     55 def __init__(
     56         self,
     57         act_impl: Optional[Type[Module]],
   (...)
     61         act_kwargs_prefix='',
     62         **kwargs):
     63     prefixed_kwargs = {
     64         act_kwargs_prefix + 'act_impl': act_impl,
     65         act_kwargs_prefix + 'passthrough_act': passthrough_act}
---> 66     QuantProxyMixin.__init__(
     67         self,
     68         quant=act_quant,
     69         proxy_prefix=act_proxy_prefix,
     70         kwargs_prefix=act_kwargs_prefix,
     71         proxy_protocol=ActQuantProxyProtocol,
     72         none_quant_injector=NoneActQuant,
     73         **prefixed_kwargs,
     74         **kwargs)

File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:48, in QuantProxyMixin.__init__(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)
     46     quant_injector = quant
     47     quant_injector = quant_injector.let(**filter_kwargs(kwargs_prefix, kwargs))
---> 48     quant = quant_injector.proxy_class(self, quant_injector)
     49 else:
     50     if not isinstance(quant, proxy_protocol):

File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/proxy/runtime_quant.py:198, in ActQuantProxyFromInjector.__init__(self, quant_layer, quant_injector)
    197 def __init__(self, quant_layer, quant_injector):
--> 198     super().__init__(quant_layer, quant_injector)
    199     self.cache_class = _CachedIO

File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/proxy/runtime_quant.py:93, in ActQuantProxyFromInjectorBase.__init__(self, quant_layer, quant_injector)
     92 def __init__(self, quant_layer, quant_injector):
---> 93     QuantProxyFromInjector.__init__(self, quant_layer, quant_injector)
     94     ActQuantProxyProtocol.__init__(self)
     95     self.is_passthrough_act = _is_passthrough_act(quant_injector)

File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/proxy/quant_proxy.py:80, in QuantProxyFromInjector.__init__(self, quant_layer, quant_injector)
     78 # Use a normal list and not a ModuleList since this is a pointer to parent modules
     79 self.tracked_module_list = []
---> 80 self.add_tracked_module(quant_layer)
     81 self.disable_quant = False
     82 # Torch.compile compatibility requires this

File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/proxy/quant_proxy.py:120, in QuantProxyFromInjector.add_tracked_module(self, module)
    118     self.tracked_module_list.append(module)
    119     self.update_tracked_modules()
--> 120     self.init_tensor_quant()
    121 else:
    122     raise RuntimeError("Trying to add None as a parent module.")

File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/proxy/runtime_quant.py:127, in ActQuantProxyFromInjectorBase.init_tensor_quant(self)
    126 def init_tensor_quant(self):
--> 127     tensor_quant = self.quant_injector.tensor_quant
    128     if 'act_impl' in self.quant_injector:
    129         act_impl = self.quant_injector.act_impl

    [... skipping hidden 1 frame]

File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/_dependencies/this.py:51, in _ThisSpec.__call__(self, __self__)
     49 if kind == ".":
     50     try:
---> 51         result = getattr(result, symbol)
     52     except DependencyError:
     53         message = (
     54             "You tried to shift this more times than Injector has levels"
     55         )

File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/inject/__init__.py:129, in _ExtendedInjectorType.__getattr__(cls, attrname)
    126     else:
    127         message = "{!r} can not resolve attribute {!r}".format(
    128             cls.__name__, current_attr)
--> 129     raise DependencyError(message)
    131 marker, attribute, args, have_defaults = spec
    133 if set(args).issubset(cached):

DependencyError: 'AdvancedActQuantizer' can not resolve attribute 'per_channel_broadcastable_shape'

As expected we get an error saying that the quantizer cannot resolve per_channel_broadcastable_shape. If we pass it in then we can get a per-channel quantizer:

[36]:
quant_identity = QuantIdentity(
    act_quant=AdvancedActQuantizer, is_clamped=True, scaling_per_output_channel=True,
    per_channel_broadcastable_shape=(4, 1), return_quant_tensor=True)
quant_identity(torch.randn(4, 4))
[36]:
IntQuantTensor(value=tensor([[ 0.0100,  0.0100,  0.0100, -0.0100],
        [ 0.0100, -0.0100, -0.0100,  0.0100],
        [-0.0100, -0.0100, -0.0100,  0.0100],
        [ 0.0100,  0.0100,  0.0100,  0.0100]], grad_fn=<MulBackward0>), scale=tensor([[0.0100],
        [0.0100],
        [0.0100],
        [0.0100]], grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))

We have seen how powerful dependency injection is. In a way, it’s even too expressive. For users that are not interesting in building completely custom quantizers, it can be hard to make sense of how the various components available under brevitas.core can be assembled together according to best practices.