Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 1f0166b

Browse files
committed
[4/x] add tests for DTensor TP/SP + Float8Linear
Summary: Makes the DTensor TP/SP tests also test `Float8Linear` with all scaling types configured to be dynamic. We can add support for delayed scaling with float8 all-gather for `x` and `dL_dY` in a future PR, as needed. Test Plan: ``` ./test/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d018546 Pull Request resolved: #294
1 parent a0ad964 commit 1f0166b

File tree

2 files changed

+75
-16
lines changed

2 files changed

+75
-16
lines changed

float8_experimental/float8_tensor_parallel.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
cast_to_float8_e4m3_dynamic,
55
cast_to_float8_e5m2_dynamic_bw,
66
)
7+
from float8_experimental.float8_linear import TensorScalingType
78
from torch.distributed._tensor import DTensor
89
from torch.distributed.device_mesh import DeviceMesh
910
from torch.distributed.tensor.parallel import (
@@ -19,7 +20,17 @@
1920
# here is that in input/output handling we do casting after
2021
# creating the DTensor.
2122

22-
# NOTE: This only works and tested with the DynamicLinear
23+
# NOTE: This only works and tested with the dynamic scaling
24+
# (Float8DynamicLinear and Float8Linear with dynamic scaling for all tensors)
25+
26+
27+
def _float8_linear_supports_float8_allgather(m):
28+
# TODO(future): add support for delayed scaling for activations
29+
# and gradients
30+
return (
31+
m.scaling_type_x == TensorScalingType.DYNAMIC
32+
and m.scaling_type_dL_dY == TensorScalingType.DYNAMIC
33+
)
2334

2435

2536
class Float8ColwiseParallel(ColwiseParallel):
@@ -61,11 +72,16 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
6172

6273
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
6374
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
75+
from float8_experimental.float8_linear import Float8Linear
6476

65-
if not isinstance(module, Float8DynamicLinear):
77+
if not isinstance(module, (Float8DynamicLinear, Float8Linear)):
6678
raise ValueError(
67-
f"Expecting module to be Float8DynamicLinear but found {type(module)}"
79+
f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}"
6880
)
81+
elif isinstance(
82+
module, Float8Linear
83+
) and not _float8_linear_supports_float8_allgather(module):
84+
raise AssertionError("unsupported")
6985

7086
return super()._apply(module, device_mesh)
7187

@@ -107,11 +123,16 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
107123

108124
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
109125
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
126+
from float8_experimental.float8_linear import Float8Linear
110127

111-
if not isinstance(module, Float8DynamicLinear):
128+
if not isinstance(module, (Float8DynamicLinear, Float8Linear)):
112129
raise ValueError(
113-
f"Expecting module to be Float8DynamicLinear but found {type(module)}"
130+
f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}"
114131
)
132+
elif isinstance(
133+
module, Float8Linear
134+
) and not _float8_linear_supports_float8_allgather(module):
135+
raise AssertionError("unsupported")
115136

116137
return super()._apply(module, device_mesh)
117138

@@ -184,22 +205,23 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
184205

185206
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
186207
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
208+
from float8_experimental.float8_linear import Float8Linear
187209

188210
fwd_linear_config = None
189211
if self.fwd_config_submodule_fqn is not None:
190212
fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn)
191-
assert isinstance(fwd_linear, Float8DynamicLinear)
213+
assert isinstance(fwd_linear, (Float8DynamicLinear, Float8Linear))
192214
fwd_linear_config = fwd_linear.forward_config
193215
else:
194216
# search for ScaledMM configs for all the submodules and make sure they are the same
195217
for mod in module.modules():
196-
if isinstance(mod, Float8DynamicLinear):
218+
if isinstance(mod, (Float8DynamicLinear, Float8Linear)):
197219
if fwd_linear_config is None:
198220
fwd_linear_config = mod.forward_config
199221
else:
200222
assert (
201223
fwd_linear_config == mod.forward_config
202-
), "All the Float8DynamicLinear modules should have same forward config!"
224+
), "All the Float8DynamicLinear and Float8Linear modules should have same forward config!"
203225

204226
self.fwd_linear_config = fwd_linear_config
205227
super()._apply(module, device_mesh)

test/test_dtensor.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Float8DynamicLinear,
1919
NoopFwToFloat8E5M2Bw,
2020
)
21+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
2122
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
2223
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
2324
from float8_experimental.float8_tensor_parallel import (
@@ -169,23 +170,37 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
169170
loss.backward()
170171

171172

172-
def test_fp8_mlp_tensor_parallelism_base(
173-
mesh: DeviceMesh, size=16, compile: bool = False
173+
def _test_fp8_mlp_tensor_parallelism_base(
174+
mesh: DeviceMesh, size=16, compile: bool = False, use_float8_linear: bool = False
174175
):
175176
device = mesh.device_type
177+
# TODO(future): delete Float8DynamicLinear from this test once all the
178+
# code is unified
179+
float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear
180+
extra_kwargs = {}
181+
if use_float8_linear:
182+
# For now, just use Float8Linear with dynamic scaling, which is the
183+
# same behavior as Float8Linear.
184+
# TODO(future): add support for float8 all-gather with delayed scaling
185+
# for activations and gradients.
186+
extra_kwargs = {
187+
"scaling_type_x": TensorScalingType.DYNAMIC,
188+
"scaling_type_w": TensorScalingType.DYNAMIC,
189+
"scaling_type_dL_dY": TensorScalingType.DYNAMIC,
190+
}
176191

177192
toy_model = ToyModel().to(device)
178193
toy_model_fp8 = swap_linear_with_float8_linear(
179-
toy_model, Float8DynamicLinear, emulate=True
194+
toy_model, float8_cls, emulate=True, **extra_kwargs
180195
)
181196

182197
tp_model = copy.deepcopy(toy_model)
183198
tp_model = swap_linear_with_float8_linear(
184-
tp_model, Float8DynamicLinear, emulate=True
199+
tp_model, float8_cls, emulate=True, **extra_kwargs
185200
)
186201
sp_model = copy.deepcopy(toy_model)
187202
sp_model = swap_linear_with_float8_linear(
188-
sp_model, Float8DynamicLinear, emulate=True
203+
sp_model, float8_cls, emulate=True, **extra_kwargs
189204
)
190205

191206
# vanilla TP
@@ -218,7 +233,7 @@ def test_fp8_mlp_tensor_parallelism_base(
218233
# PrepareFloat8ModuleInput with specific submodule fqn
219234
sp_model2 = copy.deepcopy(toy_model)
220235
sp_model2 = swap_linear_with_float8_linear(
221-
sp_model2, Float8DynamicLinear, emulate=True
236+
sp_model2, Float8DynamicLinear, emulate=True, **extra_kwargs
222237
)
223238

224239
sp_model2 = parallelize_module(
@@ -271,8 +286,28 @@ def test_fp8_mlp_tensor_parallelism_base(
271286
)
272287

273288

289+
def test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
290+
_test_fp8_mlp_tensor_parallelism_base(
291+
mesh, size, compile=False, use_float8_linear=False
292+
)
293+
294+
295+
def test_fp8_mlp_tensor_parallelism_eager_float8_linear(mesh: DeviceMesh, size=16):
296+
_test_fp8_mlp_tensor_parallelism_base(
297+
mesh, size, compile=False, use_float8_linear=True
298+
)
299+
300+
274301
def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
275-
test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
302+
_test_fp8_mlp_tensor_parallelism_base(
303+
mesh, size, compile=True, use_float8_linear=False
304+
)
305+
306+
307+
def test_fp8_mlp_tensor_parallelism_compile_float8_linear(mesh: DeviceMesh, size=16):
308+
_test_fp8_mlp_tensor_parallelism_base(
309+
mesh, size, compile=True, use_float8_linear=True
310+
)
276311

277312

278313
if __name__ == "__main__":
@@ -285,8 +320,10 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
285320
test_fp8_redistribute,
286321
test_dtensor_cast_to_fp8,
287322
test_dtensor_fp8_autograd,
288-
test_fp8_mlp_tensor_parallelism_base,
323+
test_fp8_mlp_tensor_parallelism_eager,
324+
test_fp8_mlp_tensor_parallelism_eager_float8_linear,
289325
test_fp8_mlp_tensor_parallelism_compile,
326+
test_fp8_mlp_tensor_parallelism_compile_float8_linear,
290327
]
291328

292329
for test in tqdm(tests, desc="Running tests"):

0 commit comments

Comments
 (0)