An overview of QuantTensor and QuantConv2d#
In this initial tutorial, we take a first look at QuantTensor
, a basic data structure in Brevitas, and at QuantConv2d
, a typical quantized layer. QuantConv2d
is an instance of a QuantWeightBiasInputOutputLayer
(typically imported as QuantWBIOL
), meaning that it supports quantization of its weight, bias, input and output. Other instances of QuantWBIOL
are QuantLinear
, QuantConv1d
, QuantConvTranspose1d
and QuantConvTranspose2d
, and they all follow the same
principles.
If we take a look at the __init__
method of QuantConv2d
, we notice a few things:
[1]:
import inspect
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantIdentity
from IPython.display import Markdown, display
import torch
# helpers
def assert_with_message(condition):
assert condition
print(condition)
def pretty_print_source(source):
display(Markdown('```python\n' + source + '\n```'))
# set manual seed for the notebook
torch.manual_seed(0)
source = inspect.getsource(QuantConv2d.__init__)
pretty_print_source(source)
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
padding_mode: str = 'zeros',
bias: Optional[bool] = True,
weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
bias_quant: Optional[BiasQuantType] = None,
input_quant: Optional[ActQuantType] = None,
output_quant: Optional[ActQuantType] = None,
return_quant_tensor: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
# avoid an init error in the super class by setting padding to 0
if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance(
stride, int) else any(map(lambda x: x > 1, stride))):
padding = 0
is_same_padded_strided = True
else:
is_same_padded_strided = False
Conv2d.__init__(
self,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
padding_mode=padding_mode,
dilation=dilation,
groups=groups,
bias=bias,
device=device,
dtype=dtype)
QuantWBIOL.__init__(
self,
weight_quant=weight_quant,
bias_quant=bias_quant,
input_quant=input_quant,
output_quant=output_quant,
return_quant_tensor=return_quant_tensor,
**kwargs)
self.is_same_padded_strided = is_same_padded_strided
QuantConv2d
is an instance of both Conv2d
and QuantWBIOL
. Its initialization method exposes the usual arguments of a Conv2d
, as well as: an extra flag to support same padding; four different arguments to set a quantizer for - respectively - weight, bias, input, and output; a return_quant_tensor
boolean flag; the **kwargs
placeholder to intercept additional arbitrary keyword arguments.By default weight_quant=Int8WeightPerTensorFloat
, while bias_quant
, input_quant
and output_quant
are set to None
. That means that by default weights are quantized to 8-bit signed integer with a per-tensor floating-point scale factor (a very common type of quantization adopted by e.g. the ONNX standard opset), while quantization of bias, input, and output are disabled. We can easily verify all of this at runtime on an example:
[2]:
default_quant_conv = QuantConv2d(
in_channels=2, out_channels=3, kernel_size=(3,3), bias=False)
[3]:
print(f'Is weight quant enabled: {default_quant_conv.weight_quant.is_quant_enabled}')
print(f'Is bias quant enabled: {default_quant_conv.bias_quant.is_quant_enabled}')
print(f'Is input quant enabled: {default_quant_conv.input_quant.is_quant_enabled}')
print(f'Is output quant enabled: {default_quant_conv.output_quant.is_quant_enabled}')
Is weight quant enabled: True
Is bias quant enabled: False
Is input quant enabled: False
Is output quant enabled: False
If we now try to pass in a random floating-point tensor as input, as expected we get the output of the convolution:
[4]:
import torch
out = default_quant_conv(torch.randn(1, 2, 5, 5))
out
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/torch/_tensor.py:1255: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/c10/core/TensorImpl.h:1758.)
return super(Tensor, self).rename(names)
[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/torch/nn/modules/conv.py:459: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/torch/csrc/utils/python_arg_parser.cpp:350.)
return F.conv2d(input, weight, bias, self.stride,
[4]:
tensor([[[[ 0.2908, -0.1793, -0.9610],
[-0.6542, -0.3532, 0.6361],
[ 1.0290, 0.2730, 0.0969]],
[[-0.3479, 0.6030, 0.4900],
[ 0.1607, 0.3547, -0.4283],
[-0.6696, 0.0652, 0.7300]],
[[-0.0769, -0.2424, 0.1860],
[ 0.1740, -0.1182, -0.7017],
[ 0.0963, 0.2375, -0.9439]]]], grad_fn=<ConvolutionBackward0>)
In this case we are computing the convolution between an unquantized input tensor and quantized weights, so the output in general is unquantized.
A QuantConv2d with quantization disabled everywhere behaves like a standard Conv2d
. Again can easily verify this:
[5]:
from torch.nn import Conv2d
torch.manual_seed(0) # set a seed to make sure the random weight init is reproducible
disabled_quant_conv = QuantConv2d(
in_channels=2, out_channels=3, kernel_size=(3,3), bias=False, weight_quant=None)
torch.manual_seed(0) # reproduce the same random weight init as above
float_conv = Conv2d(
in_channels=2, out_channels=3, kernel_size=(3,3), bias=False)
inp = torch.randn(1, 2, 5, 5)
assert_with_message(torch.isclose(disabled_quant_conv(inp), float_conv(inp)).all().item())
True
QuantTensor#
We can directly observe the quantized weights by calling the weight quantizer on the layer’s weights: default_quant_conv.weight_quant(quant_conv.weight)
, which for shortness is already implemented as default_quant_conv.quant_weight()
:
[6]:
default_quant_conv.quant_weight()
[6]:
IntQuantTensor(value=tensor([[[[-0.0018, 0.1273, -0.1937],
[-0.1734, -0.0904, 0.0627],
[-0.0055, 0.1863, -0.0203]],
[[ 0.0627, -0.0720, -0.0461],
[-0.2251, -0.1568, -0.0978],
[ 0.0092, 0.0941, 0.1421]]],
[[[-0.1605, -0.1033, 0.0849],
[ 0.1956, -0.0480, 0.1771],
[-0.0387, 0.0258, 0.2140]],
[[-0.2196, -0.1476, -0.0590],
[-0.0923, 0.2030, -0.1531],
[-0.1089, -0.1642, -0.2214]]],
[[[-0.1384, 0.2030, 0.1052],
[ 0.1144, 0.0129, -0.1199],
[ 0.0406, -0.2196, -0.1697]],
[[-0.1218, 0.1494, 0.1384],
[-0.1052, -0.0092, 0.1513],
[ 0.2343, 0.0941, 0.0314]]]], grad_fn=<MulBackward0>), scale=tensor(0.0018, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))
Notice how the quantized weights are wrapped in a data structure implemented by Brevitas called QuantTensor
. A QuantTensor
is a way to represent an affine quantized tensor with all its metadata, meaning: the value
of the quantized tensor in dequantized format, scale
, zero_point
, bit_width
, whether the quantized value it’s signed
or not, and whether the tensor was generated in training
mode.
As expected, we have that the quantized value (in dequantized format) can be computer from its integer representation, together with zero-point and scale:
[7]:
int_weight = default_quant_conv.quant_weight().int()
zero_point = default_quant_conv.weight_quant.zero_point()
scale = default_quant_conv.weight_quant.scale()
quant_weight_manually = (int_weight - zero_point) * scale
assert_with_message(default_quant_conv.quant_weight().value.isclose(quant_weight_manually).all().item())
True
A valid QuantTensor correctly populates all its fields with values != None
and respect the affine quantization invariant, i.e. value / scale + zero_point
is (accounting for rounding errors) an integer that can be represented within the interval defined by the bit_width
and signed
fields of the QuantTensor
. A non-valid one doesn’t. We can observe that the quantized weights are indeed marked as valid:
[8]:
assert_with_message(default_quant_conv.quant_weight().is_valid)
True
Calling is_valid
is relative expensive, so it should be using sparingly, but there are a few cases where a non-valid QuantTensor might be generated that is important to be aware of. Say we have two QuantTensor as output of the same quantized activation, and we want to sum them together:
[9]:
quant_act = QuantIdentity(return_quant_tensor=True)
out_tensor_0 = quant_act(torch.randn(1,2,5,5))
out_tensor_1 = quant_act(torch.randn(1,2,5,5))
assert_with_message(out_tensor_0.is_valid)
assert_with_message(out_tensor_1.is_valid)
print(out_tensor_0.scale)
print(out_tensor_1.scale)
True
True
tensor(0.0211, grad_fn=<DivBackward0>)
tensor(0.0162, grad_fn=<DivBackward0>)
Both QuantTensor are valid but since the quantized activation is in training mode by default, their scale factors are going to be different. It is important to note that the behaviour is different at evaluation time, where the two scale factors will be the same.
[10]:
out_tensor = out_tensor_0 + out_tensor_1
print(out_tensor)
IntQuantTensor(value=tensor([[[[-0.1106, 1.1945, -0.4972, -2.0968, 0.7175],
[-2.5901, 0.0588, -0.2014, 2.1486, 1.6435],
[ 0.9067, -2.5212, 2.2193, 0.2352, -0.8395],
[-0.8351, 0.6341, -0.5551, 0.1040, -3.3151],
[-0.8979, -0.7092, 3.8232, 1.0875, 0.3954]],
[[ 1.4363, -1.3973, 1.3249, 2.6914, 0.3660],
[ 1.5057, 1.8094, 0.5100, -1.6874, 1.9981],
[ 1.2472, -1.7813, 0.0334, -1.2880, -2.9333],
[ 0.0180, -1.4298, -2.9978, 0.5494, -1.4548],
[ 1.6738, -0.3177, -0.3721, -0.1650, -1.1871]]]],
grad_fn=<AddBackward0>), scale=0.018651068210601807, zero_point=0.0, bit_width=9.0, signed_t=True, training_t=True)
Because we set training
to True
for both of them, we are allowed to sum them even if they have different scale factors. The output QuantTensor will have the correct bit_width
, and a scale which is the average of the two original scale factors. This is done only at training time, in order to propagate gradient information, however the consequence is that the resulting QuantTensor is no longer valid:
[11]:
assert_with_message(not out_tensor.is_valid)
True
QuantTensor
implements __torch_function__
to handle being called from torch functional operators (e.g. ops under torch.nn.functional
). Passing a QuantTensor to supported ops that are invariant to quantization, e.g. max-pooling, preserve the the validity of a QuantTensor. Example:
[12]:
import torch
quant_identity = QuantIdentity(return_quant_tensor=True)
quant_tensor = quant_identity(torch.randn(1, 3, 4, 4))
torch.nn.functional.max_pool2d(quant_tensor, kernel_size=2, stride=2)
[12]:
IntQuantTensor(value=tensor([[[[0.5191, 0.6402],
[2.1455, 0.5883]],
[[2.0417, 0.5883],
[1.2631, 0.3980]],
[[0.7959, 0.5191],
[0.8132, 1.3496]]]], grad_fn=<MaxPool2DWithIndicesBackward0>), scale=tensor(0.0173, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))
For ops that are not invariant to quantization, a QuantTensor
decays into a floating-point torch.Tensor
. Example:
[13]:
torch.tanh(quant_tensor)
/tmp/ipykernel_81376/1377665000.py:1: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/torch/csrc/utils/python_arg_parser.cpp:350.)
torch.tanh(quant_tensor)
[13]:
tensor([[[[ 0.4770, 0.2212, 0.0691, 0.5650],
[-0.0346, -0.6618, -0.4635, -0.3482],
[ 0.9730, -0.7245, -0.5881, -0.5287],
[-0.0863, 0.8857, 0.5287, -0.4498]],
[[ 0.9669, 0.5650, -0.6211, -0.4498],
[-0.2376, 0.6103, 0.5287, 0.2700],
[-0.6808, 0.8519, 0.2700, -0.5531],
[-0.0173, 0.8264, 0.3782, -0.1881]],
[[-0.6211, -0.9764, -0.5993, 0.4770],
[ 0.5033, 0.6618, -0.1881, -0.6211],
[-0.8031, 0.1375, 0.5287, 0.8740],
[-0.6714, 0.6714, -0.5650, 0.8611]]]], grad_fn=<TanhBackward0>)
Input Quantization#
We can obtain a valid output QuantTensor
by making sure that both input and weight of QuantConv2d
are quantized. To do so, we can set a quantizer for input_quant
. In this example we pick a signed 8-bit quantizer with per-tensor floating-point scale factor:
[14]:
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
input_quant_conv = QuantConv2d(
in_channels=2, out_channels=3, kernel_size=(3,3), bias=False,
input_quant=Int8ActPerTensorFloat, return_quant_tensor=True)
out_tensor = input_quant_conv(torch.randn(1, 2, 5, 5))
out_tensor
[14]:
IntQuantTensor(value=tensor([[[[-0.3568, -0.1883, 0.3589],
[-0.4470, 0.1039, -0.3945],
[-0.4190, 0.3723, 0.8384]],
[[-0.0510, 0.5514, -0.2751],
[-0.5668, 0.5824, 0.2328],
[ 0.1316, -0.2518, 1.0418]],
[[ 0.2734, 0.7268, -0.0249],
[-0.1732, 0.5197, 1.1158],
[ 0.3771, -0.3810, 0.2008]]]], grad_fn=<ConvolutionBackward0>), scale=tensor([[[[3.1958e-05]]]], grad_fn=<MulBackward0>), zero_point=tensor([0.]), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))
[15]:
assert_with_message(out_tensor.is_valid)
True
What happens internally is that the input tensor passed to input_quant_conv
is being quantized before being passed to the convolution operator. That means we are now computing a convolution between two quantized tensors, which mimplies that the output of the operation is also quantized. As expected then out_tensor
is marked as valid.
Another important thing to notice is how the bit_width
field of out_tensor
is relatively high at 21 bits. In Brevitas, the assumption is always that the output bit-width of an operator reflects the worst-case size of the accumulator required by that operation. In other terms, given the size of the input and weight tensors and their bit-widths, 21 is the bit-width that would be required to represent the largest possible output value that could be generated. This makes sure that
the affine quantization invariant is always respected.
We could have obtained a similar result by directly passing as input a QuantTensor. In this example we are directly defining a QuantTensor ourselves, but it could also be the output of a previous layer.
[17]:
from brevitas.quant_tensor import IntQuantTensor
scale = 0.0001
bit_width = 8
zero_point = 0.
int_value = torch.randint(low=- 2 ** (bit_width - 1), high=2 ** (bit_width - 1) - 1, size=(1, 2, 5, 5))
quant_value = (int_value - zero_point) * scale
quant_tensor_input = IntQuantTensor(
quant_value,
scale=torch.tensor(scale),
zero_point=torch.tensor(zero_point),
bit_width=torch.tensor(float(bit_width)),
signed=True,
training=True)
quant_tensor_input
[17]:
IntQuantTensor(value=tensor([[[[ 7.2000e-03, -3.7000e-03, 7.7000e-03, -2.4000e-03, -8.9000e-03],
[-1.2000e-02, -8.1000e-03, 7.2000e-03, -1.1300e-02, -9.7000e-03],
[-1.0000e-03, 1.0100e-02, 3.8000e-03, -1.1900e-02, 6.9000e-03],
[ 8.3000e-03, 1.0000e-04, -6.9000e-03, 3.9000e-03, -5.4000e-03],
[ 1.1300e-02, -6.0000e-03, 9.7000e-03, 0.0000e+00, 1.0900e-02]],
[[-1.0900e-02, 1.1400e-02, -6.4000e-03, 9.2000e-03, 7.1000e-03],
[-6.0000e-04, 9.2000e-03, -8.5000e-03, 5.0000e-03, 6.5000e-03],
[-8.3000e-03, -1.2000e-03, 7.4000e-03, 9.2000e-03, -6.0000e-04],
[-2.1000e-03, 9.5000e-03, 3.0000e-04, -2.9000e-03, -6.5000e-03],
[-1.1800e-02, -4.8000e-03, 5.4000e-03, -2.5000e-03, 9.0000e-04]]]]), scale=tensor(1.0000e-04), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))
[ ]:
assert_with_message(quant_tensor_input.is_valid)
True
Note: how we are explicitly forcing value
, scale
, zero_point
and bit_width
to be floating-point torch.Tensor
, as this is expected by Brevitas but it’s currently not enforced automatically at initialization time.
If we now pass in quant_tensor_input
to return_quant_conv
, we will see that indeed the output is a valid 21-bit QuantTensor
:
[ ]:
return_quant_conv = QuantConv2d(
in_channels=2, out_channels=3, kernel_size=(3,3), bias=False, return_quant_tensor=True)
out_tensor = return_quant_conv(quant_tensor_input)
out_tensor
IntQuantTensor(value=tensor([[[[-0.0019, 0.0049, -0.0012],
[-0.0012, 0.0050, -0.0074],
[-0.0023, -0.0035, -0.0033]],
[[-0.0031, 0.0028, 0.0116],
[ 0.0079, 0.0046, 0.0022],
[ 0.0021, -0.0004, 0.0011]],
[[-0.0045, -0.0010, 0.0002],
[-0.0044, 0.0027, 0.0025],
[-0.0009, 0.0040, -0.0044]]]], grad_fn=<ConvolutionBackward0>), scale=tensor([[[[1.8307e-07]]]], grad_fn=<MulBackward0>), zero_point=tensor([0.]), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))
[ ]:
assert_with_message(out_tensor.is_valid)
True
We can also pass in an input QuantTensor
to a layer that has input_quant
enabled. In that case, the input gets re-quantized:
[ ]:
input_quant_conv(quant_tensor_input)
IntQuantTensor(value=tensor([[[[-0.0073, 0.0040, -0.0011],
[-0.0033, 0.0078, -0.0028],
[ 0.0005, -0.0025, -0.0008]],
[[ 0.0021, -0.0021, 0.0035],
[ 0.0012, -0.0016, -0.0023],
[-0.0010, -0.0015, 0.0040]],
[[-0.0010, 0.0047, 0.0025],
[-0.0014, 0.0021, -0.0039],
[ 0.0036, -0.0003, 0.0026]]]], grad_fn=<ConvolutionBackward0>), scale=tensor([[[[1.7393e-07]]]], grad_fn=<MulBackward0>), zero_point=tensor([0.]), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))
Output Quantization#
Let’s now look at would have happened if we instead enabled output quantization:
[ ]:
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
output_quant_conv = QuantConv2d(
in_channels=2, out_channels=3, kernel_size=(3,3), bias=False,
output_quant=Int8ActPerTensorFloat, return_quant_tensor=True)
out_tensor = output_quant_conv(torch.randn(1, 2, 5, 5))
out_tensor
IntQuantTensor(value=tensor([[[[-0.2117, -0.4811, 0.0385],
[-0.5100, -0.2502, -0.2213],
[-0.5773, 0.0192, -0.5485]],
[[ 0.1347, 0.8179, -1.2316],
[-0.6062, 0.4426, -0.3849],
[ 0.1732, -0.5100, -0.1251]],
[[ 1.0873, 0.2406, -0.2887],
[-0.4330, -0.4907, -0.2021],
[ 0.6447, 0.4811, 0.1347]]]], grad_fn=<MulBackward0>), scale=tensor(0.0096, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))
[ ]:
assert_with_message(out_tensor.is_valid)
True
QuantTensor
. However, what happened internally is quite different from before.bit_width
. In the previous case, we had that the bit_width
reflected the size of the output accumulator. In this case instead, we have bit_width=tensor(8.)
, which is what we expected since output_quant
had been set to an Int8 quantizer.Bias Quantization#
There is an important scenario where the various options we just saw make a practical difference, and it’s quantization of bias. In many contexts, such as in the ONNX standard opset and in FINN, bias is assumed to be quantized with scale factor equal to input scale weight scale*, which means that we need a valid quantized input somehow. A predefined bias quantizer that reflects that assumption is brevitas.quant.scaled_int.Int8Bias
. If we simply tried to set it to a QuantConv2d
without any sort of input quantization, we would get an error:
[ ]:
from brevitas.quant.scaled_int import Int8Bias
bias_quant_conv = QuantConv2d(
in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
bias_quant=Int8Bias, return_quant_tensor=True)
bias_quant_conv(torch.randn(1, 2, 5, 5))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[23], line 6
1 from brevitas.quant.scaled_int import Int8Bias
3 bias_quant_conv = QuantConv2d(
4 in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
5 bias_quant=Int8Bias, return_quant_tensor=True)
----> 6 bias_quant_conv(torch.randn(1, 2, 5, 5))
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/quant_conv.py:194, in QuantConv2d.forward(self, input)
193 def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
--> 194 return self.forward_impl(input)
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/quant_layer.py:152, in QuantWeightBiasInputOutputLayer.forward_impl(self, inp)
148 compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance(
149 quant_weight, QuantTensor)
150 if not (compute_output_quant_tensor or
151 self.output_quant.is_quant_enabled) and self.return_quant_tensor:
--> 152 raise RuntimeError("QuantLayer is not correctly configured")
154 if self.bias is not None:
155 quant_bias = self.bias_quant(self.bias, quant_input, quant_weight)
RuntimeError: QuantLayer is not correctly configured
We can solve the issue by passing in a valid QuantTensor
, e.g. the quant_tensor_input
we defined above:
[ ]:
bias_quant_conv(quant_tensor_input)
IntQuantTensor(value=tensor([[[[-2.4238e-03, -5.6598e-03, 5.1882e-03],
[-6.5582e-03, 8.9274e-03, 4.9640e-04],
[ 9.6283e-03, -1.7466e-03, -4.8311e-03]],
[[ 2.9322e-03, -3.1358e-03, -6.2727e-04],
[ 2.8723e-06, -3.7981e-03, 1.0973e-02],
[-4.1031e-03, 6.5909e-03, -4.2369e-03]],
[[ 4.1967e-03, -7.0733e-03, 1.6456e-03],
[ 1.8197e-03, -3.1683e-03, 4.8200e-03],
[-3.2585e-04, 3.1055e-03, 1.9703e-03]]]],
grad_fn=<ConvolutionBackward0>), scale=tensor([[[[1.7953e-07]]]], grad_fn=<MulBackward0>), zero_point=tensor([0.]), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))
Or by enabling input quantization and then passing in a float a torch.Tensor
or a QuantTensor
:
[ ]:
input_bias_quant_conv = QuantConv2d(
in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
input_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
input_bias_quant_conv(torch.randn(1, 2, 5, 5))
IntQuantTensor(value=tensor([[[[-0.2816, -0.5271, -0.1748],
[-0.4247, -0.1575, 0.0681],
[ 0.6528, -0.5346, -0.0657]],
[[ 0.2993, -0.3383, 0.3035],
[-0.4595, -0.6796, -0.9720],
[-0.1948, -0.5169, -0.2175]],
[[ 0.5586, 0.0665, -0.5807],
[ 0.5565, 0.1780, -0.0555],
[-0.1080, 0.0791, -0.2262]]]], grad_fn=<ConvolutionBackward0>), scale=tensor([[[[4.2009e-05]]]], grad_fn=<MulBackward0>), zero_point=tensor([0.]), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))
[ ]:
input_bias_quant_conv(quant_tensor_input)
IntQuantTensor(value=tensor([[[[-0.0058, 0.0030, 0.0030],
[-0.0013, -0.0002, 0.0043],
[-0.0061, 0.0033, -0.0001]],
[[ 0.0013, -0.0008, -0.0015],
[ 0.0011, 0.0012, -0.0012],
[-0.0013, -0.0020, 0.0002]],
[[-0.0061, 0.0053, -0.0004],
[ 0.0028, 0.0031, -0.0037],
[ 0.0027, -0.0048, -0.0044]]]], grad_fn=<ConvolutionBackward0>), scale=tensor([[[[1.7370e-07]]]], grad_fn=<MulBackward0>), zero_point=tensor([0.]), bit_width=tensor(22.), signed_t=tensor(True), training_t=tensor(True))
Notice how the output bit_width=tensor(22.)
. This is because, in the worst-case, summing a 21-bit integer (the size of the accumulator before bias is added) and an 8-bit integer (the size of quantized bias) gives a 22-bit integer.
Let’s try now to enable output quantization instead of input quantization. That wouldn’t have solved the problem with bias quantization, as output quantization is performed after bias is added:
[ ]:
output_bias_quant_conv = QuantConv2d(
in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
output_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
output_bias_quant_conv(torch.randn(1, 2, 5, 5))
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/proxy/parameter_quant.py:154: UserWarning: No quant bias cache found, set cache_inference_quant_bias=True and run an inference pass first
warn(
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[27], line 4
1 output_bias_quant_conv = QuantConv2d(
2 in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
3 output_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
----> 4 output_bias_quant_conv(torch.randn(1, 2, 5, 5))
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/quant_conv.py:194, in QuantConv2d.forward(self, input)
193 def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
--> 194 return self.forward_impl(input)
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/quant_layer.py:155, in QuantWeightBiasInputOutputLayer.forward_impl(self, inp)
152 raise RuntimeError("QuantLayer is not correctly configured")
154 if self.bias is not None:
--> 155 quant_bias = self.bias_quant(self.bias, quant_input, quant_weight)
156 else:
157 quant_bias = None
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File /proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/proxy/parameter_quant.py:330, in BiasQuantProxyFromInjector.forward(self, x, input, weight)
328 input_scale = self.scale()
329 if input_scale is None:
--> 330 raise RuntimeError("Input scale required")
331 elif self.requires_input_scale and input_scale is not None and self.is_quant_enabled:
332 input_scale = input_scale.view(-1)
RuntimeError: Input scale required
Not all scenarios require bias quantization to depend on the scale factor of the input. In those cases, biases can be quantized the same way weights are quantized, and have their own scale factor. In Brevitas, a predefined quantizer that reflects this other scenario is Int8BiasPerTensorFloatInternalScaling
. In this case then a valid quantized input is not required:
[ ]:
from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling
bias_internal_scale_quant_conv = QuantConv2d(
in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
bias_quant=Int8BiasPerTensorFloatInternalScaling, return_quant_tensor=False)
bias_internal_scale_quant_conv(torch.randn(1, 2, 5, 5))
tensor([[[[-0.4360, -0.2674, -0.4194],
[-0.2412, -0.6360, -0.6838],
[-0.5227, -0.0199, -0.1445]],
[[-0.3524, 0.8025, 0.2844],
[ 0.9945, -0.4782, 0.8064],
[ 0.5732, 0.1249, 0.3110]],
[[ 0.3223, 0.2530, 0.2753],
[ 0.5764, -0.2533, -0.0181],
[-0.4147, 0.2049, -0.9944]]]], grad_fn=<ConvolutionBackward0>)
There are a couple of situations to be aware of concerning bias quantization that can lead to changes in the output zero_point
.
Let’s consider the scenario where we compute the convolution between a quantized input tensor and quantized weights. In the first case, we then add an unquantized bias on top of the output. In the second one, we add a bias quantized with its own scale factor, e.g. with the Int8BiasPerTensorFloatInternalScaling
quantizer. In both cases, in order to make sure the output QuantTensor
is valid (i.e. the affine quantization invariant is respected), the output zero_point
becomes non-zero:
[ ]:
unquant_bias_input_quant_conv = QuantConv2d(
in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
input_quant=Int8ActPerTensorFloat, return_quant_tensor=True)
out_tensor = unquant_bias_input_quant_conv(torch.randn(1, 2, 5, 5))
out_tensor
IntQuantTensor(value=tensor([[[[-0.6912, 0.0086, 0.1628],
[-0.4786, -0.8073, 0.5224],
[ 0.4157, 0.4686, 0.2560]],
[[ 0.3170, -0.5486, -0.5216],
[ 0.1832, 1.0217, -0.3637],
[-0.1115, 0.6974, -0.0452]],
[[-0.6168, -0.5241, -0.6593],
[ 0.6408, 0.2674, 0.4537],
[-0.3744, -0.7771, -0.2848]]]], grad_fn=<ConvolutionBackward0>), scale=tensor([[[[3.0094e-05]]]], grad_fn=<MulBackward0>), zero_point=tensor([[[[ 339.3406]],
[[-4597.1797]],
[[-3452.3713]]]], grad_fn=<DivBackward0>), bit_width=tensor(21.), signed_t=tensor(True), training_t=tensor(True))
[ ]:
assert_with_message(out_tensor.is_valid)
True
Finally, an important point about QuantTensor
. With the exception of learned bit-width (which will be the subject of a separate tutorial) and some of the bias quantization scenarios we have just seen, usually returing a QuantTensor
is not necessary and can create extra complexity. This is why currently return_quant_tensor
defaults to False
. We can easily see it in an example:
[ ]:
bias_input_quant_conv = QuantConv2d(
in_channels=2, out_channels=3, kernel_size=(3,3), bias=True,
input_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias)
bias_input_quant_conv(torch.randn(1, 2, 5, 5))
tensor([[[[-0.2327, 0.9267, 0.6294],
[ 0.0901, 0.1027, -0.0727],
[-0.5614, 0.6182, 0.5394]],
[[ 0.4179, -0.5184, -0.2016],
[ 0.1390, -0.3925, -0.6171],
[ 0.4782, 0.0814, 0.6124]],
[[ 0.2896, -0.3779, 0.9408],
[-0.1334, 0.6186, 0.2167],
[-0.5926, 0.3690, -0.0284]]]], grad_fn=<ConvolutionBackward0>)
Altough not obvious, the output is actually implicitly quantized.