Skip to content

Commit eea3884

Browse files
committed
fix: Improve torch_tensorrt Dynamo path
- Add dedicated settings and defaults files to centralize data and improve code readability, as well as reduce duplication of code - Improve documentation of functions, types, and comments - Rework logic to make compiler more uniform with existing torch tensorrt compilers, while retaining key Dynamo keywords needed for compilation via the torch.compile path
1 parent 2479300 commit eea3884

File tree

5 files changed

+114
-90
lines changed

5 files changed

+114
-90
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import torch
22
import logging
3-
from typing import Sequence, Any
3+
import torch_tensorrt
44

5+
from typing import Sequence, Any
56
from torch_tensorrt import EngineCapability, Device
6-
77
from torch_tensorrt.dynamo import create_backend
8+
from torch_tensorrt.fx.utils import LowerPrecision
9+
10+
from torch_tensorrt.dynamo._defaults import (
11+
PRECISION,
12+
DEBUG,
13+
MAX_WORKSPACE_SIZE,
14+
)
15+
816

917
logger = logging.getLogger(__name__)
1018

@@ -18,10 +26,10 @@ def compile(
1826
sparse_weights=False,
1927
enabled_precisions=set(),
2028
refit=False,
21-
debug=False,
29+
debug=DEBUG,
2230
capability=EngineCapability.default,
2331
num_avg_timing_iters=1,
24-
workspace_size=20 << 30,
32+
workspace_size=MAX_WORKSPACE_SIZE,
2533
dla_sram_size=1048576,
2634
dla_local_dram_size=1073741824,
2735
dla_global_dram_size=536870912,
@@ -31,26 +39,38 @@ def compile(
3139
min_block_size=3,
3240
torch_executed_ops=[],
3341
torch_executed_modules=[],
42+
**kwargs,
3443
):
44+
45+
logger.warn(
46+
"The Dynamo backend is an experimental feature, for which only the "
47+
+ "following arguments are supported: "
48+
+ "{enabled_precisions, debug, workspace_size, max_num_trt_engines}"
49+
)
50+
51+
if (
52+
torch.float16 in enabled_precisions
53+
or torch_tensorrt.dtype.half in enabled_precisions
54+
):
55+
lower_precision = LowerPrecision.FP16
56+
elif (
57+
torch.float32 in enabled_precisions
58+
or torch_tensorrt.dtype.float in enabled_precisions
59+
):
60+
lower_precision = LowerPrecision.FP32
61+
elif len(enabled_precisions) == 0:
62+
logger.info(f"No precision specified, defaulting to {PRECISION}")
63+
lower_precision = PRECISION
64+
else:
65+
raise ValueError(
66+
f"Precision {enabled_precisions} not supported in the Dynamo Path"
67+
)
68+
3569
custom_backend = create_backend(
36-
device=device,
37-
disable_tf32=disable_tf32,
38-
sparse_weights=sparse_weights,
39-
enabled_precisions=enabled_precisions,
40-
refit=refit,
70+
precision=lower_precision,
4171
debug=debug,
42-
capability=capability,
43-
num_avg_timing_iters=num_avg_timing_iters,
4472
workspace_size=workspace_size,
45-
dla_sram_size=dla_sram_size,
46-
dla_local_dram_size=dla_local_dram_size,
47-
dla_global_dram_size=dla_global_dram_size,
48-
calibrator=calibrator,
49-
truncate_long_and_double=truncate_long_and_double,
50-
require_full_compilation=require_full_compilation,
51-
min_block_size=min_block_size,
52-
torch_executed_ops=torch_executed_ops,
53-
torch_executed_modules=torch_executed_modules,
73+
**kwargs,
5474
)
5575

5676
model = torch.compile(gm, backend=custom_backend)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from torch_tensorrt.fx.utils import LowerPrecision
2+
3+
4+
PRECISION = LowerPrecision.FP32
5+
DEBUG = False
6+
MAX_WORKSPACE_SIZE = 20 << 30
7+
MAX_NUM_TRT_ENGINES = 10

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from dataclasses import dataclass
2+
3+
from torch_tensorrt.fx.utils import LowerPrecision
4+
from torch_tensorrt.dynamo._defaults import (
5+
PRECISION,
6+
DEBUG,
7+
MAX_WORKSPACE_SIZE,
8+
MAX_NUM_TRT_ENGINES,
9+
)
10+
11+
12+
@dataclass(frozen=True)
13+
class CompilationSettings:
14+
precision: LowerPrecision = (PRECISION,)
15+
debug: bool = (DEBUG,)
16+
workspace_size: int = (MAX_WORKSPACE_SIZE,)
17+
max_num_trt_engines: int = (MAX_NUM_TRT_ENGINES,)

py/torch_tensorrt/dynamo/backends.py

Lines changed: 49 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,14 @@
33
import traceback
44
from functools import partial
55
import torch._dynamo as td
6-
from torch_tensorrt import EngineCapability, Device
7-
from torch_tensorrt.dynamo import compile
86

7+
from torch_tensorrt.dynamo._defaults import (
8+
PRECISION,
9+
DEBUG,
10+
MAX_WORKSPACE_SIZE,
11+
MAX_NUM_TRT_ENGINES,
12+
)
13+
from torch_tensorrt.dynamo._settings import CompilationSettings
914
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
1015
from torch_tensorrt.dynamo.lowering._partition import partition, get_submod_inputs
1116
from torch_tensorrt.dynamo.conversion import convert_module
@@ -14,55 +19,38 @@
1419

1520
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
1621

17-
from torch_tensorrt.fx.fx2trt import (
18-
InputTensorSpec,
19-
TRTInterpreter,
20-
)
21-
import tensorrt as trt
22-
23-
from torch_tensorrt.fx.trt_module import TRTModule
2422
from torch_tensorrt.fx.utils import LowerPrecision
2523

2624
logger = logging.getLogger(__name__)
2725

2826

2927
def create_backend(
30-
input_signature=None,
31-
device=Device._current_device(),
32-
disable_tf32=False,
33-
sparse_weights=False,
34-
enabled_precisions=set(),
35-
refit=False,
36-
debug=False,
37-
capability=EngineCapability.default,
38-
num_avg_timing_iters=1,
39-
workspace_size=20 << 30,
40-
dla_sram_size=1048576,
41-
dla_local_dram_size=1073741824,
42-
dla_global_dram_size=536870912,
43-
calibrator=None,
44-
truncate_long_and_double=False,
45-
require_full_compilation=False,
46-
min_block_size=3,
47-
torch_executed_ops=[],
48-
torch_executed_modules=[],
28+
precision: LowerPrecision = PRECISION,
29+
debug: bool = DEBUG,
30+
workspace_size: int = MAX_WORKSPACE_SIZE,
31+
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES,
32+
**kwargs
4933
):
50-
logger.warn(
51-
"The Dynamo backend is an experimental feature, for which the "
52-
+ "following arguments are unsupported: "
53-
+ "{input_signature, disable_tf32, sparse_weights, refit, capability, "
54-
+ "num_avg_timing_iters, dla_sram_size, dla_local_dram_size, "
55-
+ "dla_global_dram_size, calibrator, truncate_long_and_double, "
56-
+ "require_full_compilation, min_block_size, torch_executed_ops, "
57-
+ "torch_executed_modules}"
34+
"""Create torch.compile backend given specified arguments
35+
36+
Args:
37+
precision:
38+
debug: Whether to print out verbose debugging information
39+
workspace_size: Maximum workspace TRT is allowed to use for the module
40+
precision: Model Layer precision
41+
Returns:
42+
Backend for torch.compile
43+
"""
44+
settings = CompilationSettings(
45+
debug=debug,
46+
precision=precision,
47+
workspace_size=workspace_size,
48+
max_num_trt_engines=max_num_trt_engines,
5849
)
5950

6051
return partial(
6152
tensorrt_backend,
62-
debug=debug,
63-
enabled_precisions=enabled_precisions,
64-
device=device,
65-
workspace_size=workspace_size,
53+
settings=settings,
6654
)
6755

6856

@@ -71,19 +59,12 @@ def create_backend(
7159
def tensorrt_backend(
7260
gm: torch.Module,
7361
sample_inputs,
74-
*,
75-
debug=False,
76-
enabled_precisions=set(),
77-
device=Device._current_device(),
78-
workspace_size=20 << 30,
62+
settings: CompilationSettings = CompilationSettings(),
7963
):
8064

8165
custom_backend = partial(
8266
fx_dynamo_backend,
83-
debug=debug,
84-
enabled_precisions=enabled_precisions,
85-
device=device,
86-
workspace_size=workspace_size,
67+
settings=settings,
8768
)
8869

8970
# Invoke AOTAutograd to translate operators to aten
@@ -100,15 +81,15 @@ def tensorrt_backend(
10081
def fx_dynamo_backend(
10182
gm: torch.fx.GraphModule,
10283
example_inputs,
103-
*,
104-
debug=False,
105-
enabled_precisions=set(),
106-
device=Device._current_device(),
107-
workspace_size=20 << 30,
84+
settings: CompilationSettings = CompilationSettings(),
10885
):
10986
"""Helper function to manage translation of FX module to TRT engines"""
11087
try:
111-
trt_compiled = compile_module(gm, example_inputs)
88+
trt_compiled = compile_module(
89+
gm,
90+
example_inputs,
91+
settings=settings,
92+
)
11293
return trt_compiled
11394
except:
11495
traceback.print_exc()
@@ -122,22 +103,23 @@ def fx_dynamo_backend(
122103
def compile_module(
123104
gm: torch.fx.GraphModule,
124105
example_inputs,
125-
debug: bool = False,
126-
workspace_size: int = 20 << 30,
127-
precision: LowerPrecision = LowerPrecision.FP32,
106+
settings: CompilationSettings = CompilationSettings(),
128107
) -> torch.fx.GraphModule:
129-
"""Convert an FX module to a TRT module
108+
"""Compile an FX module
109+
110+
Includes: Partitioning + Conversion Phases
111+
130112
Args:
131113
module: FX GraphModule to convert
132114
inputs: Inputs to the module
133-
debug: Whether to print out verbose debugging information
134-
workspace_size: Maximum workspace TRT is allowed to use for the module
135-
precision: Model Layer precision
115+
settings: Compilation settings
136116
Returns:
137-
TRTModule or TRTModuleNext
117+
Compiled FX GraphModule
138118
"""
139119
# Partition module into components that can be TRT-accelerated
140-
partitioned_module = partition(gm)
120+
partitioned_module = partition(
121+
gm, verbose=settings.debug, max_num_trt_engines=settings.max_num_trt_engines
122+
)
141123

142124
# Iterate over all components that can be accelerated
143125
# Generate the corresponding TRT Module for those
@@ -153,9 +135,9 @@ def compile_module(
153135
trt_mod = convert_module(
154136
submodule,
155137
submodule_inputs,
156-
debug=debug,
157-
workspace_size=workspace_size,
158-
precision=precision,
138+
debug=settings.debug,
139+
workspace_size=settings.workspace_size,
140+
precision=settings.precision,
159141
)
160142

161143
# Replace FX Module with TRT Module

py/torch_tensorrt/dynamo/lowering/_partition.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@
22

33
import torch
44

5+
from torch_tensorrt.dynamo._defaults import MAX_NUM_TRT_ENGINES
56
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
67
from torch.fx.passes.operator_support import OperatorSupport
78

89
from torch_tensorrt.fx.converter_registry import CONVERTERS
910

1011

11-
MAX_NUM_TRT_ENGINES = 10
12-
13-
1412
class TorchTensorRTOperatorSupport(OperatorSupport):
1513
"""Class to determine whether operators within a module are supported"""
1614

0 commit comments

Comments
 (0)