Rotations in Brevitas#
Why are rotations important?#
Large Language Models exhibit computational invariance [1] meaning that applying an invertible linear operation to the output of certain modules (sources), and its inverse to the input of others (sinks), leaves the model’s output unchanged (assuming sufficient precision). This property allows for the selective application of random orthogonal transformations, which effectively mitigate weight and activation outliers, enhancing their quantization amenability [2]. Moreover, some rotations can be fused into the module weights, thus preserving floating-point inference performance.
Although random orthogonal rotations generally improve quantization amenability in low-bit regimes, the quantized network performance exhibits a large variance under different random rotations, as observed in [3]. Consequently, these authors propose to further optimize the rotations to improve quantized performance. In order to do so, they leverage the Cailey-SGD optimizer to ensure that the optimized rotations stay within the Stiefel manifold during optimization [4].
Rotations in Brevitas#
In Brevitas, we support different flavours of rotations:
Layerwise Rotations
Fused Rotations
Fused Optimized Rotations
Layerwise Rotations#
In order to merge rotations into subsequent and preceding layers, it is often required to apply certain transformations to the network, such as converting LayerNorm to RMSNorm, as well as identifying the suitable layers (or regions) in the network where such operation is doable.
This comes with some drawbacks, such as the impact on quantization of merging RMSNorm affine parameters, and with some limitations. In order to detect the regions of the model where it is possible to fuse rotations, we rely on FX graph, in particular the one that can be extracted with torch.dynamo._export.
This API allows to get a modular FX graph, which means that calls to modules (i.e., Linear module) are not flattened into their respective functional calls.
The limitation of this API resides in the fact that graph breaks are not allowed, which means that a the entire network must contain only operators that can be traced through torch.dynamo. Although compatibility has greatly improved in the past months, custom networks and operators might still not be supported.
For all these reasons, we offer the possibility of applying rotations without having to worry about merging the rotations in the preceding layers. From our experiments, this provides generally better accuracy compared to the case where rotations are fused, although the computational price to pay is generally higher.
In its most simple form, this can be done in the following way:
from brevitas.graph.equalize import LayerwiseActivationRotation
eq = LayerwiseActivationRotation()
model = eq.apply(model)
Fused Rotations#
Following more closely the approach used in [1] and [2], it is possible to fuse rotations into preceding layers if the network structure allows it.
As mentioned, there are few assumptions in this case:
The class accepts a FX-traced module (e.g., through torch.dynamo._export)
LayerNorm have been transformed to RMSNorm if necessary
RMSNorm affine parameters are merged into adjecent layers
Brevitas offers utilities to facilitate these transformations, as well as a standalone class for FX-based rotation, called GraphRotationEqualization.
The most important options for the class are:
orphan_sink: Whether to add standalone, unmerged rotations for the regions that do not support fusion
full_rotation_method: Allows to select between Hadamard (‘had’) or orthogonal (‘ort’) rotations
sdpa_regions: detect where torch.nn.functional.scaled_dot_product is placed in the network and appropriately applies rotations in the layers around it.
use_parametrized_rotations: Register rotations are parametrization, allowing for a subsequent optimization, for example through Caley SGD.
It is important to note that sdpa_regions makes some assumptions about the structure of the attention block, with the presence of QKV linear layers before, and an Output linea layer afterwards.
from brevitas.graph.equalize import LayerNormToRMS
from brevitas.graph.equalize import MergeLnAffine
from brevitas.graph.equalize import GraphRotationEqualization
import torch
with torch.no_grad():
fx_model, guards = torch._dynamo.export(model)(example_input)
# Merge the learned (affine) LayerNorm parameters
eq = MergeLnAffine()
fx_model = eq.apply(fx_model)
# Convert LayerNorm to RMSNorm
eq = LayerNormToRMS()
fx_model = eq.apply(fx_model)
eq = GraphRotationEqualization(
orphan_sink=True
full_rotation_method='had',
sdpa_regions=True,
use_parametrized_rotations=False)
fx_model = eq.apply(fx_model)
If use_parametrized_rotations is False, this resembles closely the optimization structure presented in QuaRot [2].
Optimized Fused Rotations#
If use_parametrized_rotations is True, the rotations are registered as parametrizations ([5]), and as a parameters. After adding the quantization blocks to the network, it is thus necessary to optimize the rotations and then fuse them into the corresponding weights.
Brevitas offers an example of how to accomplish that in our LLM entrypoint (brevitas_ptq_llm).
In particular, there are two ways to apply optimized fused rotations through this entrypoint, FX and FUSED_NO_FX.
Although an FX representation of the model is fundamental for this process, it is possible to discard the FX representation once the correct regions have been identified, and the transformations can be then applied on the non-FX model.
The advantage of this approach is that we do not lose the hierarchical structure of the original model, which is a side effect of torch.dynamo._export. Maintaining this hierarchical structure allows us to optimize the execution of certain PTQ algorithms like GPTQ or Qronos, by optimizing one block at a time and caching intermediate representations.
Moreover, similarly to [3], Brevitas can leverage the Cailey-SGD optimizer to further optimize the rotations.
The rotation training procedure relies on the
HF Trainer class, and, therefore,
can be configured by passing arguments accepted by the dataclass
TrainingArguments.
Moreover, the number of samples used for rotation calibration can be configured through the parameter --nsamples-rot-calibration.
Following, we provide a minimal example configuration for optimizing, in a single GPU, the rotations
of a HuggingfaceTB/SmolLM2-135M model, with its weights quantized to 4 bits:
dataset: wikitext2
eval: true
model: HuggingfaceTB/SmolLM2-135M
rotation: fused_no_fx
optimize_rotations: true
nsamples_rot_calibration: 800
replace_rmsnorm: true
weight_bit_width: 4
dtype: float32
learning_rate: 1.5
weight_decay: 0.0
lr_scheduler_type: cosine
max_steps: 100
per_device_train_batch_size: 2
gradient_accumulation_steps: 4
save_safetensors: false
logging_steps: 10
log_on_each_node: false
Note that the training parameters used in the SpinQuant paper [3] can be found in their repository.
References