# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
import math
from typing import Callable
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from torch import Tensor
from torch.nn import Module
import brevitas
from brevitas.core.function_wrapper import Abs
from brevitas.core.function_wrapper import Identity
from brevitas.core.function_wrapper import InplaceLogTwo
from brevitas.core.function_wrapper import LogTwo
from brevitas.core.function_wrapper import PowerOfTwo
from brevitas.core.function_wrapper import RoundSte
from brevitas.core.function_wrapper import ScalarSignedClampMinSte
from brevitas.inject.enum import FloatToIntImplType # retrocompatibility
from brevitas.inject.enum import RestrictValueType
assert RestrictValueType # prevent removal of unused import
assert FloatToIntImplType
class _RestrictClampValue(brevitas.jit.ScriptModule):
def __init__(
self,
scaling_min_val: Optional[float] = None,
restrict_value_impl: Optional[Module] = None,
is_unsigned=True):
super(_RestrictClampValue, self).__init__()
if scaling_min_val is not None and scaling_min_val != 0:
self.clamp_min_ste = ScalarSignedClampMinSte(scaling_min_val)
else:
self.clamp_min_ste = Identity()
if restrict_value_impl is not None:
self.restrict_value_impl = restrict_value_impl
else:
self.restrict_value_impl = Identity()
if is_unsigned:
self.apply_abs = Abs()
else:
self.apply_abs = Identity()
@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.restrict_value_impl(x)
x = self.clamp_min_ste(x)
x = self.apply_abs(x)
return x
class _RestrictValue(brevitas.jit.ScriptModule):
def __init__(self, restrict_value_impl: Optional[Module]):
super(_RestrictValue, self).__init__()
if restrict_value_impl is not None:
self.restrict_value_impl = restrict_value_impl
else:
self.restrict_value_impl = Identity()
@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.restrict_value_impl(x)
return x
class _ClampValue(brevitas.jit.ScriptModule):
def __init__(self, scaling_min_val: Optional[float]):
super(_ClampValue, self).__init__()
if scaling_min_val is not None and scaling_min_val != 0:
self.clamp_min_ste = ScalarSignedClampMinSte(scaling_min_val)
else:
self.clamp_min_ste = Identity()
self.min_val = scaling_min_val
@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.clamp_min_ste(x)
return x
class _AbsValue(brevitas.jit.ScriptModule):
def __init__(self, is_unsigned: bool = True):
super(_AbsValue, self).__init__()
if is_unsigned:
self.apply_abs = Abs()
else:
self.apply_abs = Identity()
@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.apply_abs(x)
return x
[docs]class FloatRestrictValue(brevitas.jit.ScriptModule):
def __init__(self) -> None:
super(FloatRestrictValue, self).__init__()
[docs] def restrict_init_float(self, x: float) -> float:
return x
[docs] def restrict_init_tensor(self, x: Tensor) -> Tensor:
return x
[docs] def restrict_init_module(self):
return Identity()
[docs] def restrict_init_inplace_module(self):
return Identity()
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
return x
[docs]class LogFloatRestrictValue(brevitas.jit.ScriptModule):
def __init__(self):
super(LogFloatRestrictValue, self).__init__()
self.power_of_two: Module = PowerOfTwo()
[docs] def restrict_init_float(self, x: float):
return math.log2(x)
[docs] def restrict_init_tensor(self, x: Tensor):
return torch.log2(x)
[docs] def restrict_init_module(self):
return LogTwo()
[docs] def restrict_init_inplace_module(self):
return InplaceLogTwo()
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.power_of_two(x)
return x
[docs]class IntRestrictValue(brevitas.jit.ScriptModule):
def __init__(self, restrict_value_float_to_int_impl: Module = RoundSte()):
super(IntRestrictValue, self).__init__()
self.float_to_int_impl = restrict_value_float_to_int_impl
[docs] def restrict_init_float(self, x: float):
return x
[docs] def restrict_init_tensor(self, x: Tensor):
return x
[docs] def restrict_init_module(self):
return Identity()
[docs] def restrict_init_inplace_module(self):
return Identity()
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
return x
[docs]class PowerOfTwoRestrictValue(brevitas.jit.ScriptModule):
def __init__(self, restrict_value_float_to_int_impl: Module = RoundSte()):
super(PowerOfTwoRestrictValue, self).__init__()
self.float_to_int_impl = restrict_value_float_to_int_impl
self.power_of_two: Module = PowerOfTwo()
[docs] def restrict_init_float(self, x: float):
return math.log2(x)
[docs] def restrict_init_tensor(self, x: Tensor):
return torch.log2(x)
[docs] def restrict_init_module(self):
return LogTwo()
[docs] def restrict_init_inplace_module(self):
return InplaceLogTwo()
[docs] @brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
x = self.power_of_two(x)
return x
[docs]class QuantRestrictValue(brevitas.jit.ScriptModule):
def __init__(
self,
restrict_value_float_to_int_impl: Module,
scaling_shape: Tuple[int, ...],
scale_dequantized_shape: Optional[Tuple[int, ...]]):
super(QuantRestrictValue, self).__init__()
self.float_to_int_impl = restrict_value_float_to_int_impl
self.scaling_shape = scaling_shape
self.scale_dequantized_shape = scale_dequantized_shape
[docs] def restrict_init_float(self, x: float):
return Identity()
[docs] def restrict_init_tensor(self, x: torch.Tensor):
return Identity()
[docs] def restrict_init_module(self):
return Identity()
[docs] def restrict_init_inplace_module(self):
return Identity()
[docs] def retrocompatibility_op(self, x):
return Identity()
[docs] @brevitas.jit.script_method
def forward(self, x: torch.Tensor):
o, *_ = self.float_to_int_impl(x)
# We need to go back to the dequantized shape, relevant for groupwise quantization
if self.scale_dequantized_shape is not None:
o = o.view(self.scale_dequantized_shape)
return o