brevitas.function package#
Submodules#
brevitas.function.ops module#
Implementation of various core operations often performed as part of quantization. The implemented functions adheres to the restriction imposed by Pytorch 1.1.0’s TorchScript compiler.
- brevitas.function.ops.binary_sign(x)[source]#
Computes the 2-valued sign of an input tensor.
- Parameters:
x (Tensor) – input tensor.
- Returns:
the 2-valued sign tensor of the input tensor.
- Return type:
Tensor
Examples
>>> binary_sign(torch.tensor([2.1, -0.3, 0.0])) tensor([ 1., -1., 1.])
- brevitas.function.ops.dpu_round(x)[source]#
Compute DPU rounding.
- Parameters:
x (Tensor) – input tensor.
- Returns:
rounded input tensor.
- Return type:
Tensor
Examples
>>> dpu_round(torch.tensor([-1.5, -0.5, 0.5, 1.5])) tensor([-1., -0., 0., 2.])
- brevitas.function.ops.get_upper_bound_on_l1_norm(accumulator_bit_width, input_bit_width, input_is_signed)[source]#
Calculate the upper bound on the l1-norm of the weights needed to guarantee overflow avoidance for a given accumulator bit width and input representation using the derivations from A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance by I.Colbert, A.Pappalardo, and J.Petri-Koenig. Note that this assumes integer quantization.
- Return type:
Tensor
- brevitas.function.ops.identity(x)[source]#
Identity function.
- Parameters:
x (Tensor) – Input Tensor
- Returns:
THe input tensor x
- Return type:
Tensor
Examples
>>> identity(torch.tensor(1.7)) tensor(1.7)
- brevitas.function.ops.max_int(signed, narrow_range, bit_width)[source]#
Compute the maximum integer representable by a given number of bits.
- Parameters:
- Returns:
Maximum integer that can be represented according to the input arguments.
- Return type:
Tensor
Examples
>>> max_int(signed=True, narrow_range=True, bit_width=torch.tensor(8)) tensor(127) >>> max_int(signed=False, narrow_range=True, bit_width=torch.tensor(8)) tensor(254) >>> max_int(signed=True, narrow_range=False, bit_width=torch.tensor(8)) tensor(127) >>> max_int(signed=False, narrow_range=False, bit_width=torch.tensor(8)) tensor(255)
- brevitas.function.ops.min_int(signed, narrow_range, bit_width)[source]#
Compute the minimum integer representable by a given number of bits.
- Parameters:
- Returns:
Maximum unsigned integer that can be represented according to the input arguments.
- Return type:
Tensor
Examples
>>> min_int(signed=True, narrow_range=True, bit_width=torch.tensor(8)) tensor(-127) >>> min_int(signed=False, narrow_range=True, bit_width=torch.tensor(8)) tensor(0) >>> min_int(signed=True, narrow_range=False, bit_width=torch.tensor(8)) tensor(-128) >>> min_int(signed=False, narrow_range=False, bit_width=torch.tensor(8)) tensor(0)
- brevitas.function.ops.round_to_zero(x)[source]#
Compute rounding towards zero.
- Parameters:
x (Tensor) – input tensor.
- Returns:
rounded input tensor.
- Return type:
Tensor
Examples
>>> round_to_zero(torch.tensor([-1.5, -0.5, 0.5, 1.5])) tensor([-1., -0., 0., 1.])
- brevitas.function.ops.tensor_clamp(x, min_val, max_val)[source]#
Generalized clamp function with support for tensors as clamping values.
- Parameters:
x (
Tensor
) – Input on which to apply the clamp operationmin_val (
Tensor
) – Minimum values for the clamp operation.max_val (
Tensor
) – Maximum values for the clamp operation.
Notes
x, min_val, max_val need to be broadcastable.
Notes
Differentiable w.r.t. x, min_val, max_val.
- Return type:
Tensor
- Returns:
Input x clamped between the provided minimum and maximum tensors.
Examples
>>> tensor_clamp(torch.tensor([1.7, -0.5, 0.1]), torch.tensor(0.0), torch.tensor(1.0)) tensor([1.0000, 0.0000, 0.1000])
- brevitas.function.ops.tensor_clamp_(x, min_val, max_val)[source]#
In-place variant of
tensor_clamp()
. Not differentiable wrt to any of the inputs.- Return type:
Tensor
brevitas.function.ops_ste module#
Implementation of various functions with a straight-through gradient estimators, dispatched to
either a native just-in-time compiled backend (when env BREVITAS_JIT=1
) or to an autograd
Function implemented in autograd_ste_ops
(when env BREVITAS_JIT=0
).
The native backend is enabled when BREVITAS_JIT
is enabled to allow for end-to-end compilation
of the built-in quantizers, since as of Pytorch 1.8.1 a torch.autograd.Function is not supported by
the compiler.
- brevitas.function.ops_ste.abs_binary_sign_grad(x)[source]#
Function that implements
torch.abs()
with a binary-sign backward, in order to have subgradient 1 in 0. Compare withtorch.abs()
’ subgradient of 0 in 0.- Return type:
Tensor
Notes
Wrapper for either
abs_binary_sign_grad_impl()
(with envBREVITAS_JIT=0
) or its native just-in-time compiled variant (withBREVITAS_JIT=1
).Examples
>>> x = torch.tensor([0.0], requires_grad=True) >>> y = abs_binary_sign_grad(x) >>> y tensor([0.], grad_fn=<AbsBinarySignGradFnBackward>) >>> grad = torch.tensor([0.1]) >>> y.backward(grad) >>> (x.grad == grad).all().item() True
- brevitas.function.ops_ste.binary_sign_ste(x)[source]#
Function that implements
binary_sign()
with a straight-through gradient estimator.- Return type:
Tensor
Notes
Wrapper for either
binary_sign_ste_impl()
(with envBREVITAS_JIT=0
) or its native just-in-time compiled variant (withBREVITAS_JIT=1
).Examples
>>> x = torch.tensor([1.7, 0.0, -0.5], requires_grad=True) >>> y = binary_sign_ste(x) >>> y tensor([ 1., 1., -1.], grad_fn=<BinarySignSteFnBackward>) >>> grad = torch.tensor([0.1, 0.2, -0.1]) >>> y.backward(grad) >>> (x.grad == grad).all().item() True
- brevitas.function.ops_ste.ceil_ste(x)[source]#
Function that implements
torch.ceil()
with a straight-through gradient estimator.- Return type:
Tensor
Notes
Wrapper for either
ceil_ste_impl()
(with envBREVITAS_JIT=0
) or its native just-in-time compiled variant (withBREVITAS_JIT=1
).Examples
>>> x = torch.tensor([1.7, -1.7], requires_grad=True) >>> y = ceil_ste(x) >>> y tensor([ 2., -1.], grad_fn=<CeilSteFnBackward>) >>> grad = torch.tensor([0.1, -0.1]) >>> y.backward(grad) >>> (x.grad == grad).all().item() True
- brevitas.function.ops_ste.dpu_round_ste(x)[source]#
Function that implements
dpu_round()
with a straight-through gradient estimator.- Return type:
Tensor
Notes
Wrapper for either
dpu_round_ste_impl()
(with envBREVITAS_JIT=0
) or its native just-in-time compiled variant (withBREVITAS_JIT=1
).Examples
>>> x = torch.tensor([1.7, -1.7], requires_grad=True) >>> y = dpu_round_ste(x) >>> y tensor([ 2., -2.], grad_fn=<DPURoundSteFnBackward>) >>> grad = torch.tensor([0.1, -0.1]) >>> y.backward(grad) >>> (x.grad == grad).all().item() True
- brevitas.function.ops_ste.floor_ste(x)[source]#
Function that implements
torch.floor()
with a straight-through gradient estimator.- Return type:
Tensor
Notes
Wrapper for either
floor_ste_impl()
(with envBREVITAS_JIT=0
) or its native just-in-time compiled variant (withBREVITAS_JIT=1
).Examples
>>> x = torch.tensor([1.7, -1.7], requires_grad=True) >>> y = floor_ste(x) >>> y tensor([ 1., -2.], grad_fn=<FloorSteFnBackward>) >>> grad = torch.tensor([0.1, -0.1]) >>> y.backward(grad) >>> (x.grad == grad).all().item() True
- brevitas.function.ops_ste.round_ste(x)[source]#
Function that implements
torch.round()
with a straight-through gradient estimator.- Return type:
Tensor
Notes
Wrapper for either
round_ste_impl()
(with envBREVITAS_JIT=0
) or its native just-in-time compiled variant (withBREVITAS_JIT=1
).Examples
>>> x = torch.tensor([1.7, -1.7], requires_grad=True) >>> y = round_ste(x) >>> y tensor([ 2., -2.], grad_fn=<RoundSteFnBackward>) >>> grad = torch.tensor([0.1, -0.1]) >>> y.backward(grad) >>> (x.grad == grad).all().item() True
- brevitas.function.ops_ste.round_to_zero_ste(x)[source]#
Function that implements
round_to_zero()
with a straight-through gradient estimator.- Return type:
Tensor
Notes
Wrapper for either
round_to_zero_ste_impl()
(with envBREVITAS_JIT=0
) or its native just-in-time compiled variant (withBREVITAS_JIT=1
).Examples
>>> x = torch.tensor([1.7, -1.7], requires_grad=True) >>> y = round_to_zero_ste(x) >>> y tensor([ 1., -1.], grad_fn=<RoundToZeroSteFnBackward>) >>> grad = torch.tensor([0.1, -0.1]) >>> y.backward(grad) >>> (x.grad == grad).all().item() True
- brevitas.function.ops_ste.scalar_clamp_min_ste(x, min_val)[source]#
Function that implements
torch.clamp_min()
with a straight-through gradient estimator for the gradient of output y w.r.t. tox
, while the gradient of y w.r.t. tomin_val
is alwaysNone
.- Parameters:
x (
Tensor
) – input tensor to clamp.min_val (
float
) – scalar value to use as lower bound for the input tensor.
- Returns:
clamped output tensor.
- Return type:
Tensor
Notes
Wrapper for either
scalar_clamp_min_ste_impl()
(with envBREVITAS_JIT=0
) or its C++ just-in-time compiled variant (withBREVITAS_JIT=1
).Examples
>>> x = torch.tensor([1.5, 0.4, -1.5], requires_grad=True) >>> y = scalar_clamp_min_ste(x, -1.0) >>> y tensor([ 1.5000, 0.4000, -1.0000], grad_fn=<ScalarClampMinSteFnBackward>) >>> grad = torch.tensor([0.1, -0.1, 0.1]) >>> y.backward(grad) >>> (x.grad == grad).all().item() True
- brevitas.function.ops_ste.scalar_clamp_ste(x, min_val, max_val)[source]#
Function that implements
torch.clamp()
with a straight-through gradient estimator for the gradient of the output w.r.t. tox
, while the gradient ofy
w.r.t. tomin_val
andmax_val
is alwaysNone
.- Parameters:
- Returns:
clamped output tensor.
- Return type:
Tensor
Notes
Wrapper for either
scalar_clamp_ste_impl()
(with envBREVITAS_JIT=0
) or its C++ just-in-time compiled variant (withBREVITAS_JIT=1
).Examples
>>> x = torch.tensor([1.5, 0.4, -1.5], requires_grad=True) >>> y = scalar_clamp_ste(x, -1.0, 1.0) >>> y tensor([ 1.0000, 0.4000, -1.0000], grad_fn=<ScalarClampSteFnBackward>) >>> grad = torch.tensor([0.1, -0.1, 0.1]) >>> y.backward(grad) >>> (x.grad == grad).all().item() True
- brevitas.function.ops_ste.tensor_clamp_ste(x, min_val, max_val)[source]#
Function that implements
tensor_clamp()
with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val and max_val is always None.- Return type:
Tensor
Notes
Wrapper for either
tensor_clamp_ste_impl()
(with envBREVITAS_JIT=0
) or its native just-in-time compiled variant (withBREVITAS_JIT=1
).Examples
>>> x = torch.tensor([1.5, 0.4, -1.5], requires_grad=True) >>> y = tensor_clamp_ste(x, torch.tensor([-1.0, -0.5, -0.5]), torch.tensor([1.0, 0.5, 0.5])) >>> y tensor([ 1.0000, 0.4000, -0.5000], grad_fn=<TensorClampSteFnBackward>) >>> grad = torch.tensor([0.1, -0.1, 0.1]) >>> y.backward(grad) >>> (x.grad == grad).all().item() True
- brevitas.function.ops_ste.tensor_clamp_ste_(x, min_val, max_val)[source]#
Function that implements
tensor_clamp_()
with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val and max_val is always None.- Return type:
Tensor
Notes
Wrapper for either
tensor_clamp_ste_impl_()
(with envBREVITAS_JIT=0
) or its C++ just-in-time compiled variant (withBREVITAS_JIT=1
).Examples
>>> x = torch.tensor([1.5, 0.4, -1.5], requires_grad=True) >>> y = tensor_clamp_ste_(x, torch.tensor([-1.0, -0.5, -0.5]), torch.tensor([1.0, 0.5, 0.5])) >>> y tensor([ 1.0000, 0.4000, -0.5000], grad_fn=<InplaceTensorClampSteFnBackward>) >>> (y == x).all().item() True >>> grad = torch.tensor([0.1, -0.1, 0.1]) >>> y.backward(grad) >>> (x.grad == grad).all().item() True
- brevitas.function.ops_ste.ternary_sign_ste(x)[source]#
Function that implements
torch.sign()
with a straight-through gradient estimator.- Return type:
Tensor
Notes
Wrapper for either
ternary_sign_ste_impl()
(with envBREVITAS_JIT=0
) or its native just-in-time compiled variant (withBREVITAS_JIT=1
).Examples
>>> x = torch.tensor([1.7, 0.0, -0.5], requires_grad=True) >>> y = ternary_sign_ste(x) >>> y tensor([ 1., 0., -1.], grad_fn=<TernarySignSteFnBackward>) >>> grad = torch.tensor([0.1, 0.2, -0.1]) >>> y.backward(grad) >>> (x.grad == grad).all().item() True
brevitas.function.shape module#
Implementation of various functions to compute shapes that induce flattening along certain dimensions of a tensor.
- brevitas.function.shape.over_batch_over_output_channels(x)[source]#
Returns a shape s such that x.view(s) is a 3-dim tensor with batches at dimension 0, output channels at dimension 1, and any other feature at dimension 2.
- Parameters:
x (Tensor) – Input tensor with batches at dimension 0 and output channels at dimension 1.
- Returns:
A tuple containing the 3-dim shape.
Examples
>>> over_batch_over_output_channels(torch.randn([2, 3, 4, 3])) (2, 3, -1)
- brevitas.function.shape.over_batch_over_tensor(x)[source]#
Computes the shape s such that x.view(s) is a 2-dim tensor with batches at dimension 0 and any other feature at dimension 1.
- Parameters:
x (Tensor) – Input tensor with batches at dimension 0.
- Return type:
- Returns:
A tuple containing the 2-dim shape.
Examples
>>> over_batch_over_tensor(torch.randn([2, 3, 4, 3])) (2, -1)
- brevitas.function.shape.over_output_channels(x)[source]#
Computes the shape s such that x.view(s) is a 2-dim tensor with output channels at dimension 0 and any other feature at dimension 1.
Args: x (Tensor): Input tensor with output channels at dimension 0.
Examples
>>> over_output_channels(torch.randn([2, 3, 4, 3])) (2, -1)
- brevitas.function.shape.over_output_features(x)[source]#
Returns a shape s such that x.view(s) is a 2-dim tensor with all features except the last one at dimension 0.
- Parameters:
x (Tensor) – Input tensor with batches at dimension 0 and output channels at dimension 1.
- Returns:
A tuple containing the 2-dim shape.
Examples
>>> over_output_features(torch.randn([2, 3, 4, 3])) (24, 3)