brevitas.function package#

Submodules#

brevitas.function.ops module#

Implementation of various core operations often performed as part of quantization. The implemented functions adheres to the restriction imposed by Pytorch 1.1.0’s TorchScript compiler.

brevitas.function.ops.binary_sign(x)[source]#

Computes the 2-valued sign of an input tensor.

Parameters:

x (Tensor) – input tensor.

Returns:

the 2-valued sign tensor of the input tensor.

Return type:

Tensor

Examples

>>> binary_sign(torch.tensor([2.1, -0.3, 0.0]))
tensor([ 1., -1.,  1.])
brevitas.function.ops.dpu_round(x)[source]#

Compute DPU rounding.

Parameters:

x (Tensor) – input tensor.

Returns:

rounded input tensor.

Return type:

Tensor

Examples

>>> dpu_round(torch.tensor([-1.5, -0.5, 0.5, 1.5]))
tensor([-1., -0.,  0.,  2.])
brevitas.function.ops.get_upper_bound_on_l1_norm(accumulator_bit_width, input_bit_width, input_is_signed)[source]#

Calculate the upper bound on the l1-norm of the weights needed to guarantee overflow avoidance for a given accumulator bit width and input representation using the derivations from A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance by I.Colbert, A.Pappalardo, and J.Petri-Koenig. Note that this assumes integer quantization.

Return type:

Tensor

brevitas.function.ops.identity(x)[source]#

Identity function.

Parameters:

x (Tensor) – Input Tensor

Returns:

THe input tensor x

Return type:

Tensor

Examples

>>> identity(torch.tensor(1.7))
tensor(1.7)
brevitas.function.ops.max_float(exponent_bit_width, mantissa_bit_width, exponent_bias)[source]#
brevitas.function.ops.max_int(signed, narrow_range, bit_width)[source]#

Compute the maximum integer representable by a given number of bits.

Parameters:
  • signed (bool) – Indicates whether the represented integer is signed or not.

  • narrow_range (bool) – Indicates whether to narrow the maximum unsigned value represented by 1.

  • bit_width (Tensor) – Number of bits available for the representation.

Returns:

Maximum integer that can be represented according to the input arguments.

Return type:

Tensor

Examples

>>> max_int(signed=True, narrow_range=True, bit_width=torch.tensor(8))
tensor(127)
>>> max_int(signed=False, narrow_range=True, bit_width=torch.tensor(8))
tensor(254)
>>> max_int(signed=True, narrow_range=False, bit_width=torch.tensor(8))
tensor(127)
>>> max_int(signed=False, narrow_range=False, bit_width=torch.tensor(8))
tensor(255)
brevitas.function.ops.min_int(signed, narrow_range, bit_width)[source]#

Compute the minimum integer representable by a given number of bits.

Parameters:
  • signed (bool) – Indicates whether the represented integer is signed or not.

  • narrow_range (bool) – Indicates whether to narrow the minimum value represented by 1.

  • bit_width (Tensor) – Number of bits available for the representation.

Returns:

Maximum unsigned integer that can be represented according to the input arguments.

Return type:

Tensor

Examples

>>> min_int(signed=True, narrow_range=True, bit_width=torch.tensor(8))
tensor(-127)
>>> min_int(signed=False, narrow_range=True, bit_width=torch.tensor(8))
tensor(0)
>>> min_int(signed=True, narrow_range=False, bit_width=torch.tensor(8))
tensor(-128)
>>> min_int(signed=False, narrow_range=False, bit_width=torch.tensor(8))
tensor(0)
brevitas.function.ops.round_to_zero(x)[source]#

Compute rounding towards zero.

Parameters:

x (Tensor) – input tensor.

Returns:

rounded input tensor.

Return type:

Tensor

Examples

>>> round_to_zero(torch.tensor([-1.5, -0.5, 0.5, 1.5]))
tensor([-1., -0.,  0.,  1.])
brevitas.function.ops.tensor_clamp(x, min_val, max_val)[source]#

Generalized clamp function with support for tensors as clamping values.

Parameters:
  • x (Tensor) – Input on which to apply the clamp operation

  • min_val (Tensor) – Minimum values for the clamp operation.

  • max_val (Tensor) – Maximum values for the clamp operation.

Notes

x, min_val, max_val need to be broadcastable.

Notes

Differentiable w.r.t. x, min_val, max_val.

Return type:

Tensor

Returns:

Input x clamped between the provided minimum and maximum tensors.

Examples

>>> tensor_clamp(torch.tensor([1.7, -0.5, 0.1]), torch.tensor(0.0), torch.tensor(1.0))
tensor([1.0000, 0.0000, 0.1000])
brevitas.function.ops.tensor_clamp_(x, min_val, max_val)[source]#

In-place variant of tensor_clamp(). Not differentiable wrt to any of the inputs.

Return type:

Tensor

brevitas.function.ops_ste module#

Implementation of various functions with a straight-through gradient estimators, dispatched to either a native just-in-time compiled backend (when env BREVITAS_JIT=1) or to an autograd Function implemented in autograd_ste_ops (when env BREVITAS_JIT=0).

The native backend is enabled when BREVITAS_JIT is enabled to allow for end-to-end compilation of the built-in quantizers, since as of Pytorch 1.8.1 a torch.autograd.Function is not supported by the compiler.

brevitas.function.ops_ste.abs_binary_sign_grad(x)[source]#

Function that implements torch.abs() with a binary-sign backward, in order to have subgradient 1 in 0. Compare with torch.abs()’ subgradient of 0 in 0.

Return type:

Tensor

Notes

Wrapper for either abs_binary_sign_grad_impl() (with env BREVITAS_JIT=0) or its native just-in-time compiled variant (with BREVITAS_JIT=1).

Examples

>>> x = torch.tensor([0.0], requires_grad=True)
>>> y = abs_binary_sign_grad(x)
>>> y
tensor([0.], grad_fn=<AbsBinarySignGradFnBackward>)
>>> grad = torch.tensor([0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
brevitas.function.ops_ste.binary_sign_ste(x)[source]#

Function that implements binary_sign() with a straight-through gradient estimator.

Return type:

Tensor

Notes

Wrapper for either binary_sign_ste_impl() (with env BREVITAS_JIT=0) or its native just-in-time compiled variant (with BREVITAS_JIT=1).

Examples

>>> x = torch.tensor([1.7, 0.0, -0.5], requires_grad=True)
>>> y = binary_sign_ste(x)
>>> y
tensor([ 1.,  1., -1.], grad_fn=<BinarySignSteFnBackward>)
>>> grad = torch.tensor([0.1, 0.2, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
brevitas.function.ops_ste.ceil_ste(x)[source]#

Function that implements torch.ceil() with a straight-through gradient estimator.

Return type:

Tensor

Notes

Wrapper for either ceil_ste_impl() (with env BREVITAS_JIT=0) or its native just-in-time compiled variant (with BREVITAS_JIT=1).

Examples

>>> x = torch.tensor([1.7, -1.7], requires_grad=True)
>>> y = ceil_ste(x)
>>> y
tensor([ 2., -1.], grad_fn=<CeilSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
brevitas.function.ops_ste.dpu_round_ste(x)[source]#

Function that implements dpu_round() with a straight-through gradient estimator.

Return type:

Tensor

Notes

Wrapper for either dpu_round_ste_impl() (with env BREVITAS_JIT=0) or its native just-in-time compiled variant (with BREVITAS_JIT=1).

Examples

>>> x = torch.tensor([1.7, -1.7], requires_grad=True)
>>> y = dpu_round_ste(x)
>>> y
tensor([ 2., -2.], grad_fn=<DPURoundSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
brevitas.function.ops_ste.floor_ste(x)[source]#

Function that implements torch.floor() with a straight-through gradient estimator.

Return type:

Tensor

Notes

Wrapper for either floor_ste_impl() (with env BREVITAS_JIT=0) or its native just-in-time compiled variant (with BREVITAS_JIT=1).

Examples

>>> x = torch.tensor([1.7, -1.7], requires_grad=True)
>>> y = floor_ste(x)
>>> y
tensor([ 1., -2.], grad_fn=<FloorSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
brevitas.function.ops_ste.round_ste(x)[source]#

Function that implements torch.round() with a straight-through gradient estimator.

Return type:

Tensor

Notes

Wrapper for either round_ste_impl() (with env BREVITAS_JIT=0) or its native just-in-time compiled variant (with BREVITAS_JIT=1).

Examples

>>> x = torch.tensor([1.7, -1.7], requires_grad=True)
>>> y = round_ste(x)
>>> y
tensor([ 2., -2.], grad_fn=<RoundSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
brevitas.function.ops_ste.round_to_zero_ste(x)[source]#

Function that implements round_to_zero() with a straight-through gradient estimator.

Return type:

Tensor

Notes

Wrapper for either round_to_zero_ste_impl() (with env BREVITAS_JIT=0) or its native just-in-time compiled variant (with BREVITAS_JIT=1).

Examples

>>> x = torch.tensor([1.7, -1.7], requires_grad=True)
>>> y = round_to_zero_ste(x)
>>> y
tensor([ 1., -1.], grad_fn=<RoundToZeroSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
brevitas.function.ops_ste.scalar_clamp_min_ste(x, min_val)[source]#

Function that implements torch.clamp_min() with a straight-through gradient estimator for the gradient of output y w.r.t. to x, while the gradient of y w.r.t. to min_val is always None.

Parameters:
  • x (Tensor) – input tensor to clamp.

  • min_val (float) – scalar value to use as lower bound for the input tensor.

Returns:

clamped output tensor.

Return type:

Tensor

Notes

Wrapper for either scalar_clamp_min_ste_impl() (with env BREVITAS_JIT=0) or its C++ just-in-time compiled variant (with BREVITAS_JIT=1).

Examples

>>> x = torch.tensor([1.5, 0.4, -1.5], requires_grad=True)
>>> y = scalar_clamp_min_ste(x, -1.0)
>>> y
tensor([ 1.5000,  0.4000, -1.0000], grad_fn=<ScalarClampMinSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1, 0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
brevitas.function.ops_ste.scalar_clamp_ste(x, min_val, max_val)[source]#

Function that implements torch.clamp() with a straight-through gradient estimator for the gradient of the output w.r.t. to x, while the gradient of y w.r.t. to min_val and max_val is always None.

Parameters:
  • x (Tensor) – input tensor to clamp.

  • min_val (float) – scalar value to use as lower bound for the input tensor.

  • max_val (float) – scalar value to use as upper bound for the input tensor.

Returns:

clamped output tensor.

Return type:

Tensor

Notes

Wrapper for either scalar_clamp_ste_impl() (with env BREVITAS_JIT=0) or its C++ just-in-time compiled variant (with BREVITAS_JIT=1).

Examples

>>> x = torch.tensor([1.5, 0.4, -1.5], requires_grad=True)
>>> y = scalar_clamp_ste(x, -1.0, 1.0)
>>> y
tensor([ 1.0000,  0.4000, -1.0000], grad_fn=<ScalarClampSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1, 0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
brevitas.function.ops_ste.tensor_clamp_ste(x, min_val, max_val)[source]#

Function that implements tensor_clamp() with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val and max_val is always None.

Return type:

Tensor

Notes

Wrapper for either tensor_clamp_ste_impl() (with env BREVITAS_JIT=0) or its native just-in-time compiled variant (with BREVITAS_JIT=1).

Examples

>>> x = torch.tensor([1.5, 0.4, -1.5], requires_grad=True)
>>> y = tensor_clamp_ste(x, torch.tensor([-1.0, -0.5, -0.5]), torch.tensor([1.0, 0.5, 0.5]))
>>> y
tensor([ 1.0000,  0.4000, -0.5000], grad_fn=<TensorClampSteFnBackward>)
>>> grad = torch.tensor([0.1, -0.1, 0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
brevitas.function.ops_ste.tensor_clamp_ste_(x, min_val, max_val)[source]#

Function that implements tensor_clamp_() with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val and max_val is always None.

Return type:

Tensor

Notes

Wrapper for either tensor_clamp_ste_impl_() (with env BREVITAS_JIT=0) or its C++ just-in-time compiled variant (with BREVITAS_JIT=1).

Examples

>>> x = torch.tensor([1.5, 0.4, -1.5], requires_grad=True)
>>> y = tensor_clamp_ste_(x, torch.tensor([-1.0, -0.5, -0.5]), torch.tensor([1.0, 0.5, 0.5]))
>>> y
tensor([ 1.0000,  0.4000, -0.5000], grad_fn=<InplaceTensorClampSteFnBackward>)
>>> (y == x).all().item()
True
>>> grad = torch.tensor([0.1, -0.1, 0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True
brevitas.function.ops_ste.ternary_sign_ste(x)[source]#

Function that implements torch.sign() with a straight-through gradient estimator.

Return type:

Tensor

Notes

Wrapper for either ternary_sign_ste_impl() (with env BREVITAS_JIT=0) or its native just-in-time compiled variant (with BREVITAS_JIT=1).

Examples

>>> x = torch.tensor([1.7, 0.0, -0.5], requires_grad=True)
>>> y = ternary_sign_ste(x)
>>> y
tensor([ 1.,  0., -1.], grad_fn=<TernarySignSteFnBackward>)
>>> grad = torch.tensor([0.1, 0.2, -0.1])
>>> y.backward(grad)
>>> (x.grad == grad).all().item()
True

brevitas.function.shape module#

Implementation of various functions to compute shapes that induce flattening along certain dimensions of a tensor.

brevitas.function.shape.over_batch_over_output_channels(x)[source]#

Returns a shape s such that x.view(s) is a 3-dim tensor with batches at dimension 0, output channels at dimension 1, and any other feature at dimension 2.

Parameters:

x (Tensor) – Input tensor with batches at dimension 0 and output channels at dimension 1.

Returns:

A tuple containing the 3-dim shape.

Examples

>>> over_batch_over_output_channels(torch.randn([2, 3, 4, 3]))
(2, 3, -1)
brevitas.function.shape.over_batch_over_tensor(x)[source]#

Computes the shape s such that x.view(s) is a 2-dim tensor with batches at dimension 0 and any other feature at dimension 1.

Parameters:

x (Tensor) – Input tensor with batches at dimension 0.

Return type:

Tuple[int, int]

Returns:

A tuple containing the 2-dim shape.

Examples

>>> over_batch_over_tensor(torch.randn([2, 3, 4, 3]))
(2, -1)
brevitas.function.shape.over_output_channels(x)[source]#

Computes the shape s such that x.view(s) is a 2-dim tensor with output channels at dimension 0 and any other feature at dimension 1.

Args: x (Tensor): Input tensor with output channels at dimension 0.

Return type:

Tuple[int, int]

Returns:

A tuple containing the 2-dim shape.

Examples

>>> over_output_channels(torch.randn([2, 3, 4, 3]))
(2, -1)
brevitas.function.shape.over_output_features(x)[source]#

Returns a shape s such that x.view(s) is a 2-dim tensor with all features except the last one at dimension 0.

Parameters:

x (Tensor) – Input tensor with batches at dimension 0 and output channels at dimension 1.

Returns:

A tuple containing the 2-dim shape.

Examples

>>> over_output_features(torch.randn([2, 3, 4, 3]))
(24, 3)
brevitas.function.shape.over_tensor(x)[source]#

Computes the shape s such that x.view(s) is a flat tensor.

Parameters:

x (Tensor) – Input tensor.

Return type:

int

Returns:

The number -1 corresponding to a flat shape.

Examples

>>> over_tensor(torch.randn([2, 3, 4, 3]))
-1