Skip to content

Commit a4ee3e2

Browse files
committed
feat: Add sample torch.compile backend for tensorrt aten path
- Add backend adapted from previous `fx2trt_compiler` provided by Dynamo - Currently, the TRTSplitter needs work to fully support the `aten` path - Additionally, the existing `aten` pass was reworked to exclude the `torch._dynamo.export` call, which may be necessary here
1 parent a1d4af0 commit a4ee3e2

File tree

2 files changed

+113
-2
lines changed

2 files changed

+113
-2
lines changed

py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def trace(f, args, *rest):
134134

135135

136136
@req_torch_version("2.dev")
137-
def opt_trace(f, args, *rest):
137+
def opt_trace(f, args, perform_trace=True, *rest):
138138
"""
139139
Optimized trace with necessary passes which re-compose some ops or replace some ops
140140
These passes should be general and functional purpose
@@ -152,7 +152,11 @@ def opt_trace(f, args, *rest):
152152
replace_inplace_ops, # remove it once functionalization is enabled
153153
]
154154

155-
fx_module, _ = trace(f, args)
155+
if perform_trace:
156+
fx_module, _ = trace(f, args)
157+
else:
158+
fx_module = f
159+
156160
print(fx_module.graph)
157161
for passes in passes_list:
158162
pr: PassResult = passes(fx_module)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import torch
2+
import traceback
3+
import torch._dynamo as td
4+
5+
from torch_tensorrt.fx.fx2trt import (
6+
InputTensorSpec,
7+
TRTInterpreter,
8+
)
9+
import tensorrt as trt
10+
from torch_tensorrt.fx.tools.trt_splitter import (
11+
TRTSplitter,
12+
TRTSplitterSetting,
13+
)
14+
from torch_tensorrt.fx.tracer.dispatch_tracer import aten_tracer
15+
from torch_tensorrt.fx.trt_module import TRTModule
16+
from torch_tensorrt.fx.utils import LowerPrecision
17+
18+
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
19+
20+
MAX_SPLITS_THRESHOLD = 10
21+
22+
23+
def tensorrt_backend(gm, sample_inputs):
24+
# Invoke AOTAutograd to compile model
25+
return aot_module_simplified(
26+
gm,
27+
sample_inputs,
28+
fw_compiler=make_boxed_compiler(fx2trt_compiler),
29+
)
30+
31+
32+
def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs):
33+
model = gm
34+
inputs = example_inputs
35+
36+
# Perform lowering pass on model
37+
model = aten_tracer.opt_trace(model, inputs, perform_trace=False)
38+
39+
# Split out unsupported ops --> Needs rewrite/revision for ATEN
40+
splitter_setting = TRTSplitterSetting()
41+
splitter_setting.use_implicit_batch_dim = False
42+
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
43+
44+
splitter.node_support_preview()
45+
split_mod = splitter()
46+
num_piece = 0
47+
48+
for name, _ in split_mod.named_children():
49+
print(f"Graph is split into {name}")
50+
num_pieces += 1
51+
52+
# Select threshold above which segmentation is not beneficial and run graph in Torch
53+
if num_pieces > MAX_SPLITS_THRESHOLD:
54+
raise AssertionError(
55+
f"The graph module is split into {num_piece} which is large than the \
56+
threshold={MAX_SPLITS_THRESHOLD}. Falling back to non-TRT module."
57+
)
58+
59+
precision = LowerPrecision.FP32
60+
61+
def get_submod_inputs(mod, submod, inputs):
62+
acc_inputs = None
63+
64+
def get_input(self, inputs):
65+
nonlocal acc_inputs
66+
acc_inputs = inputs
67+
68+
handle = submod.register_forward_pre_hook(get_input)
69+
mod(*inputs)
70+
handle.remove()
71+
return acc_inputs
72+
73+
for name, _ in split_mod.named_children():
74+
if "_run_on_acc" in name:
75+
submod = getattr(split_mod, name)
76+
acc_inputs = get_submod_inputs(split_mod, submod, inputs)
77+
78+
interp = TRTInterpreter(
79+
submod,
80+
InputTensorSpec.from_tensors(acc_inputs),
81+
explicit_batch_dimension=True,
82+
logger_level=trt.Logger.VERBOSE,
83+
)
84+
r = interp.run(
85+
max_workspace_size=20 << 30,
86+
lower_precision=precision,
87+
profiling_verbosity=trt.ProfilingVerbosity.VERBOSE,
88+
)
89+
90+
trt_mod = TRTModule(*r)
91+
92+
setattr(split_mod, name, trt_mod)
93+
94+
return split_mod
95+
96+
97+
@td.register_backend
98+
def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs):
99+
try:
100+
trt_compiled = fx2trt(gm, example_inputs)
101+
return trt_compiled
102+
except Exception:
103+
traceback.print_exc()
104+
print(
105+
"FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead"
106+
)
107+
return gm.forward

0 commit comments

Comments
 (0)