Skip to content

Commit 2479300

Browse files
committed
feat: Add new convert_module function
- Improve overall documentation and commenting, improve code delineation and separation of functionality
1 parent 0c5befd commit 2479300

File tree

5 files changed

+155
-69
lines changed

5 files changed

+155
-69
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 6 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,17 @@
11
import torch
22
import logging
3+
from typing import Sequence, Any
34

45
from torch_tensorrt import EngineCapability, Device
56

6-
from torch_tensorrt.dynamo.lowering._partition import partition
77
from torch_tensorrt.dynamo import create_backend
88

9-
from torch_tensorrt.fx.fx2trt import (
10-
InputTensorSpec,
11-
TRTInterpreter,
12-
)
13-
import tensorrt as trt
14-
15-
from torch_tensorrt.fx.trt_module import TRTModule
16-
from torch_tensorrt.fx.utils import LowerPrecision
17-
189
logger = logging.getLogger(__name__)
1910

2011

2112
def compile(
2213
gm: torch.Module,
23-
example_inputs,
14+
example_inputs: Sequence[Any],
2415
*,
2516
device=Device._current_device(),
2617
disable_tf32=False,
@@ -30,7 +21,7 @@ def compile(
3021
debug=False,
3122
capability=EngineCapability.default,
3223
num_avg_timing_iters=1,
33-
workspace_size=0,
24+
workspace_size=20 << 30,
3425
dla_sram_size=1048576,
3526
dla_local_dram_size=1073741824,
3627
dla_global_dram_size=536870912,
@@ -63,52 +54,8 @@ def compile(
6354
)
6455

6556
model = torch.compile(gm, backend=custom_backend)
66-
# Ensure compilation
67-
model(example_inputs)
68-
69-
return model
70-
71-
72-
def compile_logic(gm: torch.fx.GraphModule, example_inputs):
73-
partitioned = partition(gm)
74-
75-
precision = LowerPrecision.FP32
76-
77-
def get_submod_inputs(mod, submod, inputs):
78-
"""Helper function to get inputs to submodule"""
79-
acc_inputs = None
8057

81-
def get_input(self, inputs):
82-
nonlocal acc_inputs
83-
acc_inputs = inputs
58+
# Ensure compilation occurs by calling the function with provided inputs
59+
model(*example_inputs)
8460

85-
handle = submod.register_forward_pre_hook(get_input)
86-
mod(*inputs)
87-
handle.remove()
88-
return acc_inputs
89-
90-
for name, _ in partitioned.named_children():
91-
submod = getattr(partitioned, name)
92-
93-
# Get submodule inputs
94-
acc_inputs = get_submod_inputs(partitioned, submod, example_inputs)
95-
96-
# Create TRT Module from submodule
97-
interp = TRTInterpreter(
98-
submod,
99-
InputTensorSpec.from_tensors(acc_inputs),
100-
explicit_batch_dimension=True,
101-
logger_level=trt.Logger.VERBOSE,
102-
)
103-
104-
r = interp.run(
105-
max_workspace_size=20 << 30,
106-
lower_precision=precision,
107-
profiling_verbosity=trt.ProfilingVerbosity.VERBOSE,
108-
)
109-
trt_mod = TRTModule(*r)
110-
111-
# Replace FX Module with TRT Module
112-
setattr(partitioned, name, trt_mod)
113-
114-
return partitioned
61+
return model

py/torch_tensorrt/dynamo/backends.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,22 @@
66
from torch_tensorrt import EngineCapability, Device
77
from torch_tensorrt.dynamo import compile
88

9+
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
10+
from torch_tensorrt.dynamo.lowering._partition import partition, get_submod_inputs
11+
from torch_tensorrt.dynamo.conversion import convert_module
12+
913
from torch._dynamo.backends.common import fake_tensor_unsupported
1014

1115
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
1216

13-
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
17+
from torch_tensorrt.fx.fx2trt import (
18+
InputTensorSpec,
19+
TRTInterpreter,
20+
)
21+
import tensorrt as trt
22+
23+
from torch_tensorrt.fx.trt_module import TRTModule
24+
from torch_tensorrt.fx.utils import LowerPrecision
1425

1526
logger = logging.getLogger(__name__)
1627

@@ -97,7 +108,7 @@ def fx_dynamo_backend(
97108
):
98109
"""Helper function to manage translation of FX module to TRT engines"""
99110
try:
100-
trt_compiled = compile(gm, example_inputs)
111+
trt_compiled = compile_module(gm, example_inputs)
101112
return trt_compiled
102113
except:
103114
traceback.print_exc()
@@ -106,3 +117,48 @@ def fx_dynamo_backend(
106117
+ "Returning GraphModule forward instead."
107118
)
108119
return gm.forward
120+
121+
122+
def compile_module(
123+
gm: torch.fx.GraphModule,
124+
example_inputs,
125+
debug: bool = False,
126+
workspace_size: int = 20 << 30,
127+
precision: LowerPrecision = LowerPrecision.FP32,
128+
) -> torch.fx.GraphModule:
129+
"""Convert an FX module to a TRT module
130+
Args:
131+
module: FX GraphModule to convert
132+
inputs: Inputs to the module
133+
debug: Whether to print out verbose debugging information
134+
workspace_size: Maximum workspace TRT is allowed to use for the module
135+
precision: Model Layer precision
136+
Returns:
137+
TRTModule or TRTModuleNext
138+
"""
139+
# Partition module into components that can be TRT-accelerated
140+
partitioned_module = partition(gm)
141+
142+
# Iterate over all components that can be accelerated
143+
# Generate the corresponding TRT Module for those
144+
for name, _ in partitioned_module.named_children():
145+
submodule = getattr(partitioned_module, name)
146+
147+
# Get submodule inputs
148+
submodule_inputs = get_submod_inputs(
149+
partitioned_module, submodule, example_inputs
150+
)
151+
152+
# Create TRT Module from submodule
153+
trt_mod = convert_module(
154+
submodule,
155+
submodule_inputs,
156+
debug=debug,
157+
workspace_size=workspace_size,
158+
precision=precision,
159+
)
160+
161+
# Replace FX Module with TRT Module
162+
setattr(partitioned_module, name, trt_mod)
163+
164+
return partitioned_module
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Sequence, Union
2+
import torch
3+
from torch_tensorrt.fx.trt_module import TRTModule
4+
from torch_tensorrt import TRTModuleNext
5+
from torch_tensorrt.fx.fx2trt import (
6+
InputTensorSpec,
7+
TRTInterpreter,
8+
)
9+
from torch_tensorrt.fx.utils import LowerPrecision
10+
11+
import tensorrt as trt
12+
13+
14+
def convert_module(
15+
module: torch.fx.GraphModule,
16+
inputs: Sequence[torch.Tensor],
17+
debug: bool = False,
18+
workspace_size: int = 20 << 30,
19+
precision: LowerPrecision = LowerPrecision.FP32,
20+
) -> Union[TRTModuleNext, TRTModule]:
21+
"""Convert an FX module to a TRT module
22+
Args:
23+
module: FX GraphModule to convert
24+
inputs: Sequence of Tensors representing inputs to the module
25+
debug: Whether to print out verbose debugging information
26+
workspace_size: Maximum workspace TRT is allowed to use for the module
27+
precision: Model Layer precision
28+
Returns:
29+
TRTModule or TRTModuleNext
30+
"""
31+
interp = TRTInterpreter(
32+
module,
33+
InputTensorSpec.from_tensors(inputs),
34+
explicit_batch_dimension=True,
35+
logger_level=(trt.Logger.VERBOSE if debug else trt.Logger.WARNING),
36+
)
37+
38+
r = interp.run(
39+
max_workspace_size=workspace_size,
40+
lower_precision=precision,
41+
profiling_verbosity=(
42+
trt.ProfilingVerbosity.VERBOSE
43+
if debug
44+
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
45+
),
46+
)
47+
48+
return TRTModule(*r)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
2-
from torch_tensorrt.dynamo.lowering._partition import partition
2+
from torch_tensorrt.dynamo.lowering._partition import partition, get_submod_inputs

py/torch_tensorrt/dynamo/lowering/_partition.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict
1+
from typing import Dict, Optional, Sequence
22

33
import torch
44

@@ -12,7 +12,7 @@
1212

1313

1414
class TorchTensorRTOperatorSupport(OperatorSupport):
15-
"""Class to determine whether the aten operators have converters"""
15+
"""Class to determine whether operators within a module are supported"""
1616

1717
def __init__(self, support_dict=None):
1818
super().__init__(support_dict)
@@ -38,7 +38,7 @@ def is_node_supported(
3838

3939
return False
4040

41-
def print_support_overview(self, num_trt_blocks=None):
41+
def print_support_overview(self, num_trt_blocks: Optional[int] = None):
4242
if num_trt_blocks is not None:
4343
print(f"Number of TensorRT-Accelerated Subgraphs: {num_trt_blocks}\n")
4444

@@ -51,9 +51,20 @@ def print_support_overview(self, num_trt_blocks=None):
5151
print(node_name)
5252

5353

54-
def partition(gm: torch.fx.GraphModule, verbose=True):
54+
def partition(
55+
gm: torch.fx.GraphModule,
56+
verbose: bool = True,
57+
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES,
58+
) -> torch.fx.GraphModule:
5559
"""Partition an FX GraphModule with aten ops into TRT engines
56-
Partitioning is based on operator support
60+
Partitioning is based on converter operator support
61+
62+
Args:
63+
gm: FX GraphModule to partition
64+
verbose: Bool representing whether to print operator support
65+
max_num_trt_engines: Maximum number of allowed TRT engines in partitioning
66+
Returns:
67+
torch.fx.GraphModule
5768
"""
5869
supported_ops = TorchTensorRTOperatorSupport()
5970
partitioner = CapabilityBasedPartitioner(gm, supported_ops)
@@ -62,10 +73,10 @@ def partition(gm: torch.fx.GraphModule, verbose=True):
6273
# exceeds a specified threshold
6374
partitions = partitioner.propose_partitions()
6475
num_blocks = len(partitions)
65-
if num_blocks > MAX_NUM_TRT_ENGINES:
76+
if num_blocks > max_num_trt_engines:
6677
raise AssertionError(
6778
f"The graph module has {num_blocks} TRT Engines which is larger than the "
68-
+ f"threshold={MAX_NUM_TRT_ENGINES}. Falling back to non-TRT module."
79+
+ f"threshold={max_num_trt_engines}. Falling back to non-TRT module."
6980
)
7081

7182
# Fuse partitions and display overview of supported/unsupported operators
@@ -76,3 +87,27 @@ def partition(gm: torch.fx.GraphModule, verbose=True):
7687
supported_ops.print_support_overview(num_blocks)
7788

7889
return fused_graph
90+
91+
92+
def get_submod_inputs(
93+
mod: torch.fx.GraphModule, submod: torch.fx.GraphModule, inputs
94+
) -> Sequence[torch.Tensor]:
95+
"""Helper function to get inputs to a Torch submodule
96+
97+
Args:
98+
mod: Parent FX GraphModule
99+
submod: Child FX GraphModule
100+
inputs: Sample inputs to parent module
101+
Returns:
102+
Sequence of Tensors representing inputs to child module
103+
"""
104+
acc_inputs = None
105+
106+
def get_input(self, inputs):
107+
nonlocal acc_inputs
108+
acc_inputs = inputs
109+
110+
handle = submod.register_forward_pre_hook(get_input)
111+
mod(*inputs)
112+
handle.remove()
113+
return acc_inputs

0 commit comments

Comments
 (0)