brevitas.ops package#

Submodules#

brevitas.ops.autograd_ste_ops module#

Implementation of various torch.autograd.Function with straight-through estimators.

class brevitas.ops.autograd_ste_ops.AbsBinarySignGradFn(*args, **kwargs)[source]#

Bases: Function

Autograd function that implements torch.abs() with a binary-sign backward, in order to have subgradient 1 in 0. Compare with torch.abs()’ subgradient of 0 in 0.

AbsBinarySignGradFn.apply(*args) is first aliased to abs_binary_sign_grad(*args) and then wrapped by abs_binary_sign_grad() when env BREVITAS_JIT=0. See abs_binary_sign_grad() for details on the interface and examples.

class brevitas.ops.autograd_ste_ops.BinarySignSteFn(*args, **kwargs)[source]#

Bases: Function

Autograd function that implements binary_sign() with a straight-through gradient estimator.

BinarySignSteFn.apply(*args) is first aliased to binary_sign_ste_impl(*args) and then wrapped by binary_sign_ste() when env BREVITAS_JIT=0. See binary_sign_ste() for details on the interface and examples.

class brevitas.ops.autograd_ste_ops.CeilSteFn(*args, **kwargs)[source]#

Bases: Function

Autograd function that implements torch.ceil() with a straight-through gradient estimator.

CeilSteFn.apply(*args) is first aliased to ceil_ste_impl(*args) and then wrapped by ceil_ste() when env BREVITAS_JIT=0. See ceil_ste() for details on the interface and examples.

class brevitas.ops.autograd_ste_ops.DPURoundSteFn(*args, **kwargs)[source]#

Bases: Function

Autograd function that implements dpu_round() with a straight-through gradient estimator.

DPURoundSteFn.apply(*args) is first aliased to dpu_round_ste_impl(*args) and then wrapped by dpu_round_ste() when env BREVITAS_JIT=0. See dpu_round_ste() for details on the interface and examples.

class brevitas.ops.autograd_ste_ops.FloorSteFn(*args, **kwargs)[source]#

Bases: Function

Autograd function that implements torch.floor() with a straight-through gradient estimator.

FloorSteFn.apply(*args) is first aliased to floor_ste_impl(*args) and then wrapped by floor_ste() when env BREVITAS_JIT=0. See floor_ste() for details on the interface and examples.

class brevitas.ops.autograd_ste_ops.InplaceTensorClampSteFn(*args, **kwargs)[source]#

Bases: Function

Autograd function that implements tensor_clamp_() with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val and max_val is always None.

InplaceTensorClampSteFn.apply(*args) is first aliased to tensor_clamp_ste_impl_(*args) and then wrapped by tensor_clamp_() when env BREVITAS_JIT=0. See tensor_clamp_() for details on the interface and examples.

class brevitas.ops.autograd_ste_ops.RoundSteFn(*args, **kwargs)[source]#

Bases: Function

Autograd function that implements torch.round() with a straight-through gradient estimator.

RoundSteFn.apply(*args) is first aliased to round_ste_impl(*args) and then wrapped by round_ste() when env BREVITAS_JIT=0. See round_ste() for details on the interface and examples.

class brevitas.ops.autograd_ste_ops.RoundToZeroSteFn(*args, **kwargs)[source]#

Bases: Function

Autograd function that implements round_to_zero() with a straight-through gradient estimator.

RoundToZeroSteFn.apply(*args) is first aliased to round_to_zero_ste_impl(*args) and then wrapped by round_to_zero_ste() when env BREVITAS_JIT=0. See round_to_zero_ste() for details on the interface and examples.

class brevitas.ops.autograd_ste_ops.ScalarClampMinSteFn(*args, **kwargs)[source]#

Bases: Function

Autograd function that implements torch.clamp_min with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val is always None.

ScalarClampMinSteFn.apply(*args) is first aliased to scalar_clamp_min_ste_impl(*args) and then wrapped by scalar_clamp_min_ste() and invoked when env BREVITAS_JIT=0. See scalar_clamp_ste() for details on the interface and examples.

class brevitas.ops.autograd_ste_ops.ScalarClampSteFn(*args, **kwargs)[source]#

Bases: Function

Autograd function that implements torch.clamp with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val and min_val are always None.

ScalarClampSteFn.apply(*args) is first aliased to scalar_clamp_ste_impl(*args) and then wrapped by scalar_clamp_ste() and invoked when env BREVITAS_JIT=0. See scalar_clamp_ste() for details on the interface and examples.

class brevitas.ops.autograd_ste_ops.TensorClampSteFn(*args, **kwargs)[source]#

Bases: Function

Autograd function that implements tensor_clamp() with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val and max_val is always None.

TensorClampSteFn.apply(*args) is first aliased to tensor_clamp_ste_impl(*args) and then wrapped by tensor_clamp() when env BREVITAS_JIT=0. See tensor_clamp() for details on the interface and examples.

class brevitas.ops.autograd_ste_ops.TernarySignSteFn(*args, **kwargs)[source]#

Bases: Function

Autograd function that implements torch.sign() with a straight-through gradient estimator.

TernarySignSteFn.apply(*args) is first aliased to ternary_sign_ste_impl(*args) and then wrapped by ternary_sign_ste() when env BREVITAS_JIT=0. See ternary_sign_ste() for details on the interface and examples.

brevitas.ops.autograd_ste_ops.abs_binary_sign_grad_impl(*args, **kwargs)#

Alias for AbsBinarySignGradFn.apply(*args)

brevitas.ops.autograd_ste_ops.binary_sign_ste_impl(*args, **kwargs)#

Alias for BinarySignSteFn.apply(*args)

brevitas.ops.autograd_ste_ops.ceil_ste_impl(*args, **kwargs)#

Alias for CeilSteFn.apply(*args)

brevitas.ops.autograd_ste_ops.dpu_round_ste_impl(*args, **kwargs)#

Alias for DPURoundSteFn.apply(*args)

brevitas.ops.autograd_ste_ops.floor_ste_impl(*args, **kwargs)#

Alias for FloorSteFn.apply(*args)

brevitas.ops.autograd_ste_ops.round_ste_impl(*args, **kwargs)#

Alias for RoundSteFn.apply(*args)

brevitas.ops.autograd_ste_ops.round_to_zero_ste_impl(*args, **kwargs)#

Alias for RoundToZeroSteFn.apply(*args)

brevitas.ops.autograd_ste_ops.scalar_clamp_min_ste_impl(*args, **kwargs)#

Alias for ScalarClampMinSteFn.apply(*args)

brevitas.ops.autograd_ste_ops.scalar_clamp_ste_impl(*args, **kwargs)#

Alias for ScalarClampSteFn.apply(*args)

brevitas.ops.autograd_ste_ops.tensor_clamp_ste_impl(*args, **kwargs)#

Alias for TensorClampSteFn.apply(*args)

brevitas.ops.autograd_ste_ops.ternary_sign_ste_impl(*args, **kwargs)#

Alias for TernarySignSteFn.apply(*args)