Source code for brevitas.core.scaling.runtime

# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import List
from typing import Optional
from typing import Tuple

import torch
from torch.nn import Module
from torch.nn import Parameter

import brevitas
import brevitas.config as config
from brevitas.core.function_wrapper import Identity
from brevitas.core.restrict_val import _ClampValue
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.stats import _ParameterListStats
from brevitas.core.stats import _RuntimeStats
from brevitas.core.stats import DEFAULT_MOMENTUM
from brevitas.core.utils import ParameterWrapper
from brevitas.core.utils import StatelessBuffer


[docs]class StatsFromParameterScaling(brevitas.jit.ScriptModule): def __init__( self, scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module, scaling_stats_input_concat_dim: int, tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], force_parameter: bool = False, restrict_scaling_impl: Module = FloatRestrictValue(), restrict_threshold_impl: Optional[Module] = None, scaling_affine_rescaling_init: Optional[float] = None, scaling_affine_shifting_init: Optional[float] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(StatsFromParameterScaling, self).__init__() # Ensure retro-compatibility with shared threshold/scaling restrict if restrict_threshold_impl is None: restrict_threshold_impl = restrict_scaling_impl self.parameter_list_stats = _ParameterListStats( scaling_stats_impl, scaling_shape, scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list, force_parameter) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, restrict_threshold_impl, scaling_min_val, scaling_shape, scaling_affine_rescaling_init, scaling_affine_shifting_init, dtype, device)
[docs] @brevitas.jit.script_method def forward( self, x: Optional[torch.Tensor], threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.parameter_list_stats(x) if threshold is None: threshold = torch.ones(1).type_as(stats) return self.stats_scaling_impl(stats, threshold)
class _StatsScaling(brevitas.jit.ScriptModule): def __init__( self, restrict_scaling_impl: Module, restrict_threshold_impl: Module, scaling_min_val: Optional[float], scaling_shape: Tuple[int, ...], scaling_affine_rescaling_init: Optional[float], scaling_affine_shifting_init: Optional[float], dtype: Optional[torch.dtype], device: Optional[torch.device]) -> None: super(_StatsScaling, self).__init__() _affine_rescaling = scaling_affine_rescaling_init is not None _affine_shift_scale = scaling_affine_shifting_init is not None if _affine_shift_scale and not _affine_rescaling: raise RuntimeError( "Enabling shifting of the scale requires enabling affine rescaling first.") if _affine_rescaling: self.affine_rescaling = _AffineRescaling( scaling_shape, scaling_affine_rescaling_init, scaling_affine_shifting_init, dtype, device) else: self.affine_rescaling = Identity() self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) self.restrict_clamp_threshold = _RestrictClampValue( restrict_value_impl=restrict_threshold_impl) self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() self.clamp_scaling = _ClampValue(scaling_min_val) @brevitas.jit.script_method def forward( self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: if threshold is None: threshold = torch.ones(1).type_as(stats) threshold = self.restrict_threshold_pre(threshold) threshold = self.restrict_clamp_threshold(threshold) # Clamping avoids eventual log(0) with restrict_val stats = self.clamp_scaling(stats) stats = self.restrict_scaling_pre(stats) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) stats = stats / threshold return stats
[docs]class RuntimeStatsScaling(brevitas.jit.ScriptModule): def __init__( self, scaling_stats_impl: Module, scaling_stats_input_view_shape_impl: Module, scaling_shape: Tuple[int, ...], scaling_affine_rescaling_init: Optional[float] = None, scaling_affine_shifting_init: Optional[float] = None, restrict_scaling_impl: Module = FloatRestrictValue(), restrict_threshold_impl: Optional[Module] = None, scaling_stats_momentum: float = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(RuntimeStatsScaling, self).__init__() # Ensure retro-compatibility with shared threshold/scaling restrict if restrict_threshold_impl is None: restrict_threshold_impl = restrict_scaling_impl self.runtime_stats = _RuntimeStats( scaling_stats_impl, scaling_shape, scaling_stats_input_view_shape_impl, scaling_stats_momentum, dtype, device) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, restrict_threshold_impl, scaling_min_val, scaling_shape, scaling_affine_rescaling_init, scaling_affine_shifting_init, dtype, device)
[docs] @brevitas.jit.script_method def forward(self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.runtime_stats(x) return self.stats_scaling_impl(stats, threshold)
class _AffineRescaling(brevitas.jit.ScriptModule): def __init__( self, scaling_shape, affine_weight_init: float, affine_bias_init: Optional[float], dtype: Optional[torch.dtype], device: Optional[torch.device]): super(_AffineRescaling, self).__init__() self.affine_weight = Parameter( torch.full(scaling_shape, affine_weight_init, dtype=dtype, device=device)) if affine_bias_init is not None: self.affine_bias = ParameterWrapper( torch.full(scaling_shape, affine_bias_init, dtype=dtype, device=device)) else: self.affine_bias = StatelessBuffer(torch.tensor(0., dtype=dtype, device=device)) @brevitas.jit.script_method def forward(self, x): out = x * self.affine_weight + self.affine_bias() return out def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): super(_AffineRescaling, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) affine_weight_key = prefix + 'affine_weight' affine_bias_key = prefix + 'affine_bias' if config.IGNORE_MISSING_KEYS and affine_weight_key in missing_keys: missing_keys.remove(affine_weight_key) if config.IGNORE_MISSING_KEYS and affine_bias_key in missing_keys: missing_keys.remove(affine_bias_key)
[docs]class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): def __init__( self, group_size: int, group_dim: int, input_view_impl: Module, scaling_stats_impl: Module, scaling_min_val: Optional[float], restrict_scaling_impl: Module = FloatRestrictValue(), restrict_threshold_impl: Optional[Module] = None) -> None: super(RuntimeDynamicGroupStatsScaling, self).__init__() # Ensure retro-compatibility with shared threshold/scaling restrict if restrict_threshold_impl is None: restrict_threshold_impl = restrict_scaling_impl self.group_size = group_size self.group_dim = group_dim self.scaling_stats_impl = scaling_stats_impl self.scaling_min_val = scaling_min_val self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) self.restrict_clamp_threshold = _RestrictClampValue( restrict_value_impl=restrict_threshold_impl) self.restrict_scaling_pre = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module( ) self.restrict_threshold_pre = self.restrict_clamp_threshold.restrict_value_impl.restrict_init_module( ) self.clamp_scaling = _ClampValue(scaling_min_val)
[docs] @brevitas.jit.script_method def forward( self, stats_input: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: if threshold is None: threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) out = self.scaling_stats_impl(stats_input_reshaped) # Clamping avoids eventual log(0) with restrict_val out = self.clamp_scaling(out) # Apply restrict_value preprocess out = self.restrict_scaling_pre(out) # Apply restrict_value and clamping out = self.restrict_clamp_scaling(out) / threshold return out