Skip to content

Commit ddf72cd

Browse files
committed
Implementing the Option 1-warning recommending the users to use use_distributed_mode_trace=True
1 parent 40df0e2 commit ddf72cd

File tree

3 files changed

+70
-31
lines changed

3 files changed

+70
-31
lines changed

examples/distributed_inference/tensor_parallel_simple_example.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,35 @@
2121
"""
2222

2323

24+
def compile_tp_model(tp_model, backend):
25+
compile_options = {
26+
"truncate_long_and_double": True,
27+
"enabled_precisions": {torch.float32, torch.float16},
28+
"use_python_runtime": True,
29+
"min_block_size": 1,
30+
}
31+
32+
try:
33+
return torch.compile(
34+
tp_model, backend=backend, options=compile_options, dynamic=False
35+
)
36+
except RuntimeError as e:
37+
if (
38+
"aot_export is not currently supported with traceable tensor subclass"
39+
in str(e)
40+
):
41+
logger.warning(
42+
"It is recommended to run the model with use_distributed_mode_trace=True. Running with that option"
43+
)
44+
compile_options["use_distributed_mode_trace"] = True
45+
return torch.compile(
46+
tp_model, backend=backend, options=compile_options, dynamic=False
47+
)
48+
else:
49+
logger.debug("The distributed model fails with the following error")
50+
raise
51+
52+
2453
class ToyModel(nn.Module):
2554
"""MLP based model"""
2655

@@ -64,20 +93,7 @@ def forward(self, x):
6493
inp = torch.rand(20, 10, device="cuda")
6594
python_result = tp_model(inp)
6695

67-
68-
backend = "torch_tensorrt"
69-
tp_model = torch.compile(
70-
tp_model,
71-
backend=backend,
72-
options={
73-
"truncate_long_and_double": True,
74-
"enabled_precisions": {torch.float32, torch.float16},
75-
"use_python_runtime": True,
76-
"min_block_size": 1,
77-
"use_distributed_mode_trace": True,
78-
},
79-
dynamic=False,
80-
)
96+
compile_tp_model(tp_model, backend="torch_tensorrt")
8197

8298
for i in range(10):
8399
# For TP, input needs to be same across all TP ranks.

tests/py/dynamo/distributed/test_distributed_simple_example.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,36 @@
1616
"./tensor_parallel_simple_example"
1717
)
1818

19+
20+
def compile_tp_model(tp_model, backend):
21+
compile_options = {
22+
"truncate_long_and_double": True,
23+
"enabled_precisions": {torch.float32, torch.float16},
24+
"use_python_runtime": True,
25+
"min_block_size": 1,
26+
}
27+
28+
try:
29+
return torch.compile(
30+
tp_model, backend=backend, options=compile_options, dynamic=False
31+
)
32+
except RuntimeError as e:
33+
if (
34+
"aot_export is not currently supported with traceable tensor subclass"
35+
in str(e)
36+
):
37+
logger.warning(
38+
"It is recommended to run the model with use_distributed_mode_trace=True. Running with that option"
39+
)
40+
compile_options["use_distributed_mode_trace"] = True
41+
return torch.compile(
42+
tp_model, backend=backend, options=compile_options, dynamic=False
43+
)
44+
else:
45+
logger.debug("The distributed model fails with the following error")
46+
raise
47+
48+
1949
"""
2050
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
2151
"""
@@ -60,20 +90,7 @@ def forward(self, x):
6090
inp = torch.rand(20, 10, device="cuda")
6191
python_result = tp_model(inp)
6292

63-
64-
backend = "torch_tensorrt"
65-
tp_model = torch.compile(
66-
tp_model,
67-
backend=backend,
68-
options={
69-
"truncate_long_and_double": True,
70-
"enabled_precisions": {torch.float32, torch.float16},
71-
"use_python_runtime": True,
72-
"min_block_size": 1,
73-
"use_aot_joint_export": False,
74-
},
75-
dynamic=False,
76-
)
93+
compile_tp_model(tp_model, backend="torch_tensorrt")
7794

7895
for i in range(10):
7996
# For TP, input needs to be same across all TP ranks.

tests/py/dynamo/distributed/test_nccl_ops.sh

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,17 @@ fi
8888
URL="https://pypi.nvidia.com/tensorrt-llm/$FILE"
8989
echo "Downloading $FILE from $URL..."
9090
91-
echo "Downloading ...."
91+
echo "Downloading here...."
9292
#Installing wget
9393
ensure_installed wget
94-
#Downloading the package
95-
wget "$URL"
94+
95+
#Downloading the file
96+
filename=$(basename "$URL")
97+
if [ -f "$filename" ]; then
98+
echo "File already exists: $filename"
99+
else
100+
wget "$URL"
101+
fi
96102
echo "Download complete: $FILE"
97103
98104
UNZIP_DIR="tensorrt_llm_unzip"

0 commit comments

Comments
 (0)