Source code for brevitas.core.function_wrapper.ops_ste
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
"""
ScriptModule wrappers of various functions defined in :obj:`~brevitas.function.ops_ste`.
"""
import torch
import brevitas
from brevitas.function.ops_ste import *
[docs]class RoundSte(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~brevitas.function.ops_ste.round_ste`.
"""
def __init__(self) -> None:
super(RoundSte, self).__init__()
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor):
return round_ste(x)
[docs]class FloorSte(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~brevitas.function.ops_ste.floor_ste`.
"""
def __init__(self) -> None:
super(FloorSte, self).__init__()
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor):
return floor_ste(x)
[docs]class RoundToZeroSte(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~brevitas.function.ops_ste.round_to_zero_ste`.
"""
def __init__(self) -> None:
super(RoundToZeroSte, self).__init__()
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor):
return round_to_zero_ste(x)
[docs]class DPURoundSte(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~brevitas.function.ops_ste.dpu_round_ste`.
"""
def __init__(self) -> None:
super(DPURoundSte, self).__init__()
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor):
return dpu_round_ste(x)
[docs]class CeilSte(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~brevitas.function.ops_ste.ceil_ste`.
"""
def __init__(self) -> None:
super(CeilSte, self).__init__()
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor):
return ceil_ste(x)
[docs]class ScalarClampMinSte(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~brevitas.function.ops_ste.scalar_clamp_min_ste`.
"""
__constants__ = ['min_val']
def __init__(self, min_val: float) -> None:
super(ScalarClampMinSte, self).__init__()
self.min_val = min_val
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor):
return scalar_clamp_min_ste(x, self.min_val)
[docs]class ScalarSignedClampMinSte(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~brevitas.function.ops_ste.scalar_clamp_min_ste`.
"""
__constants__ = ['min_val']
def __init__(self, min_val: float) -> None:
super(ScalarSignedClampMinSte, self).__init__()
# Verify that the minimum value is set to a non-zero value, as when min_val == 0.0,
# this module implements the identity but the gradient returned at x = 0.0, is zero,
# instead of 1.
assert abs(min_val) > 0.0, "min_val has to be greater than zero."
self.min_val = abs(min_val)
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor):
# NOTE: The previous implementation of this operation was
# torch.copysign(scalar_clamp_min_ste(abs_binary_sign_grad(x), self.min_val), x) which is more
# readable but resulted in a -1. gradient when x = -0.0, since torch.copysign distinguishes
# between positive and negative zero.
return torch.where(x >= 0, 1., -1.).type_as(x) * scalar_clamp_min_ste(
abs_binary_sign_grad(x), self.min_val)
[docs]class TensorClampSte(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~brevitas.function.ops_ste.tensor_clamp_ste`.
"""
def __init__(self) -> None:
super(TensorClampSte, self).__init__()
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor):
return tensor_clamp_ste(x, min_val, max_val)
[docs]class InplaceTensorClampSte(brevitas.jit.ScriptModule):
"""
ScriptModule wrapper for :func:`~brevitas.function.ops_ste.tensor_clamp_ste_`.
"""
def __init__(self) -> None:
super(InplaceTensorClampSte, self).__init__()
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor):
return tensor_clamp_ste_(x, min_val, max_val)