Skip to content

Commit 5d4df85

Browse files
committed
adding test cases
1 parent 4f2f93f commit 5d4df85

File tree

3 files changed

+99
-3
lines changed

3 files changed

+99
-3
lines changed

py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
from typing import Optional, Tuple, Union
44

55
import numpy as np
6+
import tensorrt as trt
67
from torch.fx.node import Argument, Target
78
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
89
from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name
910

10-
import tensorrt as trt
11-
1211

1312
# class for AllReduce
1413
class AllReduceStrategy(IntEnum):
@@ -94,7 +93,7 @@ def nccl_reduce_scatter(
9493
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
9594
)
9695

97-
p_dtype = trt.float16
96+
p_dtype = trt.float32
9897
pf_dtype = trt.PluginField(
9998
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
10099
)

tests/py/dynamo/conversion/harness.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def generate_graph(
353353
enable_passes: bool,
354354
propagate_shapes: bool = False,
355355
settings: CompilationSettings = CompilationSettings(),
356+
fuse_distributed_ops: bool = False,
356357
torch_export_dynamic_shapes: Optional[Any] = None,
357358
):
358359
mod = mod.eval()
@@ -368,6 +369,16 @@ def generate_graph(
368369
tuple(torch_export_inputs),
369370
dynamic_shapes=torch_export_dynamic_shapes,
370371
)
372+
if fuse_distributed_ops:
373+
exported_program = exported_program.run_decompositions(
374+
get_decompositions(False)
375+
)
376+
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
377+
fuse_distributed_ops,
378+
)
379+
380+
gm = exported_program.graph_module
381+
gm = fuse_distributed_ops(gm, settings)
371382
if enable_passes:
372383
exported_program = pre_export_lowering(exported_program, settings)
373384
exported_program = exported_program.run_decompositions(
@@ -406,6 +417,7 @@ def run_test(
406417
propagate_shapes=False,
407418
int32_reqd=False,
408419
immutable_weights=True,
420+
fuse_distributed_ops=False,
409421
):
410422
# TODO: lan to remove this and set use_dynamo_traccer to True by default
411423
# once all the converter test files are moved to use_dynamo_tracer
@@ -426,6 +438,7 @@ def run_test(
426438
enable_passes=enable_passes,
427439
propagate_shapes=propagate_shapes,
428440
settings=compilation_settings,
441+
fuse_distributed_ops=fuse_distributed_ops,
429442
)
430443

431444
num_inputs = len(inputs)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.nn as nn
6+
from parameterized import parameterized
7+
from torch.testing._internal.common_utils import run_tests
8+
9+
10+
def set_environment_variables():
11+
os.environ["WORLD_SIZE"] = str(1)
12+
os.environ["RANK"] = str(0)
13+
os.environ["MASTER_ADDR"] = "127.0.0.1"
14+
os.environ["MASTER_PORT"] = str(29500)
15+
os.environ["USE_TRTLLM_PLUGINS"] = "1"
16+
17+
18+
set_environment_variables()
19+
dist.init_process_group(backend="nccl", init_method="env://")
20+
group = dist.new_group(ranks=[0])
21+
group_name = group.group_name
22+
world_size = 1
23+
24+
from conversion.harness import DispatchTestCase
25+
26+
27+
class TestGatherNcclOpsConverter(DispatchTestCase):
28+
@parameterized.expand([(8)])
29+
def test_nccl_ops(self, linear_layer_dim):
30+
class DistributedGatherModel(nn.Module):
31+
def __init__(self, input_dim):
32+
super().__init__()
33+
self.fc = torch.nn.Linear(input_dim, input_dim)
34+
35+
def forward(self, x):
36+
x = self.fc(x)
37+
gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor(
38+
x, world_size, group_name
39+
)
40+
gathered_tensor = torch.ops._c10d_functional.wait_tensor(
41+
gathered_tensor
42+
)
43+
return gathered_tensor
44+
45+
inputs = [torch.randn(1, linear_layer_dim).to("cuda")]
46+
self.run_test(
47+
DistributedGatherModel(linear_layer_dim).cuda(),
48+
inputs,
49+
use_dynamo_tracer=True,
50+
fuse_distributed_ops=True,
51+
)
52+
53+
@parameterized.expand([(8)])
54+
def test_nccl_ops_scatter(self, linear_layer_dim):
55+
56+
class DistributedReduceScatterModel(nn.Module):
57+
def __init__(self, input_dim):
58+
super().__init__()
59+
self.fc = torch.nn.Linear(input_dim, input_dim)
60+
61+
def forward(self, x):
62+
x = self.fc(x)
63+
scatter_reduce_tensor = (
64+
torch.ops._c10d_functional.reduce_scatter_tensor(
65+
x, "sum", world_size, group_name
66+
)
67+
)
68+
scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor(
69+
scatter_reduce_tensor
70+
)
71+
return scatter_reduce_tensor
72+
73+
inputs = [torch.zeros(1, linear_layer_dim).to("cuda")]
74+
75+
self.run_test(
76+
DistributedReduceScatterModel(linear_layer_dim).cuda(),
77+
inputs,
78+
use_dynamo_tracer=True,
79+
fuse_distributed_ops=True,
80+
)
81+
82+
83+
if __name__ == "__main__":
84+
run_tests()

0 commit comments

Comments
 (0)