18
18
Float8DynamicLinear ,
19
19
NoopFwToFloat8E5M2Bw ,
20
20
)
21
+ from float8_experimental .float8_linear import Float8Linear , TensorScalingType
21
22
from float8_experimental .float8_linear_utils import swap_linear_with_float8_linear
22
23
from float8_experimental .float8_tensor import Float8Tensor , ScaledMMConfig
23
24
from float8_experimental .float8_tensor_parallel import (
@@ -169,23 +170,37 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
169
170
loss .backward ()
170
171
171
172
172
- def test_fp8_mlp_tensor_parallelism_base (
173
- mesh : DeviceMesh , size = 16 , compile : bool = False
173
+ def _test_fp8_mlp_tensor_parallelism_base (
174
+ mesh : DeviceMesh , size = 16 , compile : bool = False , use_float8_linear : bool = False
174
175
):
175
176
device = mesh .device_type
177
+ # TODO(future): delete Float8DynamicLinear from this test once all the
178
+ # code is unified
179
+ float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear
180
+ extra_kwargs = {}
181
+ if use_float8_linear :
182
+ # For now, just use Float8Linear with dynamic scaling, which is the
183
+ # same behavior as Float8Linear.
184
+ # TODO(future): add support for float8 all-gather with delayed scaling
185
+ # for activations and gradients.
186
+ extra_kwargs = {
187
+ "scaling_type_x" : TensorScalingType .DYNAMIC ,
188
+ "scaling_type_w" : TensorScalingType .DYNAMIC ,
189
+ "scaling_type_dL_dY" : TensorScalingType .DYNAMIC ,
190
+ }
176
191
177
192
toy_model = ToyModel ().to (device )
178
193
toy_model_fp8 = swap_linear_with_float8_linear (
179
- toy_model , Float8DynamicLinear , emulate = True
194
+ toy_model , float8_cls , emulate = True , ** extra_kwargs
180
195
)
181
196
182
197
tp_model = copy .deepcopy (toy_model )
183
198
tp_model = swap_linear_with_float8_linear (
184
- tp_model , Float8DynamicLinear , emulate = True
199
+ tp_model , float8_cls , emulate = True , ** extra_kwargs
185
200
)
186
201
sp_model = copy .deepcopy (toy_model )
187
202
sp_model = swap_linear_with_float8_linear (
188
- sp_model , Float8DynamicLinear , emulate = True
203
+ sp_model , float8_cls , emulate = True , ** extra_kwargs
189
204
)
190
205
191
206
# vanilla TP
@@ -218,7 +233,7 @@ def test_fp8_mlp_tensor_parallelism_base(
218
233
# PrepareFloat8ModuleInput with specific submodule fqn
219
234
sp_model2 = copy .deepcopy (toy_model )
220
235
sp_model2 = swap_linear_with_float8_linear (
221
- sp_model2 , Float8DynamicLinear , emulate = True
236
+ sp_model2 , Float8DynamicLinear , emulate = True , ** extra_kwargs
222
237
)
223
238
224
239
sp_model2 = parallelize_module (
@@ -271,8 +286,28 @@ def test_fp8_mlp_tensor_parallelism_base(
271
286
)
272
287
273
288
289
+ def test_fp8_mlp_tensor_parallelism_eager (mesh : DeviceMesh , size = 16 ):
290
+ _test_fp8_mlp_tensor_parallelism_base (
291
+ mesh , size , compile = False , use_float8_linear = False
292
+ )
293
+
294
+
295
+ def test_fp8_mlp_tensor_parallelism_eager_float8_linear (mesh : DeviceMesh , size = 16 ):
296
+ _test_fp8_mlp_tensor_parallelism_base (
297
+ mesh , size , compile = False , use_float8_linear = True
298
+ )
299
+
300
+
274
301
def test_fp8_mlp_tensor_parallelism_compile (mesh : DeviceMesh , size = 16 ):
275
- test_fp8_mlp_tensor_parallelism_base (mesh , size , compile = True )
302
+ _test_fp8_mlp_tensor_parallelism_base (
303
+ mesh , size , compile = True , use_float8_linear = False
304
+ )
305
+
306
+
307
+ def test_fp8_mlp_tensor_parallelism_compile_float8_linear (mesh : DeviceMesh , size = 16 ):
308
+ _test_fp8_mlp_tensor_parallelism_base (
309
+ mesh , size , compile = True , use_float8_linear = True
310
+ )
276
311
277
312
278
313
if __name__ == "__main__" :
@@ -285,8 +320,10 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
285
320
test_fp8_redistribute ,
286
321
test_dtensor_cast_to_fp8 ,
287
322
test_dtensor_fp8_autograd ,
288
- test_fp8_mlp_tensor_parallelism_base ,
323
+ test_fp8_mlp_tensor_parallelism_eager ,
324
+ test_fp8_mlp_tensor_parallelism_eager_float8_linear ,
289
325
test_fp8_mlp_tensor_parallelism_compile ,
326
+ test_fp8_mlp_tensor_parallelism_compile_float8_linear ,
290
327
]
291
328
292
329
for test in tqdm (tests , desc = "Running tests" ):
0 commit comments