Quantized RNNs and LSTMs#

With version 0.8, Brevitas introduces support for quantized recurrent layers through QuantRNN and QuantLSTM. As with other Brevitas quantized layers, QuantRNN and QuantLSTM can be used as drop-in replacement for their floating-point variants, but they also go further and support some additional structural recurrent options not found in upstream PyTorch. Similarly to other quantized layers, both QuantRNN and QuantLSTM can take in different quantizers for different tensors involved in their computation.

QuantRNN#

We start by looking at QuantRNN:

[1]:
import inspect
from brevitas.nn import QuantRNN
from IPython.display import Markdown, display
import torch
torch.manual_seed(0)

# helpers
def assert_with_message(condition):
    assert condition
    print(condition)

def pretty_print_source(source):
    display(Markdown('```python\n' + source + '\n```'))

source = inspect.getsource(QuantRNN.__init__)
pretty_print_source(source)
def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_layers: int = 1,
        nonlinearity: str = 'tanh',
        bias: Optional[bool] = True,
        batch_first: bool = False,
        bidirectional: bool = False,
        weight_quant=Int8WeightPerTensorFloat,
        bias_quant=Int32Bias,
        io_quant=Int8ActPerTensorFloat,
        gate_acc_quant=Int8ActPerTensorFloat,
        shared_input_hidden_weights=False,
        return_quant_tensor: bool = False,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        **kwargs):
    super(QuantRNN, self).__init__(
        layer_impl=_QuantRNNLayer,
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        nonlinearity=nonlinearity,
        bias=bias,
        batch_first=batch_first,
        bidirectional=bidirectional,
        weight_quant=weight_quant,
        bias_quant=bias_quant,
        io_quant=io_quant,
        gate_acc_quant=gate_acc_quant,
        shared_input_hidden_weights=shared_input_hidden_weights,
        return_quant_tensor=return_quant_tensor,
        dtype=dtype,
        device=device,
        **kwargs)

QuantRNN supports all arguments of torch.nn.RNN, plus it exposes four different quantizers: weight_quant controls quantization of the weight tensor, bias_quant controls quantization of the bias, io_quant controls quantization of the input/output, and gate_acc_quant controls quantization of the output of the gate, before the nonlinearity is applied.

Compared to other layers like QuantLinear, a couple of things can be observed. First, input and output quantization are fused together into io_quant. This is because of the recurrent structure of RNN layers, where the output is fed back as input. Second, all quantizers are already set by default. This is different from a layer like QuantLinear, where only weight_quant has a default quantizer.

As with torch.nn.RNN, QuantRNN defines a stack of potentially multiple layers, controlled by setting num_layers, that can be set to bidirectional with bidirectional=True. Internally, QuantRNN is organized into a two level nesting of ModuleList, one for the different layer(s), and one for the direction(s):

[2]:
def rnn_sublayer(module, sublayer_number, right_to_left_direction):
    return module.layers[sublayer_number][1 if right_to_left_direction else 0]

quant_rnn = QuantRNN(input_size=10, hidden_size=20, num_layers=2, bidirectional=True)
quant_rnn_0_left_to_right = rnn_sublayer(quant_rnn, sublayer_number=0, right_to_left_direction=False)
quant_rnn_0_right_to_left = rnn_sublayer(quant_rnn, sublayer_number=0, right_to_left_direction=True)
quant_rnn_1_left_to_right = rnn_sublayer(quant_rnn, sublayer_number=1, right_to_left_direction=False)
quant_rnn_1_right_to_left = rnn_sublayer(quant_rnn, sublayer_number=1, right_to_left_direction=True)
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:55: UserWarning: Keyword arguments are being passed but they not being used.
  warn('Keyword arguments are being passed but they not being used.')

Setting num_layers > 1 and/or bidirectional=True has different implications on different quantizers. For weight_quant, gate_acc_quant and bias_quant, the same quantizer definition is shared among different layers/directions, but each layer/direction is allocated its own instance of the quantizer.

[3]:
assert_with_message(not quant_rnn_0_left_to_right.gate_params.input_weight.weight_quant is quant_rnn_1_right_to_left.gate_params.input_weight.weight_quant)
True
[4]:
assert_with_message(not quant_rnn_0_left_to_right.cell.gate_acc_quant is quant_rnn_1_right_to_left.cell.gate_acc_quant)
True
[5]:
assert_with_message(not quant_rnn_0_left_to_right.gate_params.bias_quant is quant_rnn_1_right_to_left.gate_params.bias_quant)
True

Conversely, for io_quant the same instance is gonna be shared among all layers and directions. This is to make sure that input/output tensors that are internally concatenated together share the same quantization scale/zero-point/bitwidth.

[6]:
assert_with_message(quant_rnn_0_left_to_right.io_quant is quant_rnn_1_right_to_left.io_quant)
True

Finally, QuantRNN supports an additional flag, shared_input_hidden_weights. This allows, whenever bidirectional=True, to share the input-to-hidden weights among the two directions, an optimization introduced first by DeepSpeech back in the day to save on the number of parameters, with minimal impact on the quality of results.

[7]:
from brevitas.nn import QuantRNN

def count_weights(model):
    return sum(p.numel() for n, p in model.named_parameters() if 'weight' in n)

quant_rnn_single_direction = QuantRNN(input_size=10, hidden_size=20, bidirectional=False, shared_input_hidden_weights=False)
quant_rnn_bidirectional = QuantRNN(input_size=10, hidden_size=20, bidirectional=True, shared_input_hidden_weights=False)
quant_rnn_bidirectional_shared_input_hidden = QuantRNN(input_size=10, hidden_size=20, bidirectional=True, shared_input_hidden_weights=True)

print(f"Number of weights for single direction QuantRNN: {count_weights(quant_rnn_single_direction)}")
print(f"Number of weights for bidirectional QuantRNN: {count_weights(quant_rnn_bidirectional)}")
print(f"Number of weights for bidirectional QuantRNN with shared input-hidden weights: {count_weights(quant_rnn_bidirectional_shared_input_hidden)}")


Number of weights for single direction QuantRNN: 600
Number of weights for bidirectional QuantRNN: 1200
Number of weights for bidirectional QuantRNN with shared input-hidden weights: 1000

As with other Brevitas layers, it’s possible to directly modify a quantizer by passing keyword arguments with a matching prefix. For example, to set 4b per-channel weights and 6b io quantization:

[8]:
quant_rnn_4b = QuantRNN(input_size=10, hidden_size=20, weight_bit_width=4, weight_scaling_per_output_channel=True, io_bit_width=6)
quant_rnn_4b_0_left_to_right = rnn_sublayer(quant_rnn_4b, sublayer_number=0, right_to_left_direction=False)

input_hidden_weight = quant_rnn_4b_0_left_to_right.gate_params.input_weight.quant_weight()
hidden_hidden_weight = quant_rnn_4b_0_left_to_right.gate_params.hidden_weight.quant_weight()

print(f"Input-hidden weight bit-width: {input_hidden_weight.bit_width}")
print(f"Hidden-hidden weight bit-width: {hidden_hidden_weight.bit_width}")
print(f"I/O quant bit-width: {quant_rnn_4b_0_left_to_right.io_quant.bit_width()}")
print(f"Input-hidden weight scale: {input_hidden_weight.scale}")
print(f"Hidden-hidden weight scale: {hidden_hidden_weight.scale}")
Input-hidden weight bit-width: 4.0
Hidden-hidden weight bit-width: 4.0
I/O quant bit-width: 6.0
Input-hidden weight scale: tensor([[0.0297],
        [0.0311],
        [0.0298],
        [0.0295],
        [0.0316],
        [0.0311],
        [0.0318],
        [0.0309],
        [0.0317],
        [0.0309],
        [0.0316],
        [0.0319],
        [0.0319],
        [0.0318],
        [0.0315],
        [0.0310],
        [0.0319],
        [0.0319],
        [0.0318],
        [0.0312]], grad_fn=<DivBackward0>)
Hidden-hidden weight scale: tensor([[0.0297],
        [0.0311],
        [0.0298],
        [0.0295],
        [0.0316],
        [0.0311],
        [0.0318],
        [0.0309],
        [0.0317],
        [0.0309],
        [0.0316],
        [0.0319],
        [0.0319],
        [0.0318],
        [0.0315],
        [0.0310],
        [0.0319],
        [0.0319],
        [0.0318],
        [0.0312]], grad_fn=<DivBackward0>)

QuantRNN follows the same forward interface of torch.nn.RNN, with a couple of exceptions. Packed variable length inputs are currently not supported, and unbatched inputs are not supported. Other than that, everything else is the same.

Inputs are expected to have shape (batch, sequence, features) for batch_first=False, or (sequence, batch, features) for batch_first=True. The layer returns a tuple with (outputs, hidden_states), where outputs has shape (sequence, batch, hidden_size * num_directions) with num_directions=2 when bidirectional=True, for batch_first=False, or (batch, sequence, hidden_size * num_directions) for batch_first=True, while hidden_states has shape (num_directions * num_layers, batch, hidden_size).

[9]:
import torch
from brevitas.nn import QuantRNN

quant_rnn = QuantRNN(input_size=10, hidden_size=20, batch_first=True)
outputs, hidden_states = quant_rnn(torch.randn(2, 5, 10))
print(f"Output size: {outputs.shape}")
print(f"Hidden states size: {hidden_states.shape}")
Output size: torch.Size([2, 5, 20])
Hidden states size: torch.Size([1, 2, 20])

As with other quantized layers, it’s possible to return a QuantTensor with return_quant_tensor=True. As a reminder, a QuantTensor is just a data structure that captures the quantization metadata associated with a quantized tensor:

[10]:
import torch
from brevitas.nn import QuantRNN

quant_rnn = QuantRNN(input_size=10, hidden_size=20, batch_first=True, return_quant_tensor=True)
quant_rnn(torch.randn(2, 5, 10))
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:216: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/torch/csrc/utils/python_arg_parser.cpp:350.)
  return torch.cat(outputs, dim=seq_dim)
[10]:
(IntQuantTensor(value=tensor([[[-0.0062, -0.2872,  0.7931,  0.4309,  0.5495, -0.4558,  0.2373,
            0.6807,  0.4621,  0.6120, -0.1124,  0.3872,  0.3060,  0.7681,
           -0.3684,  0.0437, -0.7369, -0.3247,  0.7743,  0.3372],
          [ 0.5450,  0.2962, -0.3969,  0.3555, -0.5628,  0.2429, -0.4976,
            0.1777, -0.1244,  0.0296, -0.2607,  0.0948,  0.5036, -0.3673,
            0.5213, -0.2962,  0.7524,  0.0770, -0.0948, -0.0948],
          [ 0.2691, -0.6624, -0.5434,  0.4968, -0.6624,  0.0983,  0.1345,
            0.1242, -0.0517, -0.3726,  0.3053,  0.1604,  0.3208,  0.0983,
            0.3105,  0.4243,  0.2794,  0.1604,  0.1035, -0.0724],
          [ 0.1284, -0.3337, -0.5263, -0.0449, -0.5263,  0.3081, -0.1733,
            0.5648,  0.4942, -0.1412,  0.1733,  0.3337,  0.6225,  0.3401,
            0.5070, -0.1412,  0.0642, -0.3722,  0.2888,  0.1155],
          [ 0.0579, -0.0058, -0.4054, -0.1564, -0.5560, -0.3301,  0.3533,
            0.0058, -0.1622, -0.3765,  0.1216,  0.0695, -0.4054,  0.0927,
            0.6139, -0.1390,  0.7066,  0.1274,  0.1622, -0.2896]],

         [[ 0.1374,  0.5745,  0.0624, -0.2373,  0.3060,  0.3310, -0.5183,
            0.1186,  0.1124,  0.2997,  0.0375,  0.6369, -0.5308,  0.6307,
           -0.5683,  0.7556,  0.2997, -0.4933,  0.3934, -0.4871],
          [ 0.1066, -0.1244, -0.1718,  0.4266,  0.5569,  0.0178,  0.1185,
           -0.3910,  0.2133,  0.0178, -0.1066, -0.2903,  0.1837, -0.2547,
           -0.2903,  0.0770,  0.3495,  0.2547,  0.2311, -0.6161],
          [-0.0880, -0.1966,  0.3001, -0.0569,  0.4140, -0.1552, -0.1345,
            0.4554,  0.5175,  0.1242, -0.2898,  0.1966, -0.0414,  0.3985,
           -0.1708, -0.0621, -0.1708,  0.0828,  0.2225,  0.0517],
          [ 0.2118,  0.5648, -0.2824, -0.0449,  0.5840,  0.3209, -0.5648,
            0.3530,  0.4043, -0.4942, -0.3786,  0.0257,  0.5327, -0.1990,
           -0.1348, -0.8215,  0.3016,  0.5327,  0.5648, -0.1155],
          [-0.0290, -0.1738,  0.0695,  0.3765,  0.1738,  0.0579, -0.4054,
           -0.2664,  0.4923,  0.2143, -0.4170,  0.4112,  0.5502,  0.7066,
           -0.6024,  0.7356,  0.0348,  0.1043, -0.1911, -0.4518]]],
        grad_fn=<CatBackward0>), scale=tensor(0.0059, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),
 IntQuantTensor(value=tensor([[[ 0.0579, -0.0058, -0.4054, -0.1564, -0.5560, -0.3301,  0.3533,
            0.0058, -0.1622, -0.3765,  0.1216,  0.0695, -0.4054,  0.0927,
            0.6139, -0.1390,  0.7066,  0.1274,  0.1622, -0.2896],
          [-0.0290, -0.1738,  0.0695,  0.3765,  0.1738,  0.0579, -0.4054,
           -0.2664,  0.4923,  0.2143, -0.4170,  0.4112,  0.5502,  0.7066,
           -0.6024,  0.7356,  0.0348,  0.1043, -0.1911, -0.4518]]],
        grad_fn=<UnsqueezeBackward0>), scale=tensor(0.0058, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))

Similarly, a QuantTensor can be passed in as input. However, whenever io_quant is set (which it is by default), the input is gonna be re-quantized:

[11]:
from brevitas.nn import QuantIdentity

quant_identity = QuantIdentity(return_quant_tensor=True)
quant_rnn(quant_identity(torch.randn(2, 5, 10)))
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/torch/_tensor.py:1255: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/c10/core/TensorImpl.h:1758.)
  return super(Tensor, self).rename(names)
[11]:
(IntQuantTensor(value=tensor([[[ 0.2111,  0.1267,  0.0060,  0.6153, -0.7721, -0.3740, -0.5188,
            0.6273,  0.4162,  0.2051,  0.2292,  0.7239,  0.6032,  0.2533,
            0.5067,  0.6635,  0.1206, -0.5730,  0.0483,  0.3318],
          [ 0.5742,  0.0194, -0.3807, -0.0710, -0.6000,  0.1807,  0.1355,
            0.4129,  0.3807,  0.3936, -0.0903,  0.1549,  0.1032,  0.0645,
            0.4775, -0.0645,  0.1161, -0.0065,  0.0194, -0.1097],
          [ 0.0453, -0.4533,  0.1036, -0.0194, -0.2979,  0.3432,  0.0777,
            0.6346, -0.0842,  0.3302,  0.4727,  0.4856, -0.4144,  0.7382,
           -0.0453,  0.5439,  0.2266, -0.4792,  0.4403, -0.1036],
          [ 0.3198,  0.2741, -0.6395,  0.0971, -0.6052, -0.5196,  0.1770,
           -0.5025, -0.1256,  0.2056,  0.2684, -0.6395, -0.0285, -0.7309,
            0.7194, -0.7194,  0.1542, -0.3426, -0.6509,  0.0343],
          [ 0.0000, -0.4004,  0.3151, -0.0263, -0.5842, -0.1641, -0.3939,
            0.0263, -0.2429,  0.6499, -0.5186,  0.1247, -0.2101,  0.8337,
           -0.1444,  0.6762, -0.1641, -0.5317, -0.1707, -0.0197]],

         [[ 0.2111, -0.2111, -0.3197, -0.0241, -0.5067, -0.0241, -0.2895,
            0.1749, -0.4283,  0.0000, -0.3680,  0.5308, -0.1267,  0.5248,
            0.1206,  0.2654,  0.6394, -0.1327,  0.2292, -0.3800],
          [ 0.6775, -0.3355, -0.1807,  0.2774, -0.8259, -0.2000, -0.0065,
            0.5678,  0.4000,  0.2258,  0.4387,  0.2710,  0.5355,  0.1290,
            0.6710, -0.0645, -0.2710, -0.3613,  0.6388,  0.5226],
          [-0.0065, -0.0777, -0.6475, -0.1684, -0.3820,  0.3885,  0.0065,
            0.1943, -0.3238, -0.2525, -0.1230, -0.0453, -0.0777,  0.3432,
            0.4921, -0.1101,  0.8224,  0.2396,  0.1554, -0.3885],
          [-0.0514, -0.4111, -0.4625, -0.1713, -0.3369,  0.2512, -0.2969,
           -0.4111, -0.2341,  0.3597, -0.1998,  0.0000,  0.2741,  0.7137,
           -0.1256,  0.1370, -0.0742, -0.5938, -0.5424, -0.4168],
          [ 0.3479,  0.5974, -0.3939,  0.1444, -0.6762,  0.1969, -0.6499,
            0.4136,  0.5383, -0.3085,  0.4070,  0.4070,  0.6630, -0.0263,
            0.2823, -0.1510,  0.1313, -0.5186,  0.4464, -0.0066]]],
        grad_fn=<CatBackward0>), scale=tensor(0.0062, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),
 IntQuantTensor(value=tensor([[[ 0.0000, -0.4004,  0.3151, -0.0263, -0.5842, -0.1641, -0.3939,
            0.0263, -0.2429,  0.6499, -0.5186,  0.1247, -0.2101,  0.8337,
           -0.1444,  0.6762, -0.1641, -0.5317, -0.1707, -0.0197],
          [ 0.3479,  0.5974, -0.3939,  0.1444, -0.6762,  0.1969, -0.6499,
            0.4136,  0.5383, -0.3085,  0.4070,  0.4070,  0.6630, -0.0263,
            0.2823, -0.1510,  0.1313, -0.5186,  0.4464, -0.0066]]],
        grad_fn=<UnsqueezeBackward0>), scale=tensor(0.0066, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))

As with torch.nn.RNN, by default the initial hidden state is initialized to 0, but a custom hidden state of shape (num_directions * num_layers, batch, hidden_size) can be passed in:

[12]:
quant_rnn(torch.randn(2, 5, 10), torch.randn(1, 2, 20))
[12]:
(IntQuantTensor(value=tensor([[[-0.3777, -0.2074,  0.7184,  0.9110,  0.0148, -0.1926, -0.7110,
            0.1926, -0.4222, -0.9480,  0.2592,  0.2222, -0.2370, -0.5407,
            0.5851, -0.2370,  0.3555,  0.1703,  0.4444, -0.2222],
          [ 0.4814, -0.7355, -0.1605,  0.3878, -0.5282,  0.2073,  0.0000,
            0.3677,  0.1805, -0.1204, -0.4614,  0.2474,  0.7021,  0.0401,
            0.4346,  0.4480, -0.3143,  0.0401,  0.6887,  0.6753],
          [ 0.5038, -0.3650, -0.6936,  0.0146, -0.9345,  0.0000,  0.1679,
           -0.3066,  0.1825,  0.4089,  0.0949, -0.2555,  0.3870, -0.2482,
            0.5914, -0.0803,  0.1314, -0.4235, -0.3797,  0.1168],
          [ 0.1795,  0.1795,  0.0449,  0.0449,  0.2308,  0.0898, -0.1282,
            0.5579,  0.1731, -0.1795,  0.1603,  0.3142,  0.1090,  0.5835,
           -0.1475,  0.0449,  0.1795, -0.0256,  0.8143, -0.2437],
          [-0.0066,  0.4804,  0.0066, -0.1184,  0.6843, -0.0197,  0.1448,
            0.1842,  0.6383, -0.1908, -0.0066, -0.1053, -0.1316,  0.0461,
           -0.0066, -0.2764,  0.3751,  0.3619,  0.5001, -0.1316]],

         [[ 0.5110, -0.3555,  0.6443, -0.8221,  0.4888, -0.2074,  0.0444,
            0.4888,  0.5999,  0.4370,  0.0000,  0.5036, -0.7628,  0.9332,
           -0.6147,  0.7332,  0.3629,  0.9184,  0.7702, -0.8887],
          [ 0.8492, -0.3410, -0.3878,  0.1404, -0.3410,  0.3143, -0.1204,
            0.5817,  0.4413,  0.5550,  0.6486, -0.1070,  0.6285, -0.4948,
            0.2006,  0.1605,  0.0535, -0.4079,  0.3811,  0.4948],
          [ 0.6060,  0.7666, -0.8688, -0.6863, -0.5111, -0.0803, -0.6425,
           -0.0146, -0.3577,  0.3431, -0.6571,  0.5622,  0.0000,  0.7374,
           -0.1314, -0.3650,  0.7520,  0.2336, -0.2847, -0.8250],
          [ 0.3014,  0.2950, -0.0898, -0.3142,  0.4040,  0.4681, -0.0705,
           -0.2052,  0.8143, -0.1603,  0.3334, -0.6733,  0.0834,  0.0898,
           -0.4937,  0.1924,  0.0064,  0.4104,  0.6348, -0.3527],
          [-0.6449,  0.5856, -0.0263, -0.0197,  0.8357, -0.5856,  0.0395,
           -0.3422,  0.8028,  0.0855, -0.7238, -0.6317,  0.2764, -0.0461,
           -0.4211, -0.5988,  0.2632,  0.4014, -0.7501, -0.5659]]],
        grad_fn=<CatBackward0>), scale=tensor(0.0069, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)),
 IntQuantTensor(value=tensor([[[-0.0066,  0.4804,  0.0066, -0.1184,  0.6843, -0.0197,  0.1448,
            0.1842,  0.6383, -0.1908, -0.0066, -0.1053, -0.1316,  0.0461,
           -0.0066, -0.2764,  0.3751,  0.3619,  0.5001, -0.1316],
          [-0.6449,  0.5856, -0.0263, -0.0197,  0.8357, -0.5856,  0.0395,
           -0.3422,  0.8028,  0.0855, -0.7238, -0.6317,  0.2764, -0.0461,
           -0.4211, -0.5988,  0.2632,  0.4014, -0.7501, -0.5659]]],
        grad_fn=<UnsqueezeBackward0>), scale=tensor(0.0066, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)))

As with other Brevitas layers, QuantRNN can be initialized from a pretrained floating-point torch.nn.RNN. For the purpose of this tutorial, can simulate it from an untrained torch.nn.RNN. As for other quantized layers, setting brevitas.config.IGNORE_MISSING_KEYS might be necessary (depending on which quantizers are set). With the default quantizers, an error on activation scale keys would be triggered, so we set it to true:

[13]:
from torch.nn import RNN
from brevitas.nn import QuantRNN
from brevitas import config

config.IGNORE_MISSING_KEYS = True

float_rnn = RNN(input_size=10, hidden_size=20)
quant_rnn = QuantRNN(input_size=10, hidden_size=20)
quant_rnn.load_state_dict(float_rnn.state_dict())
[13]:
<All keys matched successfully>

Similar to other quantized layers, quantization on a certain tensor can be disabled by setting a quantizer to None. Setting all quantizers to None recovers the same behaviour as the floating-point variant:

[14]:
from torch.nn import RNN
from brevitas.nn import QuantRNN
from brevitas import config
ATOL = 1e-6

config.IGNORE_MISSING_KEYS = True
torch.manual_seed(123456)

float_rnn = RNN(input_size=10, hidden_size=20)
quant_rnn = QuantRNN(input_size=10, hidden_size=20, weight_quant=None, io_quant=None, gate_acc_quant=None, bias_quant=None)

# Set both layers to the same state_dict
quant_rnn.load_state_dict(float_rnn.state_dict())

# Generate random input
inp = torch.randn(5, 2, 10)
# Check outputs are the same
assert_with_message(torch.allclose(quant_rnn(inp)[0], float_rnn(inp)[0], atol=ATOL))
# Check hidden states are the same
assert_with_message(torch.allclose(quant_rnn(inp)[1], float_rnn(inp)[1], atol=ATOL))
True
True

As with other quantized layers, we can leverage other prebuilt quantizers too. For example, to perform binary weight quantization:

[15]:
from brevitas.quant.binary import SignedBinaryWeightPerTensorConst

binary_rnn = QuantRNN(input_size=10, hidden_size=20, weight_quant=SignedBinaryWeightPerTensorConst)
binary_rnn(torch.randn(5, 2, 10))
[15]:
(tensor([[[-0.3684, -0.0946, -0.4480,  0.0050,  0.1543,  0.6322,  0.1643,
            0.1693,  0.2937,  0.5227,  0.2290, -0.3534, -0.3883,  0.4331,
            0.0000,  0.1693, -0.4331,  0.3634, -0.0050,  0.1941],
          [-0.2240, -0.0199, -0.3534,  0.0946,  0.3485,  0.3534,  0.1941,
            0.1643,  0.1145,  0.4082,  0.2987, -0.0647, -0.0946,  0.1543,
            0.1145, -0.0498,  0.0647,  0.1493,  0.0299, -0.1195]],

         [[ 0.0776, -0.0776, -0.5670,  0.4178, -0.0239,  0.4476,  0.2029,
           -0.0836,  0.3521,  0.7042,  0.6326,  0.4058, -0.4118, -0.0477,
           -0.2387, -0.0179, -0.4416, -0.4237, -0.3282, -0.1074],
          [-0.2626,  0.3581,  0.2328, -0.2268, -0.2686, -0.3103,  0.4536,
            0.3461,  0.3103,  0.3163,  0.3282, -0.3163, -0.7639,  0.0179,
            0.0060,  0.0776, -0.5849, -0.5252,  0.1790,  0.2984]],

         [[-0.5411,  0.3147,  0.6184, -0.3037, -0.1877, -0.3755,  0.1767,
           -0.1767, -0.1491, -0.1049,  0.2871, -0.0552, -0.0883,  0.0331,
            0.4749, -0.3147,  0.0331,  0.1767,  0.7013, -0.2264],
          [-0.0773, -0.1877,  0.4749, -0.2264, -0.4583,  0.0166, -0.3534,
           -0.5743,  0.5411,  0.1160, -0.0442, -0.0442,  0.3037,  0.0166,
           -0.1325, -0.1657, -0.0718,  0.1215,  0.6240,  0.3092]],

         [[-0.0627, -0.1882, -0.4642, -0.1443,  0.4705,  0.3137, -0.2447,
            0.0063, -0.1129,  0.3011,  0.1882,  0.2572,  0.2384, -0.0376,
            0.1129, -0.1380,  0.1380,  0.3011, -0.0251, -0.0063],
          [-0.6399,  0.5771,  0.2133,  0.2572,  0.7967,  0.1631, -0.2384,
           -0.4078, -0.3199,  0.0753,  0.6524,  0.0690, -0.1819, -0.2258,
            0.3889, -0.4078, -0.3764,  0.2258,  0.5458, -0.1756]],

         [[-0.5704,  0.6139, -0.1209, -0.5173,  0.4447,  0.0048,  0.3481,
           -0.5946, -0.5221,  0.1644, -0.2949, -0.1789, -0.1982,  0.2707,
            0.2900, -0.5124, -0.4399, -0.0725,  0.4351,  0.6091],
          [ 0.0435,  0.2030, -0.4447, -0.2659,  0.1547,  0.0580,  0.4254,
            0.5559,  0.1740,  0.4254,  0.4592,  0.2369, -0.4496, -0.3336,
            0.3046,  0.1354, -0.3626, -0.2659, -0.2079, -0.4641]]],
        grad_fn=<CatBackward0>),
 tensor([[[-0.5704,  0.6139, -0.1209, -0.5173,  0.4447,  0.0048,  0.3481,
           -0.5946, -0.5221,  0.1644, -0.2949, -0.1789, -0.1982,  0.2707,
            0.2900, -0.5124, -0.4399, -0.0725,  0.4351,  0.6091],
          [ 0.0435,  0.2030, -0.4447, -0.2659,  0.1547,  0.0580,  0.4254,
            0.5559,  0.1740,  0.4254,  0.4592,  0.2369, -0.4496, -0.3336,
            0.3046,  0.1354, -0.3626, -0.2659, -0.2079, -0.4641]]],
        grad_fn=<UnsqueezeBackward0>))

QuantLSTM#

We now look at QuantLSTM:

[16]:
import inspect
from brevitas.nn import QuantLSTM
from IPython.display import Markdown, display

def pretty_print_source(source):
    display(Markdown('```python\n' + source + '\n```'))

source = inspect.getsource(QuantLSTM.__init__)
pretty_print_source(source)
def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_layers: int = 1,
        bias: Optional[bool] = True,
        batch_first: bool = False,
        bidirectional: bool = False,
        weight_quant=Int8WeightPerTensorFloat,
        bias_quant=Int32Bias,
        io_quant=Int8ActPerTensorFloat,
        gate_acc_quant=Int8ActPerTensorFloat,
        sigmoid_quant=Uint8ActPerTensorFloat,
        tanh_quant=Int8ActPerTensorFloat,
        cell_state_quant=Int8ActPerTensorFloat,
        coupled_input_forget_gates: bool = False,
        cat_output_cell_states=True,
        shared_input_hidden_weights=False,
        shared_intra_layer_weight_quant=False,
        shared_intra_layer_gate_acc_quant=False,
        shared_cell_state_quant=True,
        return_quant_tensor: bool = False,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
        **kwargs):
    super(QuantLSTM, self).__init__(
        layer_impl=_QuantLSTMLayer,
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        bias=bias,
        batch_first=batch_first,
        bidirectional=bidirectional,
        weight_quant=weight_quant,
        bias_quant=bias_quant,
        io_quant=io_quant,
        gate_acc_quant=gate_acc_quant,
        sigmoid_quant=sigmoid_quant,
        tanh_quant=tanh_quant,
        cell_state_quant=cell_state_quant,
        cifg=coupled_input_forget_gates,
        shared_input_hidden_weights=shared_input_hidden_weights,
        shared_intra_layer_weight_quant=shared_intra_layer_weight_quant,
        shared_intra_layer_gate_acc_quant=shared_intra_layer_gate_acc_quant,
        shared_cell_state_quant=shared_cell_state_quant,
        return_quant_tensor=return_quant_tensor,
        dtype=dtype,
        device=device,
        **kwargs)
    if cat_output_cell_states and cell_state_quant is not None and not shared_cell_state_quant:
        raise RuntimeError("Concatenating cell states requires shared cell quantizers.")
    if return_quant_tensor and cell_state_quant is None:
        raise RuntimeError("return_quant_tensor=True requires cell_state_quant != None.")
    self.cat_output_cell_states = cat_output_cell_states

As with QuantRNN, QuantLSTM supports all options of torch.nn.LSTM. Everything said so far on QuantRNN applies to QuantLSTM too, but there a bunch of things more to be aware of.

QuantLSTM accepts a few more quantizers: sigmoid_quant, tanh_quant and cell_state_quant. As with QuantRNN, setting bidirectional=True and/or num_layers > 1 triggers sharing the instance of certain quantizers, but not others. In particular io_quant is shared among all layers and directions, as it was the case for QuantRNN. cell_state_quant is shared by default, but setting shared_cell_state_quant=False can disable that. However, that requires setting cat_output_cell_states=False, as otherwise we would find ourselves with a concenation of cell states that have been quantized with different quantizers, which is considered illegal in Brevitas.

LSTMs have four gates, each with its input-hidden and hidden-hidden weights. Brevitas takes in one weight_quant definition, but then four different instances of the weight quantizer are instantiated, and each gate is quantized differently, meaning it can have its own scale and zero-point. To force sharing the same weight quantizer across all gates, QuantLSTM supports setting shared_intra_layer_weight_quant=True. The same reasoning applies to the quantization of the output of each gate, before the activation functions, which is controlled by the gate_acc_quant quantizer. To force the same quantizer instance to be shared, shared_intra_layer_gate_acc_quant=True can be set. Different sigmoid and tanh functions instead are always allocated different quantizer instances.

Finally, QuantLSTM also supports the coupled input-forget gates (CIFG), where the forget gate is defined as forget_gate = 1 - input_gate, by setting coupled_input_forget_gates=True. This is an optimization to save on some compute and number of parameters, and is orthogonal to all other settings, such as shared_input_hidden_weights.

Just-in-time compilation#

Custom recurrent layer can be quite slow at training time. With quantization added in, it only gets worse. To mitigate the issue, both QuantRNN and QuantLSTM support jit compilation. Setting the env variable BREVITAS_JIT=1 triggers end-to-end compilation of the quantized recurrent cell through PyTorch TorchScript compiler.

Calibration#

As of version 0.8 of Brevitas, QuantRNN and QuantLSTM don’t support quantized activations calibration through calibration_modenor bias correction through bias_correction_mode. This will be added in a future version.

Export#

As of Brevitas 0.8, export of quantized recurrent layers is still a work in progress. As a proof of concept, there is partial support only for export of QuantLSTM to ONNX QCDQ, a way to represented quantization in ONNX only with standard ops (QuantizeLinear->Clip->DequantizeLinear), and to QONNX, a custom set of quantized operators introduced by Brevitas on top of ONNX. Two use cases are supported: (1) only weight_quant is set, supported by both QCDQ and QONNX, and (2) all quantizers are set, supported only by QONNX. In both cases, bidirectional=True and num_layers > 1 are supported. We first define an utility function to visualize the network through netron, which requires pip install netron.

[17]:
import time
from IPython.display import IFrame

def show_netron(model_path, port):
    try:
        import netron
        time.sleep(3.)
        netron.start(model_path, address=("localhost", port), browse=False)
        return IFrame(src=f"http://localhost:{port}/", width="100%", height=400)
    except:
        pass

QuantLSTM weight-only quantization export#

For use case (1), we leverage export to ONNX QCDQ. Qeight quantization is represented with QCDQ nodes, while the standard ONNX LSTM operator is adopted for the recurrent cell. With this approach, we can represent any weight bit width >= 2. Opset 14 is required. For the purpose of this 1 layer, 1 direction example we keep the default weight_quant set, we add weight_bit_width=4, while we disable the other quantizers:

[18]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only = QuantLSTM(input_size=10, hidden_size=20, weight_bit_width=4, io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_4b.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)

[19]:
show_netron(export_path, 8080)
Serving 'quant_lstm_weight_only_4b.onnx' at http://localhost:8080
[19]:

Note that the model can then be accelerated in onnxruntime:

[20]:
import onnxruntime as ort
import numpy as np

sess = ort.InferenceSession(export_path)
input_name = sess.get_inputs()[0].name
np_input = np.random.uniform(size=(5, 1, 10)).astype(np.float32)  # (seq_len, batch_size, input_size)
pred_onnx = sess.run(None, {input_name: np_input})
2024-09-12 12:18:52.692518968 [W:onnxruntime:, graph.cc:1283 Graph] Initializer onnx::LSTM_93 appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.

CIFG is also supported in a way that follows the semantics of onnxruntime:

[21]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only_cifg = QuantLSTM(
    input_size=10, hidden_size=20, coupled_input_forget_gates=True, weight_bit_width=4,
    io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_cifg_4b.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_cifg, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)
[22]:
show_netron(export_path, 8082)
Serving 'quant_lstm_weight_only_cifg_4b.onnx' at http://localhost:8082
[22]:

As before we can run it with onnxruntime:

[23]:
import onnxruntime as ort
import numpy as np

sess = ort.InferenceSession(export_path)
input_name = sess.get_inputs()[0].name
np_input = np.random.uniform(size=(5, 1, 10)).astype(np.float32)  # (seq_len, batch_size, input_size)
pred_onnx = sess.run(None, {input_name: np_input})
2024-09-12 12:18:53.086326293 [W:onnxruntime:, graph.cc:1283 Graph] Initializer onnx::LSTM_87 appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.

For the 2 layers, 2 directions use case:

[24]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(
    input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4,
    io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_bidirectional_2_layers.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:55: UserWarning: Keyword arguments are being passed but they not being used.
  warn('Keyword arguments are being passed but they not being used.')
[25]:
show_netron(export_path, 8083)
Serving 'quant_lstm_weight_only_bidirectional_2_layers.onnx' at http://localhost:8083
[25]:

Shared input-hidden weights are also supported:

[26]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only_bidirectional_2_layers_shared = QuantLSTM(
    input_size=10, hidden_size=20, bidirectional=True, shared_input_hidden_weights=True, weight_bit_width=4,
    io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_ih.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers_shared, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:55: UserWarning: Keyword arguments are being passed but they not being used.
  warn('Keyword arguments are being passed but they not being used.')
[27]:
show_netron(export_path, 8085)
Serving 'quant_lstm_weight_only_bidirectional_2_layers_shared_ih.onnx' at http://localhost:8085
[27]:

We can observe how setting shared_intra_layer_weight_quant=True affects the network. Now, for each layer and for each direction within a layer, all weight quantizers share the same scale/zp/bit-width:

[28]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(
    input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4, shared_intra_layer_weight_quant=True,
    io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:55: UserWarning: Keyword arguments are being passed but they not being used.
  warn('Keyword arguments are being passed but they not being used.')
[29]:
show_netron(export_path, 8086)
Serving 'quant_lstm_weight_only_bidirectional_2_layers_shared_q.onnx' at http://localhost:8086
[29]:

Alternatively, if we set both shared_input_hidden_weights=True and shared_intra_layer_weight_quant=True, the side effect is that all quantizers among both directions in a given layer are gonna have the same scale/zp/bit-width.

[30]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

quant_lstm_weight_only_bidirectional_2_layers = QuantLSTM(
    input_size=10, hidden_size=20, bidirectional=True, num_layers=2, weight_bit_width=4,
    shared_input_hidden_weights=True, shared_intra_layer_weight_quant=True,
    io_quant=None, bias_quant=None, gate_acc_quant=None, sigmoid_quant=None, tanh_quant=None, cell_state_quant=None)
export_path = 'quant_lstm_weight_only_bidirectional_2_layers_shared_q_ih.onnx'
exported_model = export_onnx_qcdq(quant_lstm_weight_only_bidirectional_2_layers, (torch.randn(5, 1, 10)), opset_version=14, export_path=export_path)
/proj/xlabs/users/nfraser/opt/miniforge3/envs/20231115_brv_pt1.13.1/lib/python3.10/site-packages/brevitas/nn/mixin/base.py:55: UserWarning: Keyword arguments are being passed but they not being used.
  warn('Keyword arguments are being passed but they not being used.')
[31]:
show_netron(export_path, 8087)
Serving 'quant_lstm_weight_only_bidirectional_2_layers_shared_q_ih.onnx' at http://localhost:8087
[31]:

QuantLSTM full quantization export#

For use case (2) we export to QONNX. Weight quantization is represented with Quant nodes, while a custom quantized LSTM operator QuantLSTMCell operator is generated for the recurrent cell. Note that currently QuantLSTMCell is not yet supported for execution in the qonnx library. In a future version of Brevitas, QuantLSTMCell will instead be lowered to a series of standard ops + Quant nodes. For the purpose example, we keep all quantizers at default:

[32]:
import torch
from brevitas.nn import QuantLSTM
from brevitas.export import export_qonnx

quant_lstm = QuantLSTM(input_size=10, hidden_size=20)
export_path = 'quant_lstm.onnx'
exported_model = export_qonnx(quant_lstm, (torch.randn(5, 1, 10)), export_path=export_path)
[33]:
show_netron(export_path, 8088)
Serving 'quant_lstm.onnx' at http://localhost:8088
[33]:

QuantLSTMCell takes the following series of inputs.

  • quant_input,

  • quant_hidden_state,

  • quant_cell_state,

  • quant_weight_ii,

  • quant_weight_if,

  • quant_weight_ic,

  • quant_weight_io,

  • quant_weight_hi,

  • quant_weight_hf,

  • quant_weight_hc,

  • quant_weight_ho,

  • quant_bias_input,

  • quant_bias_forget,

  • quant_bias_cell,

  • quant_bias_output,

  • output_scale,

  • output_zero_point,

  • output_bit_width,

  • cell_state_scale,

  • cell_state_zero_point,

  • cell_state_bit_width,

  • input_acc_scale,

  • input_acc_zero_point,

  • input_acc_bit_width,

  • forget_acc_scale,

  • forget_acc_zero_point,

  • forget_acc_bit_width,

  • cell_acc_scale,

  • cell_acc_zero_point,

  • cell_acc_bit_width,

  • output_acc_scale,

  • output_acc_zero_point,

  • output_acc_bit_width,

  • input_sigmoid_scale,

  • input_sigmoid_zero_point,

  • input_sigmoid_bit_width,

  • forget_sigmoid_scale,

  • forget_sigmoid_zero_point,

  • forget_sigmoid_bit_width,

  • cell_tanh_scale,

  • cell_tanh_zero_point,

  • cell_tanh_bit_width,

  • output_sigmoid_scale,

  • output_sigmoid_zero_point,

  • output_sigmoid_bit_width,

  • hidden_state_tanh_scale,

  • hidden_state_tanh_zero_point,

  • hidden_state_tanh_bit_width

All previous use cases illustrated for the weight-only quantization scenario are also supported.