Skip to content

Commit bba4153

Browse files
committed
backend changes, addressing review comments
1 parent 8319c29 commit bba4153

File tree

9 files changed

+72
-23
lines changed

9 files changed

+72
-23
lines changed

examples/distributed_inference/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,33 @@ torchrun --nproc_per_node=2 tensor_parallel_llama2.py
1818
3. Tensor parallel distributed inference using nccl ops plugin
1919

2020
apt install libmpich-dev
21+
2122
apt install libopenmpi-dev
23+
24+
#For python3.10
25+
2226
pip install tensorrt-llm
27+
28+
For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export trtllm_env={lib_path}. For example, we have already set the variable in initialize_distributed_env(). Note that won't work while running example, since it needs to be preset for the converter library to get.
29+
2330
#then pip install the tensorrt and torch version compatible with installed torchTRT
31+
2432
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
2533

34+
#For other python
35+
2636
4. Tensor parallel distributed llama3 inference using nccl ops plugin
2737

2838
apt install libmpich-dev
39+
2940
apt install libopenmpi-dev
41+
42+
#For python3.10
43+
3044
pip install tensorrt-llm
45+
46+
For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so
47+
3148
#then pip install the tensorrt and torch version compatible with installed torchTRT
49+
3250
mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
accelerate
22
transformers
33
diffusers
4-
site
54
tensorrt-llm

examples/distributed_inference/tensor_parallel_llama3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
"use_python_runtime": True,
4949
"workspace_size": 1 << 33,
5050
"debug": False,
51+
"use_aot_joint_export": False,
5152
},
5253
dynamic=False,
5354
)

examples/distributed_inference/tensor_parallel_nccl_ops.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
dynamo_tensorrt_converter,
2020
)
2121
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
22-
custom_fused_all_gather_op,
23-
custom_fused_reduce_scatter_op,
22+
tensorrt_fused_nccl_all_gather_op,
23+
tensorrt_fused_nccl_reduce_scatter_op,
2424
)
2525
from torch_tensorrt.dynamo.types import TRTTensor
2626
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
@@ -105,7 +105,7 @@ def register_nccl_ops(logger_file_name):
105105
f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
106106
)
107107

108-
@dynamo_tensorrt_converter(custom_fused_all_gather_op)
108+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
109109
def insert_nccl_gather_op(
110110
ctx: ConversionContext,
111111
target: Target,
@@ -118,12 +118,18 @@ def insert_nccl_gather_op(
118118
"AllGather", "1", "tensorrt_llm"
119119
)
120120
assert allgather_plg_creator is not None
121-
_world_size = int(os.environ["WORLD_SIZE"])
121+
_world_size = os.environ.get("WORLD_SIZE")
122+
if _world_size is not None:
123+
_world_size = int(_world_size)
124+
else:
125+
raise RuntimeError(
126+
f"The WORLD_SIZE env variable is not set in distributed environment"
127+
)
122128
group = list(range(_world_size))
123129
group = trt.PluginField(
124130
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
125131
)
126-
p_dtype = trt.float16
132+
p_dtype = trt.float32
127133
pf_type = trt.PluginField(
128134
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
129135
)
@@ -133,7 +139,7 @@ def insert_nccl_gather_op(
133139
set_layer_name(layer, target, name)
134140
return layer.get_output(0)
135141

136-
@dynamo_tensorrt_converter(custom_fused_reduce_scatter_op)
142+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
137143
def insert_nccl_reduce_scatter_plugin(
138144
ctx: ConversionContext,
139145
target: Target,
@@ -151,9 +157,14 @@ def insert_nccl_reduce_scatter_plugin(
151157
counter = 0
152158
strategy = AllReduceStrategy.NCCL
153159
config = AllReduceConfig(0)
154-
155-
world_size = dist.get_world_size()
156-
group = list(range(world_size))
160+
_world_size = os.environ.get("WORLD_SIZE")
161+
if _world_size is not None:
162+
_world_size = int(_world_size)
163+
else:
164+
raise RuntimeError(
165+
f"The WORLD_SIZE env variable is not set in distributed environment"
166+
)
167+
group = list(range(_world_size))
157168
group = trt.PluginField(
158169
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
159170
)

examples/distributed_inference/tensor_parallel_simple_example.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def forward(self, x):
7575
"enabled_precisions": {torch.float32, torch.float16},
7676
"use_python_runtime": True,
7777
"min_block_size": 1,
78+
"use_aot_joint_export": False,
7879
},
7980
dynamic=False,
8081
)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
IMMUTABLE_WEIGHTS = True
4747
ENABLE_WEIGHT_STREAMING = False
4848
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
49+
USE_AOT_JOINT_EXPORT = True
4950

5051

5152
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
STRIP_ENGINE_WEIGHTS,
3434
TIMING_CACHE_PATH,
3535
TRUNCATE_DOUBLE,
36+
USE_AOT_JOINT_EXPORT,
3637
USE_EXPLICIT_TYPING,
3738
USE_FAST_PARTITIONER,
3839
USE_FP32_ACC,
@@ -91,6 +92,7 @@ class CompilationSettings:
9192
enable_weight_streaming (bool): Enable weight streaming.
9293
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
9394
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
95+
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
9496
"""
9597

9698
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -131,6 +133,7 @@ class CompilationSettings:
131133
immutable_weights: bool = IMMUTABLE_WEIGHTS
132134
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
133135
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
136+
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
134137

135138

136139
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

3+
import functools
34
import logging
45
import unittest
56
from typing import Any, Callable, Sequence
67

78
import torch
89
import torch._dynamo as td
10+
from torch._dynamo.backends.common import aot_autograd
911
from torch._dynamo.utils import detect_fake_mode
1012
from torch._functorch.aot_autograd import aot_export_joint_simple
1113
from torch_tensorrt.dynamo import CompilationSettings
@@ -49,7 +51,19 @@ def aot_torch_tensorrt_aten_backend(
4951
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
5052
) -> torch.nn.Module:
5153
settings, engine_cache = parse_dynamo_kwargs(kwargs)
52-
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
54+
if settings.use_aot_joint_export:
55+
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
56+
logger.debug("Wrapping the backend with aot_autograd\n")
57+
_pretraced_backend_autograd = functools.partial(
58+
_pretraced_backend, settings=settings, engine_cache=engine_cache
59+
)
60+
settings_aot_autograd = {}
61+
settings_aot_autograd["decompostions"] = get_decompositions(
62+
settings.enable_experimental_decompositions
63+
)
64+
return aot_autograd(fw_compiler=_pretraced_backend_autograd)(
65+
gm, sample_inputs, **settings_aot_autograd
66+
)
5367

5468

5569
def _pretraced_backend(
@@ -90,14 +104,15 @@ def _pretraced_backend(
90104
remove_detach(gm, settings)
91105

92106
# Invoke AOTAutograd to translate operators to aten
93-
gm = aot_export_joint_simple(
94-
gm,
95-
sample_inputs,
96-
trace_joint=False,
97-
decompositions=get_decompositions(
98-
settings.enable_experimental_decompositions
99-
),
100-
)
107+
if settings.use_aot_joint_export:
108+
gm = aot_export_joint_simple(
109+
gm,
110+
sample_inputs,
111+
trace_joint=False,
112+
decompositions=get_decompositions(
113+
settings.enable_experimental_decompositions
114+
),
115+
)
101116

102117
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
103118

py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
logger = logging.getLogger(__name__)
1313

1414

15-
def custom_fused_all_gather_op(args0, args1, args2):
15+
def tensorrt_fused_nccl_all_gather_op(args0, args1, args2):
1616
return torch.ops._c10d_functional.wait_tensor.default(
1717
torch.ops._c10d_functional.all_gather_into_tensor.default(args0, args1, args2)
1818
)
1919

2020

21-
def custom_fused_reduce_scatter_op(args0, args1, args2, args3):
21+
def tensorrt_fused_nccl_reduce_scatter_op(args0, args1, args2, args3):
2222
return torch.ops._c10d_functional.wait_tensor.default(
2323
torch.ops._c10d_functional.reduce_scatter_tensor.default(
2424
args0, args1, args2, args3
@@ -44,10 +44,10 @@ def fuse_distributed_ops(
4444
wait_tensor_node = list(node.users)[0]
4545
fused_op = None
4646
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
47-
fused_op = custom_fused_all_gather_op
47+
fused_op = tensorrt_fused_nccl_all_gather_op
4848
fused_op_args = (node.args[0], node.args[1], node.args[2])
4949
else:
50-
fused_op = custom_fused_reduce_scatter_op
50+
fused_op = tensorrt_fused_nccl_reduce_scatter_op
5151
fused_op_args = (node.args[0], node.args[1], node.args[2], node.args[3])
5252
with gm.graph.inserting_after(wait_tensor_node):
5353
fused_node = gm.graph.create_node(

0 commit comments

Comments
 (0)