Skip to content

Commit 09cd47a

Browse files
committed
feat: Add draft Dynamo backend based on PR #1751
1 parent 5a45f6b commit 09cd47a

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
import traceback
3+
import torch._dynamo as td
4+
5+
from torch_tensorrt.fx.fx2trt import (
6+
InputTensorSpec,
7+
TRTInterpreter,
8+
)
9+
import tensorrt as trt
10+
from torch_tensorrt.fx.tracer.dispatch_tracer import aten_tracer
11+
from torch_tensorrt.fx.trt_module import TRTModule
12+
from torch_tensorrt.fx.utils import LowerPrecision
13+
14+
from torch._dynamo.backends.common import fake_tensor_unsupported
15+
16+
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
17+
18+
from torch._inductor.decomposition import decompositions
19+
20+
21+
def partition(gm: torch.fx.GraphModule):
22+
pass
23+
24+
25+
DECOMPOSITIONS = decompositions.copy()
26+
27+
28+
def tensorrt_backend(gm, sample_inputs):
29+
# Invoke AOTAutograd to compile model
30+
return aot_module_simplified(
31+
gm,
32+
sample_inputs,
33+
fw_compiler=make_boxed_compiler(fx2trt_compiler),
34+
decompositions=DECOMPOSITIONS,
35+
)
36+
37+
38+
def fx2trt(model: torch.fx.GraphModule, inputs, **kwargs):
39+
partitioned_model = partition(model)
40+
41+
precision = LowerPrecision.FP32
42+
43+
def get_submod_inputs(mod, submod, inputs):
44+
acc_inputs = None
45+
46+
def get_input(self, inputs):
47+
nonlocal acc_inputs
48+
acc_inputs = inputs
49+
50+
handle = submod.register_forward_pre_hook(get_input)
51+
mod(*inputs)
52+
handle.remove()
53+
return acc_inputs
54+
55+
for name, _ in partitioned_model.named_children():
56+
submod = getattr(partitioned_model, name)
57+
acc_inputs = get_submod_inputs(partitioned_model, submod, inputs)
58+
59+
interp = TRTInterpreter(
60+
submod,
61+
InputTensorSpec.from_tensors(acc_inputs),
62+
explicit_batch_dimension=True,
63+
logger_level=trt.Logger.VERBOSE,
64+
)
65+
r = interp.run(
66+
max_workspace_size=20 << 30,
67+
lower_precision=precision,
68+
profiling_verbosity=trt.ProfilingVerbosity.VERBOSE,
69+
)
70+
71+
trt_mod = TRTModule(*r)
72+
73+
setattr(partitioned_model, name, trt_mod)
74+
75+
return partitioned_model
76+
77+
78+
@td.register_backend
79+
@fake_tensor_unsupported
80+
def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs):
81+
try:
82+
trt_compiled = fx2trt(gm, example_inputs)
83+
return trt_compiled
84+
except Exception:
85+
traceback.print_exc()
86+
print(
87+
"FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead"
88+
)
89+
return gm.forward

0 commit comments

Comments
 (0)