Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 54364a8

Browse files
committed
[4/x] add tests for DTensor TP/SP + Float8Linear
Summary: Makes the DTensor TP/SP tests also test `Float8Linear` with all scaling types configured to be dynamic. We can add support for delayed scaling with float8 all-gather for `x` and `dL_dY` in a future PR, as needed. Test Plan: ``` ./test/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8e121ec Pull Request resolved: #294
1 parent a0ad964 commit 54364a8

File tree

2 files changed

+73
-15
lines changed

2 files changed

+73
-15
lines changed

float8_experimental/float8_tensor_parallel.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
cast_to_float8_e4m3_dynamic,
55
cast_to_float8_e5m2_dynamic_bw,
66
)
7+
from float8_experimental.float8_linear import TensorScalingType
78
from torch.distributed._tensor import DTensor
89
from torch.distributed.device_mesh import DeviceMesh
910
from torch.distributed.tensor.parallel import (
@@ -22,6 +23,15 @@
2223
# NOTE: This only works and tested with the DynamicLinear
2324

2425

26+
def _float8_linear_supports_float8_allgather(m):
27+
# TODO(future PR): add support for delayed scaling for activations
28+
# and gradients
29+
return (
30+
m.scaling_type_x == TensorScalingType.DYNAMIC
31+
and m.scaling_type_dL_dY == TensorScalingType.DYNAMIC
32+
)
33+
34+
2535
class Float8ColwiseParallel(ColwiseParallel):
2636
@staticmethod
2737
def _prepare_input_fn(
@@ -61,11 +71,16 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
6171

6272
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
6373
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
74+
from float8_experimental.float8_linear import Float8Linear
6475

65-
if not isinstance(module, Float8DynamicLinear):
76+
if not isinstance(module, (Float8DynamicLinear, Float8Linear)):
6677
raise ValueError(
67-
f"Expecting module to be Float8DynamicLinear but found {type(module)}"
78+
f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}"
6879
)
80+
elif isinstance(
81+
module, Float8Linear
82+
) and not _float8_linear_supports_float8_allgather(module):
83+
raise AssertionError("unsupported")
6984

7085
return super()._apply(module, device_mesh)
7186

@@ -107,11 +122,16 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
107122

108123
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
109124
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
125+
from float8_experimental.float8_linear import Float8Linear
110126

111-
if not isinstance(module, Float8DynamicLinear):
127+
if not isinstance(module, (Float8DynamicLinear, Float8Linear)):
112128
raise ValueError(
113-
f"Expecting module to be Float8DynamicLinear but found {type(module)}"
129+
f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}"
114130
)
131+
elif isinstance(
132+
module, Float8Linear
133+
) and not _float8_linear_supports_float8_allgather(module):
134+
raise AssertionError("unsupported")
115135

116136
return super()._apply(module, device_mesh)
117137

@@ -184,22 +204,23 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
184204

185205
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
186206
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
207+
from float8_experimental.float8_linear import Float8Linear
187208

188209
fwd_linear_config = None
189210
if self.fwd_config_submodule_fqn is not None:
190211
fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn)
191-
assert isinstance(fwd_linear, Float8DynamicLinear)
212+
assert isinstance(fwd_linear, (Float8DynamicLinear, Float8Linear))
192213
fwd_linear_config = fwd_linear.forward_config
193214
else:
194215
# search for ScaledMM configs for all the submodules and make sure they are the same
195216
for mod in module.modules():
196-
if isinstance(mod, Float8DynamicLinear):
217+
if isinstance(mod, (Float8DynamicLinear, Float8Linear)):
197218
if fwd_linear_config is None:
198219
fwd_linear_config = mod.forward_config
199220
else:
200221
assert (
201222
fwd_linear_config == mod.forward_config
202-
), "All the Float8DynamicLinear modules should have same forward config!"
223+
), "All the Float8DynamicLinear and Float8Linear modules should have same forward config!"
203224

204225
self.fwd_linear_config = fwd_linear_config
205226
super()._apply(module, device_mesh)

test/test_dtensor.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Float8DynamicLinear,
1919
NoopFwToFloat8E5M2Bw,
2020
)
21+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
2122
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
2223
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
2324
from float8_experimental.float8_tensor_parallel import (
@@ -169,23 +170,37 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
169170
loss.backward()
170171

171172

172-
def test_fp8_mlp_tensor_parallelism_base(
173-
mesh: DeviceMesh, size=16, compile: bool = False
173+
def _test_fp8_mlp_tensor_parallelism_base(
174+
mesh: DeviceMesh, size=16, compile: bool = False, use_float8_linear: bool = False
174175
):
175176
device = mesh.device_type
177+
# TODO(future): delete Float8DynamicLinear from this test once all the
178+
# code is unified
179+
float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear
180+
extra_kwargs = {}
181+
if use_float8_linear:
182+
# For now, just use Float8Linear with dynamic scaling, which is the
183+
# same behavior as Float8Linear.
184+
# TODO(future): add support for float8 all-gather with delayed scaling
185+
# for activations and gradients.
186+
extra_kwargs = {
187+
"scaling_type_x": TensorScalingType.DYNAMIC,
188+
"scaling_type_w": TensorScalingType.DYNAMIC,
189+
"scaling_type_dL_dY": TensorScalingType.DYNAMIC,
190+
}
176191

177192
toy_model = ToyModel().to(device)
178193
toy_model_fp8 = swap_linear_with_float8_linear(
179-
toy_model, Float8DynamicLinear, emulate=True
194+
toy_model, float8_cls, emulate=True, **extra_kwargs
180195
)
181196

182197
tp_model = copy.deepcopy(toy_model)
183198
tp_model = swap_linear_with_float8_linear(
184-
tp_model, Float8DynamicLinear, emulate=True
199+
tp_model, float8_cls, emulate=True, **extra_kwargs
185200
)
186201
sp_model = copy.deepcopy(toy_model)
187202
sp_model = swap_linear_with_float8_linear(
188-
sp_model, Float8DynamicLinear, emulate=True
203+
sp_model, float8_cls, emulate=True, **extra_kwargs
189204
)
190205

191206
# vanilla TP
@@ -218,7 +233,7 @@ def test_fp8_mlp_tensor_parallelism_base(
218233
# PrepareFloat8ModuleInput with specific submodule fqn
219234
sp_model2 = copy.deepcopy(toy_model)
220235
sp_model2 = swap_linear_with_float8_linear(
221-
sp_model2, Float8DynamicLinear, emulate=True
236+
sp_model2, Float8DynamicLinear, emulate=True, **extra_kwargs
222237
)
223238

224239
sp_model2 = parallelize_module(
@@ -271,8 +286,28 @@ def test_fp8_mlp_tensor_parallelism_base(
271286
)
272287

273288

289+
def test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
290+
_test_fp8_mlp_tensor_parallelism_base(
291+
mesh, size, compile=False, use_float8_linear=False
292+
)
293+
294+
295+
def test_fp8_mlp_tensor_parallelism_eager_float8_linear(mesh: DeviceMesh, size=16):
296+
_test_fp8_mlp_tensor_parallelism_base(
297+
mesh, size, compile=False, use_float8_linear=True
298+
)
299+
300+
274301
def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
275-
test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
302+
_test_fp8_mlp_tensor_parallelism_base(
303+
mesh, size, compile=True, use_float8_linear=False
304+
)
305+
306+
307+
def test_fp8_mlp_tensor_parallelism_compile_float8_linear(mesh: DeviceMesh, size=16):
308+
_test_fp8_mlp_tensor_parallelism_base(
309+
mesh, size, compile=True, use_float8_linear=True
310+
)
276311

277312

278313
if __name__ == "__main__":
@@ -285,8 +320,10 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
285320
test_fp8_redistribute,
286321
test_dtensor_cast_to_fp8,
287322
test_dtensor_fp8_autograd,
288-
test_fp8_mlp_tensor_parallelism_base,
323+
test_fp8_mlp_tensor_parallelism_eager,
324+
test_fp8_mlp_tensor_parallelism_eager_float8_linear,
289325
test_fp8_mlp_tensor_parallelism_compile,
326+
test_fp8_mlp_tensor_parallelism_compile_float8_linear,
290327
]
291328

292329
for test in tqdm(tests, desc="Running tests"):

0 commit comments

Comments
 (0)