@@ -158,7 +158,9 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
158
158
loss .backward ()
159
159
160
160
161
- def test_fp8_mlp_tensor_parallelism (mesh : DeviceMesh , size = 16 ):
161
+ def test_fp8_mlp_tensor_parallelism_base (
162
+ mesh : DeviceMesh , size = 16 , compile : bool = False
163
+ ):
162
164
device = mesh .device_type
163
165
164
166
toy_model = ToyModel ().to (device )
@@ -197,6 +199,9 @@ def test_fp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
197
199
},
198
200
)
199
201
202
+ if compile :
203
+ tp_model = torch .compile (tp_model )
204
+
200
205
x_fp32 = torch .rand (size * 2 , size , device = device , requires_grad = False )
201
206
x_fp32_tp_input = x_fp32 .clone ()
202
207
x_fp32_sp_input = distribute_tensor (x_fp32 .clone (), mesh , [Shard (0 )])
@@ -217,6 +222,10 @@ def test_fp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
217
222
)
218
223
219
224
225
+ def test_fp8_mlp_tensor_parallelism_compile (mesh : DeviceMesh , size = 16 ):
226
+ test_fp8_mlp_tensor_parallelism_base (mesh , size , compile = True )
227
+
228
+
220
229
if __name__ == "__main__" :
221
230
# float8 only works on CUDA H100 so we only test cuda and we follow
222
231
# other test files to not use TestCase but instead just add the test
@@ -227,7 +236,8 @@ def test_fp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
227
236
test_fp8_redistribute ,
228
237
test_dtensor_cast_to_fp8 ,
229
238
test_dtensor_fp8_autograd ,
230
- test_fp8_mlp_tensor_parallelism ,
239
+ test_fp8_mlp_tensor_parallelism_base ,
240
+ test_fp8_mlp_tensor_parallelism_compile ,
231
241
]
232
242
233
243
for test in tqdm (tests , desc = "Running tests" ):
0 commit comments