Source code for brevitas.core.bit_width.const
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Optional
import torch
from torch import Tensor
from torch.nn import Module
import brevitas
import brevitas.config as config
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops_ste import tensor_clamp_ste
[docs]class BitWidthConst(brevitas.jit.ScriptModule):
"""
ScriptModule that returns a constant bit-width wrapped in a float torch.tensor.
Args:
bit_width (int): bit-width value.
Examples:
>>> bit_width = BitWidthConst(8)
>>> bit_width()
tensor(8.)
Note:
The bit-width is not part of the Module's state, meaning that it won't be saved as part of
a checkpoint.
Note:
Maps to bit_width_impl_type == BitWidthImplType.CONST == 'CONST' == 'const' in higher-level APIs.
"""
def __init__(
self,
bit_width: int,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
super(BitWidthConst, self).__init__()
assert isinstance(bit_width, int)
self.bit_width = StatelessBuffer(torch.tensor(float(bit_width), dtype=dtype, device=device))
[docs] @brevitas.jit.script_method
def forward(self) -> Tensor:
return self.bit_width()
[docs]class BitWidthStatefulConst(brevitas.jit.ScriptModule):
"""
ScriptModule that returns a constant bit-width wrapped in a float torch.tensor but retains the
bit-width as part of the module state.
Args:
bit_width (int): bit-width value.
Examples:
>>> bit_width = BitWidthStatefulConst(8)
>>> bit_width()
tensor(8.)
Note:
The BitWidthStatefulConst is a counterpart to BitWidthConst with the difference that the
BitWidthStatefulConst retains the bit-width as part of the Module's state. This means that it
will be saved as part of a checkpoint.
Note:
Maps to bit_width_impl_type == BitWidthImplType.STATEFUL_CONST == 'STATEFUL_CONST' ==
'stateful_const' in higher-level APIs.
"""
def __init__(
self,
bit_width: int,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
super(BitWidthStatefulConst, self).__init__()
assert isinstance(bit_width, int)
self.register_buffer(
"bit_width", torch.tensor(float(bit_width), dtype=dtype, device=device))
[docs] @brevitas.jit.script_method
def forward(self) -> Tensor:
return self.bit_width
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys: list,
unexpected_keys,
error_msgs):
super(BitWidthStatefulConst, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
value_key = prefix + "bit_width"
if config.IGNORE_MISSING_KEYS and value_key in missing_keys:
missing_keys.remove(value_key)
[docs]class MsbClampBitWidth(brevitas.jit.ScriptModule):
def __init__(
self,
bit_width_to_remove_impl: Module,
min_overall_bit_width: int,
max_overall_bit_width: int) -> None:
super(MsbClampBitWidth, self).__init__()
self.min_overall_bit_width = BitWidthConst(min_overall_bit_width)
self.max_overall_bit_width = BitWidthConst(max_overall_bit_width)
self.bit_width_to_remove_impl = bit_width_to_remove_impl
[docs] @brevitas.jit.script_method
def forward(self, input_bit_width: Tensor) -> Tensor:
bit_width_to_remove = self.bit_width_to_remove_impl()
output_bit_width = torch.abs(input_bit_width - bit_width_to_remove)
output_bit_width = tensor_clamp_ste(
output_bit_width, self.min_overall_bit_width(), self.max_overall_bit_width())
return output_bit_width