@@ -165,9 +165,9 @@ def __init__(self, *args, **kwargs):
165
165
# Amax scales should always be kept as float32.
166
166
self .always_float32_buffers = set ()
167
167
emulate = kwargs .pop ("emulate" , False )
168
- scaling_type_x = kwargs .pop ("scaling_type_x" , TensorScalingType .DELAYED )
169
- scaling_type_w = kwargs .pop ("scaling_type_w" , TensorScalingType .DELAYED )
170
- scaling_type_dL_dY = kwargs .pop ("scaling_type_dL_dY" , TensorScalingType .DELAYED )
168
+ scaling_type_x = kwargs .pop ("scaling_type_x" , TensorScalingType .DYNAMIC )
169
+ scaling_type_w = kwargs .pop ("scaling_type_w" , TensorScalingType .DYNAMIC )
170
+ scaling_type_dL_dY = kwargs .pop ("scaling_type_dL_dY" , TensorScalingType .DYNAMIC )
171
171
super ().__init__ (* args , ** kwargs )
172
172
173
173
# Defines the scaling behavior of x, w, dL_dY
@@ -402,8 +402,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
402
402
403
403
def scaling_repr (self ):
404
404
# add scaling settings without using too many characters
405
- # example: "x_del_w_del_dldy_dyn "
406
- return f"x_ { self .scaling_type_x .short_str ()} _w_ { self .scaling_type_w .short_str ()} _dldy_ { self .scaling_type_dL_dY .short_str ()} "
405
+ # example: "x:del,w:del,dldy:dyn "
406
+ return f"x: { self .scaling_type_x .short_str ()} ,w: { self .scaling_type_w .short_str ()} ,dldy: { self .scaling_type_dL_dY .short_str ()} "
407
407
408
408
def extra_repr (self ):
409
409
s = f'{ super ().extra_repr ()} , scaling="{ self .scaling_repr ()} "'
@@ -414,9 +414,9 @@ def from_float(
414
414
cls ,
415
415
mod ,
416
416
emulate : bool = False ,
417
- scaling_type_x = TensorScalingType .DELAYED ,
418
- scaling_type_w = TensorScalingType .DELAYED ,
419
- scaling_type_dL_dY = TensorScalingType .DELAYED ,
417
+ scaling_type_x = TensorScalingType .DYNAMIC ,
418
+ scaling_type_w = TensorScalingType .DYNAMIC ,
419
+ scaling_type_dL_dY = TensorScalingType .DYNAMIC ,
420
420
):
421
421
"""
422
422
Create an nn.Linear with fp8 compute from a regular nn.Linear
0 commit comments