brevitas.ops package#
Submodules#
brevitas.ops.autograd_ste_ops module#
Implementation of various torch.autograd.Function with straight-through estimators.
- class brevitas.ops.autograd_ste_ops.AbsBinarySignGradFn(*args, **kwargs)[source]#
Bases:
Function
Autograd function that implements
torch.abs()
with a binary-sign backward, in order to have subgradient 1 in 0. Compare withtorch.abs()
’ subgradient of 0 in 0.AbsBinarySignGradFn.apply(*args)
is first aliased toabs_binary_sign_grad(*args)
and then wrapped byabs_binary_sign_grad()
when envBREVITAS_JIT=0
. Seeabs_binary_sign_grad()
for details on the interface and examples.
- class brevitas.ops.autograd_ste_ops.BinarySignSteFn(*args, **kwargs)[source]#
Bases:
Function
Autograd function that implements
binary_sign()
with a straight-through gradient estimator.BinarySignSteFn.apply(*args)
is first aliased tobinary_sign_ste_impl(*args)
and then wrapped bybinary_sign_ste()
when envBREVITAS_JIT=0
. Seebinary_sign_ste()
for details on the interface and examples.
- class brevitas.ops.autograd_ste_ops.CeilSteFn(*args, **kwargs)[source]#
Bases:
Function
Autograd function that implements
torch.ceil()
with a straight-through gradient estimator.CeilSteFn.apply(*args)
is first aliased toceil_ste_impl(*args)
and then wrapped byceil_ste()
when envBREVITAS_JIT=0
. Seeceil_ste()
for details on the interface and examples.
- class brevitas.ops.autograd_ste_ops.DPURoundSteFn(*args, **kwargs)[source]#
Bases:
Function
Autograd function that implements
dpu_round()
with a straight-through gradient estimator.DPURoundSteFn.apply(*args)
is first aliased todpu_round_ste_impl(*args)
and then wrapped bydpu_round_ste()
when envBREVITAS_JIT=0
. Seedpu_round_ste()
for details on the interface and examples.
- class brevitas.ops.autograd_ste_ops.FloorSteFn(*args, **kwargs)[source]#
Bases:
Function
Autograd function that implements
torch.floor()
with a straight-through gradient estimator.FloorSteFn.apply(*args)
is first aliased tofloor_ste_impl(*args)
and then wrapped byfloor_ste()
when envBREVITAS_JIT=0
. Seefloor_ste()
for details on the interface and examples.
- class brevitas.ops.autograd_ste_ops.InplaceTensorClampSteFn(*args, **kwargs)[source]#
Bases:
Function
Autograd function that implements
tensor_clamp_()
with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val and max_val is always None.InplaceTensorClampSteFn.apply(*args)
is first aliased totensor_clamp_ste_impl_(*args)
and then wrapped bytensor_clamp_()
when envBREVITAS_JIT=0
. Seetensor_clamp_()
for details on the interface and examples.
- class brevitas.ops.autograd_ste_ops.RoundSteFn(*args, **kwargs)[source]#
Bases:
Function
Autograd function that implements
torch.round()
with a straight-through gradient estimator.RoundSteFn.apply(*args)
is first aliased toround_ste_impl(*args)
and then wrapped byround_ste()
when envBREVITAS_JIT=0
. Seeround_ste()
for details on the interface and examples.
- class brevitas.ops.autograd_ste_ops.RoundToZeroSteFn(*args, **kwargs)[source]#
Bases:
Function
Autograd function that implements
round_to_zero()
with a straight-through gradient estimator.RoundToZeroSteFn.apply(*args)
is first aliased toround_to_zero_ste_impl(*args)
and then wrapped byround_to_zero_ste()
when envBREVITAS_JIT=0
. Seeround_to_zero_ste()
for details on the interface and examples.
- class brevitas.ops.autograd_ste_ops.ScalarClampMinSteFn(*args, **kwargs)[source]#
Bases:
Function
Autograd function that implements
torch.clamp_min
with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. tomin_val
is alwaysNone
.ScalarClampMinSteFn.apply(*args)
is first aliased toscalar_clamp_min_ste_impl(*args)
and then wrapped byscalar_clamp_min_ste()
and invoked when envBREVITAS_JIT=0
. Seescalar_clamp_ste()
for details on the interface and examples.
- class brevitas.ops.autograd_ste_ops.ScalarClampSteFn(*args, **kwargs)[source]#
Bases:
Function
Autograd function that implements
torch.clamp
with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. tomin_val
andmin_val
are alwaysNone
.ScalarClampSteFn.apply(*args)
is first aliased toscalar_clamp_ste_impl(*args)
and then wrapped byscalar_clamp_ste()
and invoked when envBREVITAS_JIT=0
. Seescalar_clamp_ste()
for details on the interface and examples.
- class brevitas.ops.autograd_ste_ops.TensorClampSteFn(*args, **kwargs)[source]#
Bases:
Function
Autograd function that implements
tensor_clamp()
with a straight-through gradient estimator for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to min_val and max_val is always None.TensorClampSteFn.apply(*args)
is first aliased totensor_clamp_ste_impl(*args)
and then wrapped bytensor_clamp()
when envBREVITAS_JIT=0
. Seetensor_clamp()
for details on the interface and examples.
- class brevitas.ops.autograd_ste_ops.TernarySignSteFn(*args, **kwargs)[source]#
Bases:
Function
Autograd function that implements
torch.sign()
with a straight-through gradient estimator.TernarySignSteFn.apply(*args)
is first aliased toternary_sign_ste_impl(*args)
and then wrapped byternary_sign_ste()
when envBREVITAS_JIT=0
. Seeternary_sign_ste()
for details on the interface and examples.
- brevitas.ops.autograd_ste_ops.abs_binary_sign_grad_impl(*args, **kwargs)#
Alias for
AbsBinarySignGradFn.apply(*args)
- brevitas.ops.autograd_ste_ops.binary_sign_ste_impl(*args, **kwargs)#
Alias for
BinarySignSteFn.apply(*args)
- brevitas.ops.autograd_ste_ops.ceil_ste_impl(*args, **kwargs)#
Alias for
CeilSteFn.apply(*args)
- brevitas.ops.autograd_ste_ops.dpu_round_ste_impl(*args, **kwargs)#
Alias for
DPURoundSteFn.apply(*args)
- brevitas.ops.autograd_ste_ops.floor_ste_impl(*args, **kwargs)#
Alias for
FloorSteFn.apply(*args)
- brevitas.ops.autograd_ste_ops.round_ste_impl(*args, **kwargs)#
Alias for
RoundSteFn.apply(*args)
- brevitas.ops.autograd_ste_ops.round_to_zero_ste_impl(*args, **kwargs)#
Alias for
RoundToZeroSteFn.apply(*args)
- brevitas.ops.autograd_ste_ops.scalar_clamp_min_ste_impl(*args, **kwargs)#
Alias for
ScalarClampMinSteFn.apply(*args)
- brevitas.ops.autograd_ste_ops.scalar_clamp_ste_impl(*args, **kwargs)#
Alias for
ScalarClampSteFn.apply(*args)
- brevitas.ops.autograd_ste_ops.tensor_clamp_ste_impl(*args, **kwargs)#
Alias for
TensorClampSteFn.apply(*args)
- brevitas.ops.autograd_ste_ops.ternary_sign_ste_impl(*args, **kwargs)#
Alias for
TernarySignSteFn.apply(*args)