Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 14b00aa

Browse files
drisspgfacebook-github-bot
authored andcommitted
Add Dtensor compile test (#256)
Summary: Adds test Pull Request resolved: #256 Reviewed By: wanchaol Differential Revision: D56946306 Pulled By: drisspg fbshipit-source-id: 6ce100884cbae93e3acfc176f90038e87d44a0fc
1 parent 605fc1d commit 14b00aa

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

test/test_dtensor.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
158158
loss.backward()
159159

160160

161-
def test_fp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
161+
def test_fp8_mlp_tensor_parallelism_base(
162+
mesh: DeviceMesh, size=16, compile: bool = False
163+
):
162164
device = mesh.device_type
163165

164166
toy_model = ToyModel().to(device)
@@ -197,6 +199,9 @@ def test_fp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
197199
},
198200
)
199201

202+
if compile:
203+
tp_model = torch.compile(tp_model)
204+
200205
x_fp32 = torch.rand(size * 2, size, device=device, requires_grad=False)
201206
x_fp32_tp_input = x_fp32.clone()
202207
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])
@@ -217,6 +222,10 @@ def test_fp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
217222
)
218223

219224

225+
def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
226+
test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
227+
228+
220229
if __name__ == "__main__":
221230
# float8 only works on CUDA H100 so we only test cuda and we follow
222231
# other test files to not use TestCase but instead just add the test
@@ -227,7 +236,8 @@ def test_fp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16):
227236
test_fp8_redistribute,
228237
test_dtensor_cast_to_fp8,
229238
test_dtensor_fp8_autograd,
230-
test_fp8_mlp_tensor_parallelism,
239+
test_fp8_mlp_tensor_parallelism_base,
240+
test_fp8_mlp_tensor_parallelism_compile,
231241
]
232242

233243
for test in tqdm(tests, desc="Running tests"):

0 commit comments

Comments
 (0)