Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 96825bb

Browse files
committed
fix nits from deletion of Float8DynamicLinear
Summary: Addressing a couple of nits that slipped in #304 * more defaults to dynamic * undo repr change * fix comment Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 8e9623a commit 96825bb

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

float8_experimental/float8_linear.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ def __init__(self, *args, **kwargs):
165165
# Amax scales should always be kept as float32.
166166
self.always_float32_buffers = set()
167167
emulate = kwargs.pop("emulate", False)
168-
scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DELAYED)
169-
scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DELAYED)
170-
scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DELAYED)
168+
scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DYNAMIC)
169+
scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DYNAMIC)
170+
scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DYNAMIC)
171171
super().__init__(*args, **kwargs)
172172

173173
# Defines the scaling behavior of x, w, dL_dY
@@ -402,8 +402,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
402402

403403
def scaling_repr(self):
404404
# add scaling settings without using too many characters
405-
# example: "x_del_w_del_dldy_dyn"
406-
return f"x_{self.scaling_type_x.short_str()}_w_{self.scaling_type_w.short_str()}_dldy_{self.scaling_type_dL_dY.short_str()}"
405+
# example: "x:del,w:del,dldy:dyn"
406+
return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}"
407407

408408
def extra_repr(self):
409409
s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"'
@@ -414,9 +414,9 @@ def from_float(
414414
cls,
415415
mod,
416416
emulate: bool = False,
417-
scaling_type_x=TensorScalingType.DELAYED,
418-
scaling_type_w=TensorScalingType.DELAYED,
419-
scaling_type_dL_dY=TensorScalingType.DELAYED,
417+
scaling_type_x=TensorScalingType.DYNAMIC,
418+
scaling_type_w=TensorScalingType.DYNAMIC,
419+
scaling_type_dL_dY=TensorScalingType.DYNAMIC,
420420
):
421421
"""
422422
Create an nn.Linear with fp8 compute from a regular nn.Linear

test/test_dtensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ def _test_fp8_mlp_tensor_parallelism_base(
171171
mesh: DeviceMesh, size=16, compile: bool = False
172172
):
173173
device = mesh.device_type
174-
# For now, just use Float8Linear with dynamic scaling, which is the
175-
# same behavior as Float8Linear.
174+
# For now, only supports dynamic scaling of `x` and `dL_dY`.
176175
# TODO(future): add support for float8 all-gather with delayed scaling
177176
# for activations and gradients.
178177
extra_kwargs = {

0 commit comments

Comments
 (0)