Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

fix nits from deletion of Float8DynamicLinear #308

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ def __init__(self, *args, **kwargs):
# Amax scales should always be kept as float32.
self.always_float32_buffers = set()
emulate = kwargs.pop("emulate", False)
scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DELAYED)
scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DELAYED)
scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DELAYED)
scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DYNAMIC)
scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DYNAMIC)
scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DYNAMIC)
super().__init__(*args, **kwargs)

# Defines the scaling behavior of x, w, dL_dY
Expand Down Expand Up @@ -402,8 +402,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

def scaling_repr(self):
# add scaling settings without using too many characters
# example: "x_del_w_del_dldy_dyn"
return f"x_{self.scaling_type_x.short_str()}_w_{self.scaling_type_w.short_str()}_dldy_{self.scaling_type_dL_dY.short_str()}"
# example: "x:del,w:del,dldy:dyn"
return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}"

def extra_repr(self):
s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"'
Expand All @@ -414,9 +414,9 @@ def from_float(
cls,
mod,
emulate: bool = False,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
scaling_type_x=TensorScalingType.DYNAMIC,
scaling_type_w=TensorScalingType.DYNAMIC,
scaling_type_dL_dY=TensorScalingType.DYNAMIC,
):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
Expand Down
3 changes: 1 addition & 2 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ def _test_fp8_mlp_tensor_parallelism_base(
mesh: DeviceMesh, size=16, compile: bool = False
):
device = mesh.device_type
# For now, just use Float8Linear with dynamic scaling, which is the
# same behavior as Float8Linear.
# For now, only supports dynamic scaling of `x` and `dL_dY`.
# TODO(future): add support for float8 all-gather with delayed scaling
# for activations and gradients.
extra_kwargs = {
Expand Down
Loading