Quantized RNNs and LSTMs#
With version 0.8, Brevitas introduces support for quantized recurrent layers through QuantRNN
and QuantLSTM
. As with other Brevitas quantized layers, QuantRNN
and QuantLSTM
can be used as drop-in replacement for their floating-point variants, but they also go further and support some additional structural recurrent options not found in upstream PyTorch. Similarly to other quantized layers, both QuantRNN
and QuantLSTM
can take in different quantizers for different tensors
involved in their computation.
QuantRNN#
We start by looking at QuantRNN
:
[1]:
import inspect
from brevitas.nn import QuantRNN
from IPython.display import Markdown, display
import torch
torch.manual_seed(0)
# helpers
def assert_with_message(condition):
assert condition
print(condition)
def pretty_print_source(source):
display(Markdown('```python\n' + source + '\n```'))
source = inspect.getsource(QuantRNN.__init__)
pretty_print_source(source)
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
nonlinearity: str = 'tanh',
bias: Optional[bool] = True,
batch_first: bool = False,
bidirectional: bool = False,
weight_quant=Int8WeightPerTensorFloat,
bias_quant=Int32Bias,
io_quant=Int8ActPerTensorFloat,
gate_acc_quant=Int8ActPerTensorFloat,
shared_input_hidden_weights=False,
return_quant_tensor: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
**kwargs):
super(QuantRNN, self).__init__(
layer_impl=_QuantRNNLayer,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
nonlinearity=nonlinearity,
bias=bias,
batch_first=batch_first,
bidirectional=bidirectional,
weight_quant=weight_quant,
bias_quant=bias_quant,
io_quant=io_quant,
gate_acc_quant=gate_acc_quant,
shared_input_hidden_weights=shared_input_hidden_weights,
return_quant_tensor=return_quant_tensor,
dtype=dtype,
device=device,
**kwargs)
QuantRNN
supports all arguments of torch.nn.RNN
, plus it exposes four different quantizers: weight_quant
controls quantization of the weight tensor, bias_quant
controls quantization of the bias, io_quant
controls quantization of the input/output, and gate_acc_quant
controls quantization of the output of the gate, before the nonlinearity is applied.
Compared to other layers like QuantLinear
, a couple of things can be observed. First, input and output quantization are fused together into io_quant
. This is because of the recurrent structure of RNN layers, where the output is fed back as input. Second, all quantizers are already set by default. This is different from a layer like QuantLinear
, where only weight_quant
has a default quantizer.
As with torch.nn.RNN
, QuantRNN
defines a stack of potentially multiple layers, controlled by setting num_layers
, that can be set to bidirectional with bidirectional=True
. Internally, QuantRNN
is organized into a two level nesting of ModuleList
, one for the different layer(s), and one for the direction(s):
[2]:
def rnn_sublayer(module, sublayer_number, right_to_left_direction):
return module.layers[sublayer_number][1 if right_to_left_direction else 0]
quant_rnn = QuantRNN(input_size=10, hidden_size=20, num_layers=2, bidirectional=True)
quant_rnn_0_left_to_right = rnn_sublayer(quant_rnn, sublayer_number=0, right_to_left_direction=False)
quant_rnn_0_right_to_left = rnn_sublayer(quant_rnn, sublayer_number=0, right_to_left_direction=True)
quant_rnn_1_left_to_right = rnn_sublayer(quant_rnn, sublayer_number=1, right_to_left_direction=False)
quant_rnn_1_right_to_left = rnn_sublayer(quant_rnn, sublayer_number=1, right_to_left_direction=True)
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:55: UserWarning: Keyword arguments are being passed but they not being used.
warn('Keyword arguments are being passed but they not being used.')
Setting num_layers > 1
and/or bidirectional=True
has different implications on different quantizers. For weight_quant
, gate_acc_quant
and bias_quant
, the same quantizer definition is shared among different layers/directions, but each layer/direction is allocated its own instance of the quantizer.
[3]:
assert_with_message(not quant_rnn_0_left_to_right.gate_params.input_weight.weight_quant is quant_rnn_1_right_to_left.gate_params.input_weight.weight_quant)
True
[4]:
assert_with_message(not quant_rnn_0_left_to_right.cell.gate_acc_quant is quant_rnn_1_right_to_left.cell.gate_acc_quant)
True
[5]:
assert_with_message(not quant_rnn_0_left_to_right.gate_params.bias_quant is quant_rnn_1_right_to_left.gate_params.bias_quant)
True
Conversely, for io_quant
the same instance is gonna be shared among all layers and directions. This is to make sure that input/output tensors that are internally concatenated together share the same quantization scale/zero-point/bitwidth.
[6]:
assert_with_message(quant_rnn_0_left_to_right.io_quant is quant_rnn_1_right_to_left.io_quant)
True
Finally, QuantRNN
supports an additional flag, shared_input_hidden_weights
. This allows, whenever bidirectional=True
, to share the input-to-hidden weights among the two directions, an optimization introduced first by DeepSpeech back in the day to save on the number of parameters, with minimal impact on the quality of results.
[7]:
from brevitas.nn import QuantRNN
def count_weights(model):
return sum(p.numel() for n, p in model.named_parameters() if 'weight' in n)
quant_rnn_single_direction = QuantRNN(input_size=10, hidden_size=20, bidirectional=False, shared_input_hidden_weights=False)
quant_rnn_bidirectional = QuantRNN(input_size=10, hidden_size=20, bidirectional=True, shared_input_hidden_weights=False)
quant_rnn_bidirectional_shared_input_hidden = QuantRNN(input_size=10, hidden_size=20, bidirectional=True, shared_input_hidden_weights=True)
print(f"Number of weights for single direction QuantRNN: {count_weights(quant_rnn_single_direction)}")
print(f"Number of weights for bidirectional QuantRNN: {count_weights(quant_rnn_bidirectional)}")
print(f"Number of weights for bidirectional QuantRNN with shared input-hidden weights: {count_weights(quant_rnn_bidirectional_shared_input_hidden)}")
Number of weights for single direction QuantRNN: 600
Number of weights for bidirectional QuantRNN: 1200
Number of weights for bidirectional QuantRNN with shared input-hidden weights: 1000
As with other Brevitas layers, it’s possible to directly modify a quantizer by passing keyword arguments with a matching prefix. For example, to set 4b per-channel weights and 6b io quantization:
[8]:
quant_rnn_4b = QuantRNN(input_size=10, hidden_size=20, weight_bit_width=4, weight_scaling_per_output_channel=True, io_bit_width=6)
quant_rnn_4b_0_left_to_right = rnn_sublayer(quant_rnn_4b, sublayer_number=0, right_to_left_direction=False)
input_hidden_weight = quant_rnn_4b_0_left_to_right.gate_params.input_weight.quant_weight()
hidden_hidden_weight = quant_rnn_4b_0_left_to_right.gate_params.hidden_weight.quant_weight()
print(f"Input-hidden weight bit-width: {input_hidden_weight.bit_width}")
print(f"Hidden-hidden weight bit-width: {hidden_hidden_weight.bit_width}")
print(f"I/O quant bit-width: {quant_rnn_4b_0_left_to_right.io_quant.bit_width()}")
print(f"Input-hidden weight scale: {input_hidden_weight.scale}")
print(f"Hidden-hidden weight scale: {hidden_hidden_weight.scale}")
Input-hidden weight bit-width: 4.0
Hidden-hidden weight bit-width: 4.0
I/O quant bit-width: 6.0
Input-hidden weight scale: tensor([[0.0297],
[0.0311],
[0.0298],
[0.0295],
[0.0316],
[0.0311],
[0.0318],
[0.0309],
[0.0317],
[0.0309],
[0.0316],
[0.0319],
[0.0319],
[0.0318],
[0.0315],
[0.0310],
[0.0319],
[0.0319],
[0.0318],
[0.0312]], grad_fn=<DivBackward0>)
Hidden-hidden weight scale: tensor([[0.0297],
[0.0311],
[0.0298],
[0.0295],
[0.0316],
[0.0311],
[0.0318],
[0.0309],
[0.0317],
[0.0309],
[0.0316],
[0.0319],
[0.0319],
[0.0318],
[0.0315],
[0.0310],
[0.0319],
[0.0319],
[0.0318],
[0.0312]], grad_fn=<DivBackward0>)
QuantRNN
follows the same forward
interface of torch.nn.RNN
, with a couple of exceptions. Packed variable length inputs are currently not supported, and unbatched inputs are not supported. Other than that, everything else is the same.
Inputs are expected to have shape (batch, sequence, features)
for batch_first=False
, or (sequence, batch, features)
for batch_first=True
. The layer returns a tuple with (outputs, hidden_states)
, where outputs
has shape (sequence, batch, hidden_size * num_directions)
with num_directions=2
when bidirectional=True
, for batch_first=False
, or (batch, sequence, hidden_size * num_directions)
for batch_first=True
, while hidden_states
has shape
(num_directions * num_layers, batch, hidden_size)
.
[9]:
import torch
from brevitas.nn import QuantRNN
quant_rnn = QuantRNN(input_size=10, hidden_size=20, batch_first=True)
outputs, hidden_states = quant_rnn(torch.randn(2, 5, 10))
print(f"Output size: {outputs.shape}")
print(f"Hidden states size: {hidden_states.shape}")
Output size: torch.Size([2, 5, 20])
Hidden states size: torch.Size([1, 2, 20])
As with other quantized layers, it’s possible to return a QuantTensor
with return_quant_tensor=True
. As a reminder, a QuantTensor
is just a data structure that captures the quantization metadata associated with a quantized tensor:
[10]:
import torch
from brevitas.nn import QuantRNN
quant_rnn = QuantRNN(input_size=10, hidden_size=20, batch_first=True, return_quant_tensor=True)
quant_rnn(torch.randn(2, 5, 10))
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:216: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/torch/csrc/utils/python_arg_parser.cpp:350.)
return torch.cat(outputs, dim=seq_dim)
[10]:
(IntQuantTensor(value=tensor([[[-0.0062, -0.2872, 0.7931, 0.4309, 0.5495, -0.4558, 0.2373,
0.6807, 0.4621, 0.6120, -0.1124, 0.3872, 0.3060, 0.7681,
-0.3684, 0.0437, -0.7369, -0.3247, 0.7743, 0.3372],
[ 0.5450, 0.2962, -0.3969, 0.3555, -0.5628, 0.2429, -0.4976,
0.1777, -0.1244, 0.0296, -0.2607, 0.0948, 0.5036, -0.3673,
0.5213, -0.2962, 0.7524, 0.0770, -0.0948, -0.0948],
[ 0.2691, -0.6624, -0.5434, 0.4968, -0.6624, 0.0983, 0.1345,
0.1242, -0.0517, -0.3726, 0.3053, 0.1604, 0.3208, 0.0983,
0.3105, 0.4243, 0.2794, 0.1604, 0.1035, -0.0724],
[ 0.1284, -0.3337, -0.5263, -0.0449, -0.5263, 0.3081, -0.1733,
0.5648, 0.4942, -0.1412, 0.1733, 0.3337, 0.6225, 0.3401,
0.5070, -0.1412, 0.0642, -0.3722, 0.2888, 0.1155],
[ 0.0579, -0.0058, -0.4054, -0.1564, -0.5560, -0.3301, 0.3533,
0.0058, -0.1622, -0.3765, 0.1216, 0.0695, -0.4054, 0.0927,
0.6139, -0.1390, 0.7066, 0.1274, 0.1622, -0.2896]],
[[ 0.1374, 0.5745, 0.0624, -0.2373, 0.3060, 0.3310, -0.5183,
0.1186, 0.1124, 0.2997, 0.0375, 0.6369, -0.5308, 0.6307,
-0.5683, 0.7556, 0.2997, -0.4933, 0.3934, -0.4871],
[ 0.1066, -0.1244, -0.1718, 0.4266, 0.5569, 0.0178, 0.1185,
-0.3910, 0.2133, 0.0178, -0.1066, -0.2903, 0.1837, -0.2547,
-0.2903, 0.0770, 0.3495, 0.2547, 0.2311, -0.6161],
[-0.0880, -0.1966, 0.3001, -0.0569, 0.4140, -0.1552, -0.1345,
0.4554, 0.5175, 0.1242, -0.2898, 0.1966, -0.0414, 0.3985,
-0.1708, -0.0621, -0.1708, 0.0828, 0.2225, 0.0517],
[ 0.2118, 0.5648, -0.2824, -0.0449, 0.5840, 0.3209, -0.5648,
0.3530, 0.4043, -0.4942, -0.3786, 0.0257, 0.5327, -0.1990,
-0.1348, -0.8215, 0.3016, 0.5327, 0.5648, -0.1155],
[-0.0290, -0.1738, 0.0695, 0.3765, 0.1738, 0.0579, -0.4054,
-0.2664, 0.4923, 0.2143, -0.4170, 0.4112, 0.5502, 0.7066,
-0.6024, 0.7356, 0.0348, 0.1043, -0.1911, -0.4518]]],
grad_fn=<CatBackward0>), scale=tensor(0.0059, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),
IntQuantTensor(value=tensor([[[ 0.0579, -0.0058, -0.4054, -0.1564, -0.5560, -0.3301, 0.3533,
0.0058, -0.1622, -0.3765, 0.1216, 0.0695, -0.4054, 0.0927,
0.6139, -0.1390, 0.7066, 0.1274, 0.1622, -0.2896],
[-0.0290, -0.1738, 0.0695, 0.3765, 0.1738, 0.0579, -0.4054,
-0.2664, 0.4923, 0.2143, -0.4170, 0.4112, 0.5502, 0.7066,
-0.6024, 0.7356, 0.0348, 0.1043, -0.1911, -0.4518]]],
grad_fn=<UnsqueezeBackward0>), scale=tensor(0.0058, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))
Similarly, a QuantTensor
can be passed in as input. However, whenever io_quant
is set (which it is by default), the input is gonna be re-quantized:
[11]:
from brevitas.nn import QuantIdentity
quant_identity = QuantIdentity(return_quant_tensor=True)
quant_rnn(quant_identity(torch.randn(2, 5, 10)))
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/torch/_tensor.py:1255: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/c10/core/TensorImpl.h:1758.)
return super(Tensor, self).rename(names)
[11]:
(IntQuantTensor(value=tensor([[[ 0.2111, 0.1267, 0.0060, 0.6153, -0.7721, -0.3740, -0.5188,
0.6273, 0.4162, 0.2051, 0.2292, 0.7239, 0.6032, 0.2533,
0.5067, 0.6635, 0.1206, -0.5730, 0.0483, 0.3318],
[ 0.5742, 0.0194, -0.3807, -0.0710, -0.6000, 0.1807, 0.1355,
0.4129, 0.3807, 0.3936, -0.0903, 0.1549, 0.1032, 0.0645,
0.4775, -0.0645, 0.1161, -0.0065, 0.0194, -0.1097],
[ 0.0453, -0.4533, 0.1036, -0.0194, -0.2979, 0.3432, 0.0777,
0.6346, -0.0842, 0.3302, 0.4727, 0.4856, -0.4144, 0.7382,
-0.0453, 0.5439, 0.2266, -0.4792, 0.4403, -0.1036],
[ 0.3198, 0.2741, -0.6395, 0.0971, -0.6052, -0.5196, 0.1770,
-0.5025, -0.1256, 0.2056, 0.2684, -0.6395, -0.0285, -0.7309,
0.7194, -0.7194, 0.1542, -0.3426, -0.6509, 0.0343],
[ 0.0000, -0.4004, 0.3151, -0.0263, -0.5842, -0.1641, -0.3939,
0.0263, -0.2429, 0.6499, -0.5186, 0.1247, -0.2101, 0.8337,
-0.1444, 0.6762, -0.1641, -0.5317, -0.1707, -0.0197]],
[[ 0.2111, -0.2111, -0.3197, -0.0241, -0.5067, -0.0241, -0.2895,
0.1749, -0.4283, 0.0000, -0.3680, 0.5308, -0.1267, 0.5248,
0.1206, 0.2654, 0.6394, -0.1327, 0.2292, -0.3800],
[ 0.6775, -0.3355, -0.1807, 0.2774, -0.8259, -0.2000, -0.0065,
0.5678, 0.4000, 0.2258, 0.4387, 0.2710, 0.5355, 0.1290,
0.6710, -0.0645, -0.2710, -0.3613, 0.6388, 0.5226],
[-0.0065, -0.0777, -0.6475, -0.1684, -0.3820, 0.3885, 0.0065,
0.1943, -0.3238, -0.2525, -0.1230, -0.0453, -0.0777, 0.3432,
0.4921, -0.1101, 0.8224, 0.2396, 0.1554, -0.3885],
[-0.0514, -0.4111, -0.4625, -0.1713, -0.3369, 0.2512, -0.2969,
-0.4111, -0.2341, 0.3597, -0.1998, 0.0000, 0.2741, 0.7137,
-0.1256, 0.1370, -0.0742, -0.5938, -0.5424, -0.4168],
[ 0.3479, 0.5974, -0.3939, 0.1444, -0.6762, 0.1969, -0.6499,
0.4136, 0.5383, -0.3085, 0.4070, 0.4070, 0.6630, -0.0263,
0.2823, -0.1510, 0.1313, -0.5186, 0.4464, -0.0066]]],
grad_fn=<CatBackward0>), scale=tensor(0.0062, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),
IntQuantTensor(value=tensor([[[ 0.0000, -0.4004, 0.3151, -0.0263, -0.5842, -0.1641, -0.3939,
0.0263, -0.2429, 0.6499, -0.5186, 0.1247, -0.2101, 0.8337,
-0.1444, 0.6762, -0.1641, -0.5317, -0.1707, -0.0197],
[ 0.3479, 0.5974, -0.3939, 0.1444, -0.6762, 0.1969, -0.6499,
0.4136, 0.5383, -0.3085, 0.4070, 0.4070, 0.6630, -0.0263,
0.2823, -0.1510, 0.1313, -0.5186, 0.4464, -0.0066]]],
grad_fn=<UnsqueezeBackward0>), scale=tensor(0.0066, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))
As with torch.nn.RNN
, by default the initial hidden state is initialized to 0, but a custom hidden state of shape (num_directions * num_layers, batch, hidden_size)
can be passed in:
[12]:
quant_rnn(torch.randn(2, 5, 10), torch.randn(1, 2, 20))
[12]:
(IntQuantTensor(value=tensor([[[-0.3777, -0.2074, 0.7184, 0.9110, 0.0148, -0.1926, -0.7110,
0.1926, -0.4222, -0.9480, 0.2592, 0.2222, -0.2370, -0.5407,
0.5851, -0.2370, 0.3555, 0.1703, 0.4444, -0.2222],
[ 0.4814, -0.7355, -0.1605, 0.3878, -0.5282, 0.2073, 0.0000,
0.3677, 0.1805, -0.1204, -0.4614, 0.2474, 0.7021, 0.0401,
0.4346, 0.4480, -0.3143, 0.0401, 0.6887, 0.6753],
[ 0.5038, -0.3650, -0.6936, 0.0146, -0.9345, 0.0000, 0.1679,
-0.3066, 0.1825, 0.4089, 0.0949, -0.2555, 0.3870, -0.2482,
0.5914, -0.0803, 0.1314, -0.4235, -0.3797, 0.1168],
[ 0.1795, 0.1795, 0.0449, 0.0449, 0.2308, 0.0898, -0.1282,
0.5579, 0.1731, -0.1795, 0.1603, 0.3142, 0.1090, 0.5835,
-0.1475, 0.0449, 0.1795, -0.0256, 0.8143, -0.2437],
[-0.0066, 0.4804, 0.0066, -0.1184, 0.6843, -0.0197, 0.1448,
0.1842, 0.6383, -0.1908, -0.0066, -0.1053, -0.1316, 0.0461,
-0.0066, -0.2764, 0.3751, 0.3619, 0.5001, -0.1316]],
[[ 0.5110, -0.3555, 0.6443, -0.8221, 0.4888, -0.2074, 0.0444,
0.4888, 0.5999, 0.4370, 0.0000, 0.5036, -0.7628, 0.9332,
-0.6147, 0.7332, 0.3629, 0.9184, 0.7702, -0.8887],
[ 0.8492, -0.3410, -0.3878, 0.1404, -0.3410, 0.3143, -0.1204,
0.5817, 0.4413, 0.5550, 0.6486, -0.1070, 0.6285, -0.4948,
0.2006, 0.1605, 0.0535, -0.4079, 0.3811, 0.4948],
[ 0.6060, 0.7666, -0.8688, -0.6863, -0.5111, -0.0803, -0.6425,
-0.0146, -0.3577, 0.3431, -0.6571, 0.5622, 0.0000, 0.7374,
-0.1314, -0.3650, 0.7520, 0.2336, -0.2847, -0.8250],
[ 0.3014, 0.2950, -0.0898, -0.3142, 0.4040, 0.4681, -0.0705,
-0.2052, 0.8143, -0.1603, 0.3334, -0.6733, 0.0834, 0.0898,
-0.4937, 0.1924, 0.0064, 0.4104, 0.6348, -0.3527],
[-0.6449, 0.5856, -0.0263, -0.0197, 0.8357, -0.5856, 0.0395,
-0.3422, 0.8028, 0.0855, -0.7238, -0.6317, 0.2764, -0.0461,
-0.4211, -0.5988, 0.2632, 0.4014, -0.7501, -0.5659]]],
grad_fn=<CatBackward0>), scale=tensor(0.0069, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),
IntQuantTensor(value=tensor([[[-0.0066, 0.4804, 0.0066, -0.1184, 0.6843, -0.0197, 0.1448,
0.1842, 0.6383, -0.1908, -0.0066, -0.1053, -0.1316, 0.0461,
-0.0066, -0.2764, 0.3751, 0.3619, 0.5001, -0.1316],
[-0.6449, 0.5856, -0.0263, -0.0197, 0.8357, -0.5856, 0.0395,
-0.3422, 0.8028, 0.0855, -0.7238, -0.6317, 0.2764, -0.0461,
-0.4211, -0.5988, 0.2632, 0.4014, -0.7501, -0.5659]]],
grad_fn=<UnsqueezeBackward0>), scale=tensor(0.0066, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))
As with other Brevitas layers, QuantRNN
can be initialized from a pretrained floating-point torch.nn.RNN
. For the purpose of this tutorial, can simulate it from an untrained torch.nn.RNN
. As for other quantized layers, setting brevitas.config.IGNORE_MISSING_KEYS
might be necessary (depending on which quantizers are set). With the default quantizers, an error on activation scale keys would be triggered, so we set it to true:
[13]:
from torch.nn import RNN
from brevitas.nn import QuantRNN
from brevitas import config
config.IGNORE_MISSING_KEYS = True
float_rnn = RNN(input_size=10, hidden_size=20)
quant_rnn = QuantRNN(input_size=10, hidden_size=20)
quant_rnn.load_state_dict(float_rnn.state_dict())
[13]:
<All keys matched successfully>
Similar to other quantized layers, quantization on a certain tensor can be disabled by setting a quantizer to None
. Setting all quantizers to None
recovers the same behaviour as the floating-point variant:
[14]:
from torch.nn import RNN
from brevitas.nn import QuantRNN
from brevitas import config
ATOL = 1e-6
config.IGNORE_MISSING_KEYS = True
torch.manual_seed(123456)
float_rnn = RNN(input_size=10, hidden_size=20)
quant_rnn = QuantRNN(input_size=10, hidden_size=20, weight_quant=None, io_quant=None, gate_acc_quant=None, bias_quant=None)
# Set both layers to the same state_dict
quant_rnn.load_state_dict(float_rnn.state_dict())
# Generate random input
inp = torch.randn(5, 2, 10)
# Check outputs are the same
assert_with_message(torch.allclose(quant_rnn(inp)[0], float_rnn(inp)[0], atol=ATOL))
# Check hidden states are the same
assert_with_message(torch.allclose(quant_rnn(inp)[1], float_rnn(inp)[1], atol=ATOL))
True
True
As with other quantized layers, we can leverage other prebuilt quantizers too. For example, to perform binary weight quantization:
[15]:
from brevitas.quant.binary import SignedBinaryWeightPerTensorConst
binary_rnn = QuantRNN(input_size=10, hidden_size=20, weight_quant=SignedBinaryWeightPerTensorConst)
binary_rnn(torch.randn(5, 2, 10))
[15]:
(tensor([[[-0.3684, -0.0946, -0.4480, 0.0050, 0.1543, 0.6322, 0.1643,
0.1693, 0.2937, 0.5227, 0.2290, -0.3534, -0.3883, 0.4331,
0.0000, 0.1693, -0.4331, 0.3634, -0.0050, 0.1941],
[-0.2240, -0.0199, -0.3534, 0.0946, 0.3485, 0.3534, 0.1941,
0.1643, 0.1145, 0.4082, 0.2987, -0.0647, -0.0946, 0.1543,
0.1145, -0.0498, 0.0647, 0.1493, 0.0299, -0.1195]],
[[ 0.0776, -0.0776, -0.5670, 0.4178, -0.0239, 0.4476, 0.2029,
-0.0836, 0.3521, 0.7042, 0.6326, 0.4058, -0.4118, -0.0477,
-0.2387, -0.0179, -0.4416, -0.4237, -0.3282, -0.1074],
[-0.2626, 0.3581, 0.2328, -0.2268, -0.2686, -0.3103, 0.4536,
0.3461, 0.3103, 0.3163, 0.3282, -0.3163, -0.7639, 0.0179,
0.0060, 0.0776, -0.5849, -0.5252, 0.1790, 0.2984]],
[[-0.5411, 0.3147, 0.6184, -0.3037, -0.1877, -0.3755, 0.1767,
-0.1767, -0.1491, -0.1049, 0.2871, -0.0552, -0.0883, 0.0331,
0.4749, -0.3147, 0.0331, 0.1767, 0.7013, -0.2264],
[-0.0773, -0.1877, 0.4749, -0.2264, -0.4583, 0.0166, -0.3534,
-0.5743, 0.5411, 0.1160, -0.0442, -0.0442, 0.3037, 0.0166,
-0.1325, -0.1657, -0.0718, 0.1215, 0.6240, 0.3092]],
[[-0.0627, -0.1882, -0.4642, -0.1443, 0.4705, 0.3137, -0.2447,
0.0063, -0.1129, 0.3011, 0.1882, 0.2572, 0.2384, -0.0376,
0.1129, -0.1380, 0.1380, 0.3011, -0.0251, -0.0063],
[-0.6399, 0.5771, 0.2133, 0.2572, 0.7967, 0.1631, -0.2384,
-0.4078, -0.3199, 0.0753, 0.6524, 0.0690, -0.1819, -0.2258,
0.3889, -0.4078, -0.3764, 0.2258, 0.5458, -0.1756]],
[[-0.5704, 0.6139, -0.1209, -0.5173, 0.4447, 0.0048, 0.3481,
-0.5946, -0.5221, 0.1644, -0.2949, -0.1789, -0.1982, 0.2707,
0.2900, -0.5124, -0.4399, -0.0725, 0.4351, 0.6091],
[ 0.0435, 0.2030, -0.4447, -0.2659, 0.1547, 0.0580, 0.4254,
0.5559, 0.1740, 0.4254, 0.4592, 0.2369, -0.4496, -0.3336,
0.3046, 0.1354, -0.3626, -0.2659, -0.2079, -0.4641]]],
grad_fn=<CatBackward0>),
tensor([[[-0.5704, 0.6139, -0.1209, -0.5173, 0.4447, 0.0048, 0.3481,
-0.5946, -0.5221, 0.1644, -0.2949, -0.1789, -0.1982, 0.2707,
0.2900, -0.5124, -0.4399, -0.0725, 0.4351, 0.6091],
[ 0.0435, 0.2030, -0.4447, -0.2659, 0.1547, 0.0580, 0.4254,
0.5559, 0.1740, 0.4254, 0.4592, 0.2369, -0.4496, -0.3336,
0.3046, 0.1354, -0.3626, -0.2659, -0.2079, -0.4641]]],
grad_fn=<UnsqueezeBackward0>))
QuantLSTM#
We now look at QuantLSTM
:
[16]:
import inspect
from brevitas.nn import QuantLSTM
from IPython.display import Markdown, display
def pretty_print_source(source):
display(Markdown('```python\n' + source + '\n```'))
source = inspect.getsource(QuantLSTM.__init__)
pretty_print_source(source)
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
bias: Optional[bool] = True,
batch_first: bool = False,
bidirectional: bool = False,
weight_quant=Int8WeightPerTensorFloat,
bias_quant=Int32Bias,
io_quant=Int8ActPerTensorFloat,
gate_acc_quant=Int8ActPerTensorFloat,
sigmoid_quant=Uint8ActPerTensorFloat,
tanh_quant=Int8ActPerTensorFloat,
cell_state_quant=Int8ActPerTensorFloat,
coupled_input_forget_gates: bool = False,
cat_output_cell_states=True,
shared_input_hidden_weights=False,
shared_intra_layer_weight_quant=False,
shared_intra_layer_gate_acc_quant=False,
shared_cell_state_quant=True,
return_quant_tensor: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
**kwargs):
super(QuantLSTM, self).__init__(
layer_impl=_QuantLSTMLayer,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=batch_first,
bidirectional=bidirectional,
weight_quant=weight_quant,
bias_quant=bias_quant,
io_quant=io_quant,
gate_acc_quant=gate_acc_quant,
sigmoid_quant=sigmoid_quant,
tanh_quant=tanh_quant,
cell_state_quant=cell_state_quant,
cifg=coupled_input_forget_gates,
shared_input_hidden_weights=shared_input_hidden_weights,
shared_intra_layer_weight_quant=shared_intra_layer_weight_quant,
shared_intra_layer_gate_acc_quant=shared_intra_layer_gate_acc_quant,
shared_cell_state_quant=shared_cell_state_quant,
return_quant_tensor=return_quant_tensor,
dtype=dtype,
device=device,
**kwargs)
if cat_output_cell_states and cell_state_quant is not None and not shared_cell_state_quant:
raise RuntimeError("Concatenating cell states requires shared cell quantizers.")
if return_quant_tensor and cell_state_quant is None:
raise RuntimeError("return_quant_tensor=True requires cell_state_quant != None.")
self.cat_output_cell_states = cat_output_cell_states
As with QuantRNN
, QuantLSTM
supports all options of torch.nn.LSTM
. Everything said so far on QuantRNN
applies to QuantLSTM
too, but there a bunch of things more to be aware of.
QuantLSTM
accepts a few more quantizers: sigmoid_quant
, tanh_quant
and cell_state_quant
. As with QuantRNN
, setting bidirectional=True
and/or num_layers > 1
triggers sharing the instance of certain quantizers, but not others. In particular io_quant
is shared among all layers and directions, as it was the case for QuantRNN
. cell_state_quant
is shared by default, but setting shared_cell_state_quant=False
can disable that. However, that requires setting
cat_output_cell_states=False
, as otherwise we would find ourselves with a concenation of cell states that have been quantized with different quantizers, which is considered illegal in Brevitas.
LSTMs have four gates, each with its input-hidden and hidden-hidden weights. Brevitas takes in one weight_quant
definition, but then four different instances of the weight quantizer are instantiated, and each gate is quantized differently, meaning it can have its own scale and zero-point. To force sharing the same weight quantizer across all gates, QuantLSTM
supports setting shared_intra_layer_weight_quant=True
. The same reasoning applies to the quantization of the output of each
gate, before the activation functions, which is controlled by the gate_acc_quant
quantizer. To force the same quantizer instance to be shared, shared_intra_layer_gate_acc_quant=True
can be set. Different sigmoid and tanh functions instead are always allocated different quantizer instances.
Finally, QuantLSTM
also supports the coupled input-forget gates (CIFG), where the forget gate is defined as forget_gate = 1 - input_gate
, by setting coupled_input_forget_gates=True
. This is an optimization to save on some compute and number of parameters, and is orthogonal to all other settings, such as shared_input_hidden_weights
.
Just-in-time compilation#
Custom recurrent layer can be quite slow at training time. With quantization added in, it only gets worse. To mitigate the issue, both QuantRNN
and QuantLSTM
support jit compilation. Setting the env variable BREVITAS_JIT=1
triggers end-to-end compilation of the quantized recurrent cell through PyTorch TorchScript compiler.
Calibration#
As of version 0.8 of Brevitas, QuantRNN
and QuantLSTM
don’t support quantized activations calibration through calibration_mode
nor bias correction through bias_correction_mode
. This will be added in a future version.
Export#
As of Brevitas 0.8, export of quantized recurrent layers is still a work in progress. As a proof of concept, there is partial support only for export of QuantLSTM
to ONNX QCDQ, a way to represented quantization in ONNX only with standard ops (QuantizeLinear->Clip->DequantizeLinear), and to QONNX, a custom set of quantized operators introduced by Brevitas on top of ONNX. Two use cases are supported: (1) only weight_quant
is set, supported by both QCDQ and QONNX, and (2) all quantizers are
set, supported only by QONNX. In both cases, bidirectional=True
and num_layers > 1
are supported. We first define an utility function to visualize the network through netron, which requires pip install netron
.
[17]:
import time
from IPython.display import IFrame
def show_netron(model_path, port):
try:
import netron
time.sleep(3.)
netron.start(model_path, address=("localhost", port), browse=False)
return IFrame(src=f"http://localhost:{port}/", width="100%", height=400)
except:
pass
QuantLSTM weight-only quantization export#
For use case (1), we leverage export to ONNX QCDQ. Qeight quantization is represented with QCDQ nodes, while the standard ONNX LSTM
operator is adopted for the recurrent cell. With this approach, we can represent any weight bit width >= 2. Opset 14 is required. For the purpose of this 1 layer, 1 direction example we keep the default weight_quant
set, we add weight_bit_width=4
, while we disable the other quantizers:
[18]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq
quant_lstm_weight_only = QuantLSTM(input_size=10, hidden_size=20, weight_bit_width=4, io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_4b.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)
[19]:
show_netron(export_path, 8080)
Serving 'quant_lstm_weight_only_4b.onnx' at http://localhost:8080
[19]:
Note that the model can then be accelerated in onnxruntime
:
[20]:
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession(export_path)
input_name = sess.get_inputs()[0].name
np_input = np.random.uniform(size=(5, 1, 10)).astype(np.float32) # (seq_len, batch_size, input_size)
pred_onnx = sess.run(None, {input_name: np_input})
2024-09-12 12:18:52.692518968 [W:onnxruntime:, graph.cc:1283 Graph] Initializer onnx::LSTM_93 appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
CIFG is also supported in a way that follows the semantics of onnxruntime
:
[21]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq
quant_lstm_weight_only_cifg = QuantLSTM(
input_size=10, hidden_size=20, coupled_input_forget_gates=True, weight_bit_width=4,
io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_cifg_4b.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_cifg, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)
[22]:
show_netron(export_path, 8082)
Serving 'quant_lstm_weight_only_cifg_4b.onnx' at http://localhost:8082
[22]:
As before we can run it with onnxruntime
:
[23]:
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession(export_path)
input_name = sess.get_inputs()[0].name
np_input = np.random.uniform(size=(5, 1, 10)).astype(np.float32) # (seq_len, batch_size, input_size)
pred_onnx = sess.run(None, {input_name: np_input})
2024-09-12 12:18:53.086326293 [W:onnxruntime:, graph.cc:1283 Graph] Initializer onnx::LSTM_87 appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
For the 2 layers, 2 directions use case:
[24]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq
quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(
input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4,
io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_bidirectional_2_layers.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:55: UserWarning: Keyword arguments are being passed but they not being used.
warn('Keyword arguments are being passed but they not being used.')
[25]:
show_netron(export_path, 8083)
Serving 'quant_lstm_weight_only_bidirectional_2_layers.onnx' at http://localhost:8083
[25]:
Shared input-hidden weights are also supported:
[26]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq
quant_lstm_weight_only_bidirectional_2_layers_shared = QuantLSTM(
input_size=10, hidden_size=20, bidirectional=True, shared_input_hidden_weights=True, weight_bit_width=4,
io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_ih.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers_shared, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:55: UserWarning: Keyword arguments are being passed but they not being used.
warn('Keyword arguments are being passed but they not being used.')
[27]:
show_netron(export_path, 8085)
Serving 'quant_lstm_weight_only_bidirectional_2_layers_shared_ih.onnx' at http://localhost:8085
[27]:
We can observe how setting shared_intra_layer_weight_quant=True
affects the network. Now, for each layer and for each direction within a layer, all weight quantizers share the same scale/zp/bit-width:
[28]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq
quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(
input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4, shared_intra_layer_weight_quant=True,
io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:55: UserWarning: Keyword arguments are being passed but they not being used.
warn('Keyword arguments are being passed but they not being used.')
[29]:
show_netron(export_path, 8086)
Serving 'quant_lstm_weight_only_bidirectional_2_layers_shared_q.onnx' at http://localhost:8086
[29]:
Alternatively, if we set both shared_input_hidden_weights=True
and shared_intra_layer_weight_quant=True
, the side effect is that all quantizers among both directions in a given layer are gonna have the same scale/zp/bit-width.
[30]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq
quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(
input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4,
shared_input_hidden_weights=True, shared_intra_layer_weight_quant=True,
io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q_ih.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:55: UserWarning: Keyword arguments are being passed but they not being used.
warn('Keyword arguments are being passed but they not being used.')
[31]:
show_netron(export_path, 8087)
Serving 'quant_lstm_weight_only_bidirectional_2_layers_shared_q_ih.onnx' at http://localhost:8087
[31]:
QuantLSTM full quantization export#
For use case (2) we export to QONNX. Weight quantization is represented with Quant
nodes, while a custom quantized LSTM operator QuantLSTMCell
operator is generated for the recurrent cell. Note that currently QuantLSTMCell
is not yet supported for execution in the qonnx
library. In a future version of Brevitas, QuantLSTMCell
will instead be lowered to a series of standard ops + Quant
nodes. For the purpose example, we keep all quantizers at default:
[32]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_qonnx
quant_lstm = QuantLSTM(input_size=10, hidden_size=20)
export_path = 'quant_lstm.onnx'
exported_model = export_qonnx(quant_lstm, (torch.randn(5, 1, 10)), export_path=export_path)
[33]:
show_netron(export_path, 8088)
Serving 'quant_lstm.onnx' at http://localhost:8088
[33]:
QuantLSTMCell
takes the following series of inputs.
quant_input,
quant_hidden_state,
quant_cell_state,
quant_weight_ii,
quant_weight_if,
quant_weight_ic,
quant_weight_io,
quant_weight_hi,
quant_weight_hf,
quant_weight_hc,
quant_weight_ho,
quant_bias_input,
quant_bias_forget,
quant_bias_cell,
quant_bias_output,
output_scale,
output_zero_point,
output_bit_width,
cell_state_scale,
cell_state_zero_point,
cell_state_bit_width,
input_acc_scale,
input_acc_zero_point,
input_acc_bit_width,
forget_acc_scale,
forget_acc_zero_point,
forget_acc_bit_width,
cell_acc_scale,
cell_acc_zero_point,
cell_acc_bit_width,
output_acc_scale,
output_acc_zero_point,
output_acc_bit_width,
input_sigmoid_scale,
input_sigmoid_zero_point,
input_sigmoid_bit_width,
forget_sigmoid_scale,
forget_sigmoid_zero_point,
forget_sigmoid_bit_width,
cell_tanh_scale,
cell_tanh_zero_point,
cell_tanh_bit_width,
output_sigmoid_scale,
output_sigmoid_zero_point,
output_sigmoid_bit_width,
hidden_state_tanh_scale,
hidden_state_tanh_zero_point,
hidden_state_tanh_bit_width
All previous use cases illustrated for the weight-only quantization scenario are also supported.