Anatomy of a Quantizer#
What’s in a Quantizer?#
brevitas.inject.ExtendedInjector
that carries a tensor_quant
attribute, which points to an instance of a torch Module
that implements quantization.We have seen in previous tutorials quantizers being imported from brevitas.quant
and passed on to quantized layers. We can easily very what we just said on one of them:
[1]:
from brevitas.inject import ExtendedInjector
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
issubclass(Int8ActPerTensorFloat, ExtendedInjector)
[1]:
True
[2]:
Int8ActPerTensorFloat.tensor_quant
[2]:
RescalingIntQuant(
(int_quant): IntQuant(
(float_to_int_impl): RoundSte()
(tensor_clamp_impl): TensorClamp()
(delay_wrapper): DelayWrapper(
(delay_impl): _NoDelay()
)
(input_view_impl): Identity()
)
(scaling_impl): ParameterFromRuntimeStatsScaling(
(stats_input_view_shape_impl): OverTensorView()
(stats): _Stats(
(stats_impl): AbsPercentile()
)
(restrict_scaling): _RestrictValue(
(restrict_value_impl): FloatRestrictValue()
)
(clamp_scaling): _ClampValue(
(clamp_min_ste): ScalarClampMinSte()
)
(restrict_inplace_preprocess): Identity()
(restrict_preprocess): Identity()
)
(int_scaling_impl): IntScaling()
(zero_point_impl): ZeroZeroPoint(
(zero_point): StatelessBuffer()
)
(msb_clamp_bit_width_impl): BitWidthConst(
(bit_width): StatelessBuffer()
)
)
Note how we said subclass and not instance. To understand why that’s the case, we have to understand what an ExtendedInjector
is and why it’s used in the first place.
Quantization with auto-wiring Dependency Injection#
Pytorch has exploded in popularity thanks to its straightforward numpy-like define-by-run execution model. However, when it comes to applying quantization, this style of programming poses a problem.
Many quantization methods depend on making decisions based on the (in Pytorch terms) state_dict
of the original floating-point model to finetune with quantization. However, when we instantiate a model in Pytorch we can’t know on the spot if a state_dict is going to be loaded a few lines of code later or not. Yet, because Pytorch is define-by-run, we need our model to work consistently both before and after a state_dict
is possibly loaded. In a traditional scenario that wouldn’t pose a
problem. However, with quantization in the loop, the way a quantizer is defined might change before and after a pretrained state_dict
is loaded.
That means that we need a way to define our quantized model such that it can react appropriately in case the state_dict
changes. In a Python-only world that wouldn’t be too hard. However, in order to mitigate the performance impact of quantization-aware training, Brevitas makes extended use of Pytorch’s JIT compiler for a custom subset of Python, TorchScript. That means that in most scenarios, when a state_dict
is loaded, we need to recompile parts of the model. Because compilation in
general is a lossy process, a TorchScript component cannot simply re-compile itself based on new input information.
We need then a way to declare a quantization method such that it can be re-initialized and JIT compiled any time the state_dict
changes. Because we want to support arbitrarly-complex user-defined quantization algorithms, this method has to be generic, i.e. it cannot depend on the specifics of the quantization algorithm implemented.
Implementing a quantizer with an ExtendedInjector
is a way to do so. Specifically, an ExtendedInjector
extends an Injector
from an older version (0.2.1) of the excellent dependency-injection library dependencies with support for a couple of extra features that are specific to Brevitas’ needs.
An Injector
(and an ExtendedInjector
) allows to take what might be a very complicated graph of interwined objects and turns it into a flat list of variables that are capable of auto-assembly by matching variable names to arguments names. This technique typically goes under the name of auto-wiring dependency injection.
In the context of Brevitas, the goal is gather all the modules and hyperparameters that contribute to a quantization implementation such that they can be re-assembled automatically on demand. What comes out of this process is a tensor_quant
object.
A Practical Example: Binary Quantization#
To make things practical, let’s look at how we can implement a simple variant of binary quantization. All the components typically used to implement quantization can be found under brevitas.core
. As mentioned before, Brevitas makes heavy use of TorchScript. In particular, all the components found under brevitas.core
are implemented as ScriptModule
that can be assembled together. The core ScriptModule
that implements binarization can be found under brevitas.core.quant
:
[3]:
import inspect
from IPython.display import Markdown, display
# helpers
def assert_with_message(condition):
assert condition
print(condition)
def pretty_print_source(source):
display(Markdown('```python\n' + source + '\n```'))
[4]:
from brevitas.core.quant import BinaryQuant
source = inspect.getsource(BinaryQuant)
pretty_print_source(source)
class BinaryQuant(brevitas.jit.ScriptModule):
"""
ScriptModule that implements scaled uniform binary quantization of an input tensor.
Quantization is performed with :func:`~brevitas.function.ops_ste.binary_sign_ste`.
Args:
scaling_impl (Module): Module that returns a scale factor.
quant_delay_steps (int): Number of training steps to delay quantization for. Default: 0
Returns:
Tuple[Tensor, Tensor, Tensor, Tensor]: Quantized output in de-quantized format, scale, zero-point, bit_width.
Examples:
>>> from brevitas.core.scaling import ConstScaling
>>> binary_quant = BinaryQuant(ConstScaling(0.1))
>>> inp = torch.Tensor([0.04, -0.6, 3.3])
>>> out, scale, zero_point, bit_width = binary_quant(inp)
>>> out
tensor([ 0.1000, -0.1000, 0.1000])
>>> scale
tensor(0.1000)
>>> zero_point
tensor(0.)
>>> bit_width
tensor(1.)
Note:
Maps to quant_type == QuantType.BINARY == 'BINARY' == 'binary' when applied to weights in higher-level APIs.
Note:
Set env variable BREVITAS_JIT=1 to enable TorchScript compilation of this module.
"""
def __init__(self, scaling_impl: Module, signed: bool = True, quant_delay_steps: int = 0):
super(BinaryQuant, self).__init__()
assert signed, "Unsigned binary quant not supported"
self.scaling_impl = scaling_impl
self.bit_width = BitWidthConst(1)
self.zero_point = StatelessBuffer(torch.tensor(0.0))
self.delay_wrapper = DelayWrapper(quant_delay_steps)
@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
scale = self.scaling_impl(x)
y = binary_sign_ste(x) * scale
y = self.delay_wrapper(x, y)
return y, scale, self.zero_point(), self.bit_width()
The implementation is quite simple. Apart from quant_delay_steps
, which allows to delay quantization by a certain number of training steps (default = 0), the only other argument that BinaryQuant accepts is an implementation to compute the scale factor. bit_width
is fixed to 1 and zero-point
is fixed to 0.
We pick as scale factor implementation a ScriptModule
called ParameterScaling
, which implements a learned parameter with user-defined initialization. It can be found under brevitas.core.scaling
:
[5]:
from brevitas.core.scaling import ParameterScaling
Manual Binary Quantization#
As a first step, we simply instantiate BinaryQuant
with ParameterScaling
using scaling_init
equal 0.1 and we call it on a random floating-point input tensor:
[6]:
import torch
# set seed for notebook
torch.manual_seed(0)
manual_tensor_quant = BinaryQuant(scaling_impl=ParameterScaling(scaling_init=0.1))
manual_tensor_quant(torch.randn(4, 4))
[6]:
(tensor([[-0.1000, -0.1000, -0.1000, -0.1000],
[ 0.1000, 0.1000, -0.1000, -0.1000],
[ 0.1000, -0.1000, 0.1000, 0.1000],
[ 0.1000, 0.1000, 0.1000, -0.1000]], grad_fn=<MulBackward0>),
tensor(0.1000, grad_fn=<AbsBinarySignGradFnBackward>),
tensor(0.),
tensor(1.))
Nothing too surprising here, as expected the tensor is binarized with the scale factor we defined. Note however how manual_tensor_quant
is returning a tuple
and not a QuantTensor
. This is because support for custom data structures in TorchScript is still quite limited, so QuantTensor
are allocated only in Python-world abstractions.
Binary Quantization with an ExtendedInjector#
Let’s now declare tensor_quant
through an ExtendedInjector
:
[7]:
from brevitas.inject import ExtendedInjector
class MyBinaryQuantizer(ExtendedInjector):
tensor_quant = BinaryQuant
scaling_impl=ParameterScaling
scaling_init=0.1
inj_tensor_quant = MyBinaryQuantizer.tensor_quant
inj_tensor_quant(torch.randn(4, 4))
[7]:
(tensor([[-0.1000, -0.1000, 0.1000, 0.1000],
[ 0.1000, -0.1000, -0.1000, 0.1000],
[ 0.1000, -0.1000, -0.1000, 0.1000],
[ 0.1000, 0.1000, 0.1000, -0.1000]], grad_fn=<MulBackward0>),
tensor(0.1000, grad_fn=<AbsBinarySignGradFnBackward>),
tensor(0.),
tensor(1.))
Any time MyBinaryQuantizer.tensor_quant
is called, a new instance of BinaryQuant
is created. Note how the attributes of MyBinaryQuantizer
are designed to match the name of the arguments of each other, except for tensor_quant
, which is what we are interested in retrieving from the outside.
Inheritance and Composition of Quantizers#
The advantage of expressing a quantizer through a Python class also means that we can leverage both inheritance and composition. So for example we can inherit from MyBinaryQuantizer
and override scaling_init
with a new value:
[8]:
class MyChildBinaryQuantizer(MyBinaryQuantizer):
scaling_init=1.0
child_inj_tensor_quant = MyChildBinaryQuantizer.tensor_quant
child_inj_tensor_quant(torch.randn(4, 4))
[8]:
(tensor([[-1., 1., -1., 1.],
[ 1., 1., 1., 1.],
[-1., 1., -1., 1.],
[ 1., 1., -1., -1.]], grad_fn=<MulBackward0>),
tensor(1., grad_fn=<AbsBinarySignGradFnBackward>),
tensor(0.),
tensor(1.))
Or we can leverage composition by assembling together various classes containing different pieces of a quantizer:
[9]:
class MyBinaryImpl(ExtendedInjector):
tensor_quant = BinaryQuant
class MyScalingImpl(ExtendedInjector):
scaling_impl=ParameterScaling
scaling_init=0.1
class MyComposedBinaryQuantizer(MyBinaryImpl, MyScalingImpl):
pass
comp_inj_tensor_quant = MyComposedBinaryQuantizer.tensor_quant
comp_inj_tensor_quant(torch.randn(4, 4))
[9]:
(tensor([[-0.1000, 0.1000, 0.1000, 0.1000],
[-0.1000, -0.1000, 0.1000, -0.1000],
[-0.1000, 0.1000, 0.1000, 0.1000],
[-0.1000, -0.1000, 0.1000, -0.1000]], grad_fn=<MulBackward0>),
tensor(0.1000, grad_fn=<AbsBinarySignGradFnBackward>),
tensor(0.),
tensor(1.))
Interfacing a Quantizer with a Quantized Layer#
Before we can pass the quantizer to a quantized layer such as QuantConv2d
, we need a last component to define, a proxy. A proxy (found under brevitas.proxy
) is an nn.Module
that serve as interface between a quantizer and a quantized layer.
While a quantizer lives mostly in JIT-land, a proxy lives mostly in Python-land, and as such can afford much more flexibility. Proxies take care of returning a QuantTensor
and re-initializing the output of quantizer whenever a new state_dict
is loaded.
Proxies are specific to the kind of tensor being quantized, as in weights vs biases vs activations. For convenience, they are declared as part of the quantizer itself under the attribute proxy_class
. For example, for weights we can use WeightQuantProxyFromInjector
:
[10]:
from brevitas.proxy import WeightQuantProxyFromInjector
class MyBinaryWeightQuantizer(MyBinaryQuantizer):
proxy_class = WeightQuantProxyFromInjector
We can now use MyBinaryWeightQuantizer
as the weight quantizer of a layer:
[11]:
from brevitas.nn import QuantConv2d
binary_weight_quant_conv = QuantConv2d(3, 2, (3,3), weight_quant=MyBinaryWeightQuantizer)
try:
quant_weight = binary_weight_quant_conv.quant_weight()
except TypeError:
pass
Note however that we cannot compute the quantized weight, as the signed
attribute is None
.
signed
is one of those attributes that in the case of binary quantization has to be explicitly defined by the user. The idea is that it informs the proxy on whether the value generated by our quantizer should be considered signed or not. We can do so by simply setting it in the quantizer:
[12]:
class MySignedBinaryWeightQuantizer(MyBinaryWeightQuantizer):
signed = True
binary_weight_quant_conv = QuantConv2d(3, 2, (3,3), weight_quant=MySignedBinaryWeightQuantizer)
signed_quant_weight = binary_weight_quant_conv.quant_weight()
signed_quant_weight
[12]:
IntQuantTensor(value=tensor([[[[ 0.1000, 0.1000, -0.1000],
[ 0.1000, -0.1000, -0.1000],
[-0.1000, -0.1000, -0.1000]],
[[-0.1000, -0.1000, 0.1000],
[-0.1000, -0.1000, 0.1000],
[-0.1000, 0.1000, -0.1000]],
[[ 0.1000, 0.1000, 0.1000],
[ 0.1000, 0.1000, -0.1000],
[-0.1000, -0.1000, 0.1000]]],
[[[-0.1000, -0.1000, -0.1000],
[-0.1000, -0.1000, -0.1000],
[ 0.1000, 0.1000, -0.1000]],
[[ 0.1000, -0.1000, -0.1000],
[-0.1000, -0.1000, -0.1000],
[-0.1000, -0.1000, -0.1000]],
[[ 0.1000, -0.1000, 0.1000],
[-0.1000, -0.1000, 0.1000],
[-0.1000, 0.1000, 0.1000]]]], grad_fn=<MulBackward0>), scale=tensor(0.1000, grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))
[13]:
assert_with_message(signed_quant_weight.is_valid)
True
And now the quant weights are valid.
When we want to add or override an single attribute of a quantizer passed to a layer, defining a whole new quantizer can be too verbose. There is a simpler syntax to achieve the same goal. Let’s say we want to have add the signed
attribute to MyBinaryQuantizer
, as we just did. We could have also simply done the following:
[14]:
small_scale_quant_conv = QuantConv2d(3, 2, (3,3), weight_quant=MyBinaryWeightQuantizer, weight_signed=True)
small_scale_quant_conv.quant_weight()
[14]:
IntQuantTensor(value=tensor([[[[-0.1000, -0.1000, 0.1000],
[-0.1000, 0.1000, -0.1000],
[ 0.1000, 0.1000, 0.1000]],
[[ 0.1000, -0.1000, -0.1000],
[-0.1000, 0.1000, 0.1000],
[ 0.1000, -0.1000, 0.1000]],
[[-0.1000, -0.1000, -0.1000],
[ 0.1000, -0.1000, -0.1000],
[ 0.1000, 0.1000, -0.1000]]],
[[[-0.1000, -0.1000, 0.1000],
[-0.1000, 0.1000, 0.1000],
[ 0.1000, -0.1000, -0.1000]],
[[ 0.1000, -0.1000, -0.1000],
[ 0.1000, -0.1000, -0.1000],
[ 0.1000, -0.1000, -0.1000]],
[[ 0.1000, 0.1000, 0.1000],
[-0.1000, 0.1000, -0.1000],
[-0.1000, -0.1000, -0.1000]]]], grad_fn=<MulBackward0>), scale=tensor(0.1000, grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))
What we did was to take the name of the attribute signed
, add the prefix weight_
, and pass it as a keyword argument to QuantConv2d
. What happens in the background is that the keyword arguments prefixed with weight_
are set as attributes of weight_quant
, possibly overriding any pre-existing value. The same principle applies to input_
, output_
and bias_
.
This is the reason why, as it was mentioned in the first tutorial, quantized layers can accept arbitrary keyword arguments. It’s really just a way to support different styles of syntax when defining a quantizer.
Passing a custom quantizer to QuantIdentity#
We can do a similar thing with quantized activations:
[15]:
from brevitas.proxy import ActQuantProxyFromInjector
from brevitas.nn import QuantIdentity
class MySignedBinaryActQuantizer(MyBinaryQuantizer):
proxy_class = ActQuantProxyFromInjector
signed = True
binary_relu = QuantIdentity(act_quant=MySignedBinaryActQuantizer, return_quant_tensor=True)
binary_relu(torch.randn(4, 4))
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/torch/_tensor.py:1255: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/c10/core/TensorImpl.h:1758.)
return super(Tensor, self).rename(names)
[15]:
IntQuantTensor(value=tensor([[-0.1000, 0.1000, -0.1000, -0.1000],
[ 0.1000, 0.1000, 0.1000, 0.1000],
[-0.1000, -0.1000, 0.1000, -0.1000],
[-0.1000, 0.1000, -0.1000, 0.1000]], grad_fn=<MulBackward0>), scale=tensor(0.1000, grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))
So there isn’t really much difference between a quantizer for weights and a quantizer for activations, they are just wrapped by different proxies. Also, with activations a prefix is not required when passing keyword arguments. For example, when can override the existing scaling_init
defined in MyBinaryQuantizer
with a new value passed in as a keywork argument:
[16]:
small_scale_binary_identity = QuantIdentity(
act_quant=MySignedBinaryActQuantizer, scaling_init=0.001, return_quant_tensor=True)
small_scale_binary_identity(torch.randn(4, 4))
[16]:
IntQuantTensor(value=tensor([[ 0.0010, -0.0010, -0.0010, 0.0010],
[ 0.0010, 0.0010, 0.0010, 0.0010],
[ 0.0010, -0.0010, 0.0010, 0.0010],
[ 0.0010, -0.0010, -0.0010, -0.0010]], grad_fn=<MulBackward0>), scale=tensor(0.0010, grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))
A Custom Quantizer initialized with Weight Statistics#
So far we have seen use-cases where an ExtendedInjector
provides, at best, a different kind of syntax to define a quantizer, without any particular other advantage. Let’s now make things a bit more complicated to show the sort of situations where it really shines.
Let’s say we want to define a binary weight quantizer where scaling_impl
is still ParameterScaling
. However, instead of being user-defined, we want scaling_init
to be the maximum value found in the weight tensor of the quantized layer. To support this sort of use cases where the quantizer depends on the layer, a quantized layer automatically passes itself to all its quantizers under the name of module
. With only a few lines of code then, we can achieve our goal:
[17]:
from brevitas.inject import value
class ParamFromMaxWeightQuantizer(MySignedBinaryWeightQuantizer):
@value
def scaling_init(module):
return module.weight.abs().max()
Note how we are leveraging the @value
decorator to define a function that is executed at dependency-injection (DI) time. This kind of behaviour is similar in spirit to defining a @property
instead of an attribute, with the difference that a @value
function can depend on other attributes of the Injector, which are automatically passed in as arguments of the function during DI.
Let’s now pass the quantizer to a QuantConv2d and retrieve its quantized weights:
[18]:
param_from_max_quant_conv = QuantConv2d(3, 2, (3, 3), weight_quant=ParamFromMaxWeightQuantizer)
param_from_max_quant_conv.quant_weight()
[18]:
IntQuantTensor(value=tensor([[[[ 0.1820, -0.1820, -0.1820],
[ 0.1820, 0.1820, 0.1820],
[-0.1820, -0.1820, -0.1820]],
[[ 0.1820, -0.1820, -0.1820],
[ 0.1820, -0.1820, -0.1820],
[ 0.1820, 0.1820, -0.1820]],
[[-0.1820, 0.1820, 0.1820],
[ 0.1820, -0.1820, 0.1820],
[-0.1820, -0.1820, -0.1820]]],
[[[ 0.1820, 0.1820, -0.1820],
[-0.1820, -0.1820, 0.1820],
[-0.1820, 0.1820, -0.1820]],
[[-0.1820, 0.1820, -0.1820],
[ 0.1820, 0.1820, -0.1820],
[ 0.1820, 0.1820, 0.1820]],
[[-0.1820, -0.1820, -0.1820],
[ 0.1820, -0.1820, 0.1820],
[ 0.1820, -0.1820, 0.1820]]]], grad_fn=<MulBackward0>), scale=tensor(0.1820, grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))
Indeed we can verify that weight_quant.scale()
is equal to weight.abs().max()
:
[19]:
assert_with_message((param_from_max_quant_conv.weight_quant.scale() == param_from_max_quant_conv.weight.abs().max()).item())
True
Let’s say now that we want to load a pretrained floating-point weight tensor on top of our quantized model. We simuate this scenario by defining a separate nn.Conv2d
layer with the same weight shape:
[20]:
from torch import nn
float_conv = nn.Conv2d(3, 2, (3, 3))
float_conv.weight.abs().max()
[20]:
tensor(0.1924, grad_fn=<MaxBackward1>)
and then we load it on top of param_from_max_quant_conv
:
[21]:
param_from_max_quant_conv.load_state_dict(float_conv.state_dict())
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[21], line 1
----> 1 param_from_max_quant_conv.load_state_dict(float_conv.state_dict())
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/torch/nn/modules/module.py:1671, in Module.load_state_dict(self, state_dict, strict)
1666 error_msgs.insert(
1667 0, 'Missing key(s) in state_dict: {}. '.format(
1668 ', '.join('"{}"'.format(k) for k in missing_keys)))
1670 if len(error_msgs) > 0:
-> 1671 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1672 self.__class__.__name__, "\n\t".join(error_msgs)))
1673 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for QuantConv2d:
Missing key(s) in state_dict: "weight_quant.tensor_quant.scaling_impl.value".
Ouch, we get an error. This is because ParameterScaling
contains a learned torch.nn.Parameter
, and Pytorch expects all learned parameters of a model to be contained in a state_dict
that is being loaded. We can work around the issue by either setting the IGNORE_MISSING_KEYS
config flag in Brevitas, or by passing strict=False
to load_state_dict. We go with the former as setting strict=False
is too forgiving to other kind of problems:
[22]:
from brevitas import config
config.IGNORE_MISSING_KEYS = True
param_from_max_quant_conv.load_state_dict(float_conv.state_dict())
[22]:
<All keys matched successfully>
Note that we could have also achieve the same goal by setting the env variable BREVITAS_IGNORE_MISSING_KEYS=1
.
And now if we take a look at the quantized weights again:
[23]:
param_from_max_quant_conv.quant_weight()
[23]:
IntQuantTensor(value=tensor([[[[ 0.1924, 0.1924, -0.1924],
[ 0.1924, 0.1924, 0.1924],
[ 0.1924, 0.1924, 0.1924]],
[[-0.1924, -0.1924, -0.1924],
[ 0.1924, 0.1924, -0.1924],
[ 0.1924, 0.1924, 0.1924]],
[[ 0.1924, 0.1924, -0.1924],
[-0.1924, 0.1924, 0.1924],
[-0.1924, 0.1924, 0.1924]]],
[[[ 0.1924, -0.1924, 0.1924],
[-0.1924, 0.1924, -0.1924],
[ 0.1924, 0.1924, 0.1924]],
[[-0.1924, 0.1924, -0.1924],
[ 0.1924, -0.1924, -0.1924],
[-0.1924, 0.1924, 0.1924]],
[[ 0.1924, 0.1924, -0.1924],
[-0.1924, -0.1924, -0.1924],
[ 0.1924, -0.1924, -0.1924]]]], grad_fn=<MulBackward0>), scale=tensor(0.1924, grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))
We see that, as expected, the scale factor has been updated to the new weight.abs().max()
.
What happens internally is that after load_state_dict
is called on the layer, ParamFromMaxWeightQuantizer.tensor_quant
gets called again to re-initialize BinaryQuant
, and in turn ParameterScaling
is re-initialized with a new scaling_init
value computed based on the updated module.weight
tensor. This whole process wouldn’t have been possible without an ExtendedInjector
behind it.
Sharing a Quantizer#
There are two ways to share a quantizer between multiple layers, with importance differences.
The first one, which we have seen so far, is to simply pass the same ExtendedInjector to multiple layers. What that does is sharing the same quantization strategy among different layers. Each layer still gets its own instance of the quantization implementation.
[24]:
quant_conv1 = QuantConv2d(3, 2, (3, 3), weight_quant=MySignedBinaryWeightQuantizer)
quant_conv2 = QuantConv2d(3, 2, (3, 3), weight_quant=MySignedBinaryWeightQuantizer)
assert_with_message(quant_conv1.weight_quant is not quant_conv2.weight_quant)
True
Sharing a proxy#
The second one, which we are introducing now, allows to share the same quantization instance among multiple layers. This is done by simply sharing the proxy wrapping it. This can be useful in those scenarios where, for example, we want different layers to share the same scale factor. The syntax goes as follows:
[25]:
quant_conv1 = QuantConv2d(3, 2, (3, 3), weight_quant=MySignedBinaryWeightQuantizer)
quant_conv2 = QuantConv2d(3, 2, (3, 3), weight_quant=quant_conv1.weight_quant)
assert_with_message(quant_conv1.weight_quant is quant_conv2.weight_quant)
True
[26]:
assert_with_message((quant_conv1.weight_quant.scale() == quant_conv2.weight_quant.scale()).item())
True
What happens in background is that the weight quantizer now has access to both quant_conv1
and quant_conv2
. So let’s say we want to build a quantizer similar to ParamFromMaxWeightQuantizer
, but in this case we want the scale factor to be initialized with the average of both weight tensors. When a quantizer has access to multiple parent modules, they are passed in at dependency injection time as a tuple under the same name module
as before. So we can do the following:
[27]:
class SharedParamFromMeanWeightQuantizer(MySignedBinaryWeightQuantizer):
@value
def scaling_init(module):
if isinstance(module, tuple):
return torch.cat((module[0].weight.view(-1), module[1].weight.view(-1))).abs().mean()
else:
return module.weight.abs().mean()
quant_conv1 = QuantConv2d(3, 2, (3, 3), weight_quant=SharedParamFromMeanWeightQuantizer)
old_quant_conv1_scale = quant_conv1.weight_quant.scale()
quant_conv2 = QuantConv2d(3, 2, (3, 3), weight_quant=quant_conv1.weight_quant)
new_quant_conv1_scale = quant_conv1.weight_quant.scale()
assert_with_message(not (old_quant_conv1_scale == new_quant_conv1_scale).item())
True
[28]:
assert_with_message((new_quant_conv1_scale == quant_conv2.weight_quant.scale()).item())
True
Note how, when quant_conv2
is initialized using the weight_quant
of quant_conv1
, weight quantization is re-initialized for both layers such that they end up having the same scale.
We can see in this example how Brevitas works consistently with Pytorch’s eager execution model. When we initialize quant_conv1
we still don’t know that its weight quantizer is going to be shared with quant_conv2
, and the semantics of Pytorch impose that quant_conv1
should work correctly both before and after quant_conv2
is declared. The way we take advantage of dependency injection allows to do so.
Sharing an instance of Activation Quantization#
Sharing an instance of activation quantization is easier because for most scenarios it’s enough to simply share the whole layer itself, e.g. calling the same QuantReLU
from multiple places in the forward pass.
QuantReLU.act_quant
to initialize a QuantConv2d.output_quant
should be avoided as we would not share not only the quantizer, but also the relu activation function.Note: we say kind and not type because input_quant
, output_quant
and IdentityQuant
count as being the same kind of activation, even though they belong to different type of layers.
Dealing with Weight Initialization#
There is a type of situation that Brevitas cannot deal with automatically. That is, when the initialization of the quantizer depends on the layer to which it is applied (like with the ParamFromMaxWeightQuantizer
or SharedParamFromMeanWeightQuantizer
quantizers), but the layer gets modified after it is initialized.
The typical example is with weight initialization when training from scratch (so rather than loading from a floating-point state_dict):
[29]:
quant_conv_w_init = QuantConv2d(3, 2, (3, 3), weight_quant=ParamFromMaxWeightQuantizer)
torch.nn.init.uniform_(quant_conv_w_init.weight)
assert_with_message(not (quant_conv_w_init.weight.abs().max() == quant_conv_w_init.weight_quant.scale()).item())
True
We can see how the scale factor is not initialized correctly anymore. In this case we can simply trigger re-initialization of the weight quantizer manually:
[30]:
quant_conv_w_init.weight_quant.init_tensor_quant()
assert_with_message((quant_conv_w_init.weight.abs().max() == quant_conv_w_init.weight_quant.scale()).item())
True
Note: because the way weights are initialized is often the same as to how an optimizer performs the weight update step, there are currently no plans to try to perform re-initialization automatically (as it happens e.g. when a state_dict
is loaded) since it wouldn’t be possible to distinguish between the two scenarios.
Building a Custom Quantization API#
Finally, let’s go through an even more complicated example. We are going to look at a scenario that illustrates the differences between a standard Injector
(implemented in the dependencies library) and our ExtendedInjector
extension.
Let’s say we want to build two quantizers for respectively weights and activations and build a simple API on top of them. In particular, we want to be able to switch between BinaryQuant
and ClampedBinaryQuant
(a variant of binary quantization with clamping), and we want to optionally perform per-channel scaling. To do so, we are going to implement the controlling logic through a hierarchy of ExtendedInjector, leaving two boolean flags exposed as arguments of the quantizers, with the
idea then that the flags can be set through keyword arguments of the respective quantized layers.
We can go as follows:
[31]:
from brevitas.core.quant import ClampedBinaryQuant
from brevitas.proxy import WeightQuantProxyFromInjector, ActQuantProxyFromInjector
from brevitas.inject import this
class CommonQuantizer(ExtendedInjector):
scaling_impl = ParameterScaling
signed=True
@value
def tensor_quant(is_clamped):
# returning a class to auto-wire from a value function
# wouldn't be allowed in a standard Injector
if is_clamped:
return ClampedBinaryQuant
else:
return BinaryQuant
@value
def scaling_shape(scaling_per_output_channel):
if scaling_per_output_channel:
# returning this.something from a value function
# wouldn't be allowed in a standard Injector
return this.per_channel_broadcastable_shape
else:
return ()
class AdvancedWeightQuantizer(CommonQuantizer):
proxy_class = WeightQuantProxyFromInjector
@value
def per_channel_broadcastable_shape(module):
return (module.weight.shape[0], 1, 1, 1)
@value
def scaling_init(module, scaling_per_output_channel):
if scaling_per_output_channel:
num_ch = module.weight.shape[0]
return module.weight.abs().view(num_ch, -1).max(dim=1)[0].view(-1, 1, 1, 1)
else:
return module.weight.abs().max()
class AdvancedActQuantizer(CommonQuantizer):
scaling_init = 0.01
proxy_class = ActQuantProxyFromInjector
There are a bunch of things going on here to unpack.
The first one is that a @value
function can return a class to auto-wire and inject, as seen in the definition of tensor_quant
. This wouldn’t normally be possible with a standard Injector
, but it’s possible with an ExtendedInjector
. This way we can switch between different implementations of tensor_quant
.
The second one is the special object this
. this
is already present in the dependencies library, and it’s used as a way to retrieve attributes of the quantizer from within the quantizer itself. However, normally it wouldn’t be possible to return a reference to this
from a @value
function. Again this is something that only a ExtendedInjector
supports, and it allows to chain different attributes in a way such that the chained values are computed only when necessary.
Let’s see the quantizers applied to a layer:
[32]:
per_channel_quant_conv = QuantConv2d(
3, 2, (3, 3),
weight_quant=AdvancedWeightQuantizer,
weight_is_clamped=False,
weight_scaling_per_output_channel=True)
per_channel_quant_conv.quant_weight()
[32]:
IntQuantTensor(value=tensor([[[[ 0.1612, -0.1612, -0.1612],
[-0.1612, -0.1612, -0.1612],
[ 0.1612, 0.1612, 0.1612]],
[[-0.1612, 0.1612, -0.1612],
[-0.1612, 0.1612, 0.1612],
[-0.1612, -0.1612, 0.1612]],
[[-0.1612, 0.1612, 0.1612],
[ 0.1612, 0.1612, -0.1612],
[ 0.1612, 0.1612, 0.1612]]],
[[[ 0.1924, 0.1924, 0.1924],
[-0.1924, -0.1924, 0.1924],
[-0.1924, 0.1924, -0.1924]],
[[ 0.1924, -0.1924, 0.1924],
[ 0.1924, 0.1924, -0.1924],
[ 0.1924, -0.1924, -0.1924]],
[[-0.1924, -0.1924, 0.1924],
[ 0.1924, -0.1924, -0.1924],
[ 0.1924, -0.1924, 0.1924]]]], grad_fn=<MulBackward0>), scale=tensor([[[[0.1612]]],
[[[0.1924]]]], grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))
As expected the weight scale is now a vector. Everything we said so far about quantizers still applies, so for example we can load the floating-point state dict we defined before and observe how it triggers an update of the weight scale:
[33]:
per_channel_quant_conv.load_state_dict(float_conv.state_dict())
per_channel_quant_conv.quant_weight()
[33]:
IntQuantTensor(value=tensor([[[[ 0.1924, 0.1924, -0.1924],
[ 0.1924, 0.1924, 0.1924],
[ 0.1924, 0.1924, 0.1924]],
[[-0.1924, -0.1924, -0.1924],
[ 0.1924, 0.1924, -0.1924],
[ 0.1924, 0.1924, 0.1924]],
[[ 0.1924, 0.1924, -0.1924],
[-0.1924, 0.1924, 0.1924],
[-0.1924, 0.1924, 0.1924]]],
[[[ 0.1899, -0.1899, 0.1899],
[-0.1899, 0.1899, -0.1899],
[ 0.1899, 0.1899, 0.1899]],
[[-0.1899, 0.1899, -0.1899],
[ 0.1899, -0.1899, -0.1899],
[-0.1899, 0.1899, 0.1899]],
[[ 0.1899, 0.1899, -0.1899],
[-0.1899, -0.1899, -0.1899],
[ 0.1899, -0.1899, -0.1899]]]], grad_fn=<MulBackward0>), scale=tensor([[[[0.1924]]],
[[[0.1899]]]], grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))
In this case we have a per-channel quantizer, so the original floating-point weight tensor is now quantized per channel.
Similarly, we can apply our custom activation quantizer to e.g. a QuantIdentity
layer:
[34]:
from brevitas.nn import QuantIdentity
quant_identity = QuantIdentity(
act_quant=AdvancedActQuantizer, is_clamped=True, scaling_per_output_channel=False)
quant_identity(torch.randn(4, 4))
[34]:
tensor([[ 0.0100, 0.0100, -0.0100, 0.0100],
[ 0.0100, 0.0100, -0.0100, 0.0100],
[-0.0100, 0.0100, -0.0100, -0.0100],
[ 0.0100, -0.0100, -0.0100, -0.0100]], grad_fn=<MulBackward0>)
Note how AdvancedActQuantizer
doesn’t define a per_channel_broadcastable_shape
, yet no errors are triggered. This is because this.per_channel_broadcastable_shape
is required only when scaling_per_output_channel
is True
, while in this case scaling_per_output_channel
is False
. Let’ try to set it to True
then:
[35]:
from brevitas.nn import QuantIdentity
quant_identity = QuantIdentity(
act_quant=AdvancedActQuantizer, is_clamped=True, scaling_per_output_channel=True)
---------------------------------------------------------------------------
DependencyError Traceback (most recent call last)
Cell In[35], line 3
1 from brevitas.nn import QuantIdentity
----> 3 quant_identity = QuantIdentity(
4 act_quant=AdvancedActQuantizer, is_clamped=True, scaling_per_output_channel=True)
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/quant_activation.py:113, in QuantIdentity.__init__(self, act_quant, return_quant_tensor, **kwargs)
108 def __init__(
109 self,
110 act_quant: Optional[ActQuantType] = Int8ActPerTensorFloat,
111 return_quant_tensor: bool = False,
112 **kwargs):
--> 113 QuantNLAL.__init__(
114 self,
115 input_quant=None,
116 act_impl=None,
117 passthrough_act=True,
118 act_quant=act_quant,
119 return_quant_tensor=return_quant_tensor,
120 **kwargs)
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/quant_layer.py:34, in QuantNonLinearActLayer.__init__(self, act_impl, passthrough_act, input_quant, act_quant, return_quant_tensor, **kwargs)
32 QuantLayerMixin.__init__(self, return_quant_tensor)
33 QuantInputMixin.__init__(self, input_quant, **kwargs)
---> 34 QuantNonLinearActMixin.__init__(self, act_impl, passthrough_act, act_quant, **kwargs)
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/act.py:66, in QuantNonLinearActMixin.__init__(self, act_impl, passthrough_act, act_quant, act_proxy_prefix, act_kwargs_prefix, **kwargs)
55 def __init__(
56 self,
57 act_impl: Optional[Type[Module]],
(...)
61 act_kwargs_prefix='',
62 **kwargs):
63 prefixed_kwargs = {
64 act_kwargs_prefix + 'act_impl': act_impl,
65 act_kwargs_prefix + 'passthrough_act': passthrough_act}
---> 66 QuantProxyMixin.__init__(
67 self,
68 quant=act_quant,
69 proxy_prefix=act_proxy_prefix,
70 kwargs_prefix=act_kwargs_prefix,
71 proxy_protocol=ActQuantProxyProtocol,
72 none_quant_injector=NoneActQuant,
73 **prefixed_kwargs,
74 **kwargs)
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:48, in QuantProxyMixin.__init__(self, quant, proxy_protocol, none_quant_injector, proxy_prefix, kwargs_prefix, **kwargs)
46 quant_injector = quant
47 quant_injector = quant_injector.let(**filter_kwargs(kwargs_prefix, kwargs))
---> 48 quant = quant_injector.proxy_class(self, quant_injector)
49 else:
50 if not isinstance(quant, proxy_protocol):
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/proxy/runtime_quant.py:198, in ActQuantProxyFromInjector.__init__(self, quant_layer, quant_injector)
197 def __init__(self, quant_layer, quant_injector):
--> 198 super().__init__(quant_layer, quant_injector)
199 self.cache_class = _CachedIO
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/proxy/runtime_quant.py:93, in ActQuantProxyFromInjectorBase.__init__(self, quant_layer, quant_injector)
92 def __init__(self, quant_layer, quant_injector):
---> 93 QuantProxyFromInjector.__init__(self, quant_layer, quant_injector)
94 ActQuantProxyProtocol.__init__(self)
95 self.is_passthrough_act = _is_passthrough_act(quant_injector)
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/proxy/quant_proxy.py:80, in QuantProxyFromInjector.__init__(self, quant_layer, quant_injector)
78 # Use a normal list and not a ModuleList since this is a pointer to parent modules
79 self.tracked_module_list = []
---> 80 self.add_tracked_module(quant_layer)
81 self.disable_quant = False
82 # Torch.compile compatibility requires this
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/proxy/quant_proxy.py:120, in QuantProxyFromInjector.add_tracked_module(self, module)
118 self.tracked_module_list.append(module)
119 self.update_tracked_modules()
--> 120 self.init_tensor_quant()
121 else:
122 raise RuntimeError("Trying to add None as a parent module.")
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/proxy/runtime_quant.py:127, in ActQuantProxyFromInjectorBase.init_tensor_quant(self)
126 def init_tensor_quant(self):
--> 127 tensor_quant = self.quant_injector.tensor_quant
128 if 'act_impl' in self.quant_injector:
129 act_impl = self.quant_injector.act_impl
[... skipping hidden 1 frame]
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/_dependencies/this.py:51, in _ThisSpec.__call__(self, __self__)
49 if kind == ".":
50 try:
---> 51 result = getattr(result, symbol)
52 except DependencyError:
53 message = (
54 "You tried to shift this more times than Injector has levels"
55 )
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/inject/__init__.py:129, in _ExtendedInjectorType.__getattr__(cls, attrname)
126 else:
127 message = "{!r} can not resolve attribute {!r}".format(
128 cls.__name__, current_attr)
--> 129 raise DependencyError(message)
131 marker, attribute, args, have_defaults = spec
133 if set(args).issubset(cached):
DependencyError: 'AdvancedActQuantizer' can not resolve attribute 'per_channel_broadcastable_shape'
As expected we get an error saying that the quantizer cannot resolve per_channel_broadcastable_shape
. If we pass it in then we can get a per-channel quantizer:
[36]:
quant_identity = QuantIdentity(
act_quant=AdvancedActQuantizer, is_clamped=True, scaling_per_output_channel=True,
per_channel_broadcastable_shape=(4, 1), return_quant_tensor=True)
quant_identity(torch.randn(4, 4))
[36]:
IntQuantTensor(value=tensor([[ 0.0100, 0.0100, 0.0100, -0.0100],
[ 0.0100, -0.0100, -0.0100, 0.0100],
[-0.0100, -0.0100, -0.0100, 0.0100],
[ 0.0100, 0.0100, 0.0100, 0.0100]], grad_fn=<MulBackward0>), scale=tensor([[0.0100],
[0.0100],
[0.0100],
[0.0100]], grad_fn=<AbsBinarySignGradFnBackward>), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))
We have seen how powerful dependency injection is. In a way, it’s even too expressive. For users that are not interesting in building completely custom quantizers, it can be hard to make sense of how the various components available under brevitas.core
can be assembled together according to best practices.