3
3
import traceback
4
4
from functools import partial
5
5
import torch ._dynamo as td
6
- from torch_tensorrt import EngineCapability , Device
7
- from torch_tensorrt .dynamo import compile
8
6
7
+ from torch_tensorrt .dynamo ._defaults import (
8
+ PRECISION ,
9
+ DEBUG ,
10
+ MAX_WORKSPACE_SIZE ,
11
+ MAX_NUM_TRT_ENGINES ,
12
+ )
13
+ from torch_tensorrt .dynamo ._settings import CompilationSettings
9
14
from torch_tensorrt .dynamo .lowering ._decompositions import get_decompositions
10
15
from torch_tensorrt .dynamo .lowering ._partition import partition , get_submod_inputs
11
16
from torch_tensorrt .dynamo .conversion import convert_module
14
19
15
20
from torch ._functorch .aot_autograd import aot_module_simplified , make_boxed_compiler
16
21
17
- from torch_tensorrt .fx .fx2trt import (
18
- InputTensorSpec ,
19
- TRTInterpreter ,
20
- )
21
- import tensorrt as trt
22
-
23
- from torch_tensorrt .fx .trt_module import TRTModule
24
22
from torch_tensorrt .fx .utils import LowerPrecision
25
23
26
24
logger = logging .getLogger (__name__ )
27
25
28
26
29
27
def create_backend (
30
- input_signature = None ,
31
- device = Device ._current_device (),
32
- disable_tf32 = False ,
33
- sparse_weights = False ,
34
- enabled_precisions = set (),
35
- refit = False ,
36
- debug = False ,
37
- capability = EngineCapability .default ,
38
- num_avg_timing_iters = 1 ,
39
- workspace_size = 20 << 30 ,
40
- dla_sram_size = 1048576 ,
41
- dla_local_dram_size = 1073741824 ,
42
- dla_global_dram_size = 536870912 ,
43
- calibrator = None ,
44
- truncate_long_and_double = False ,
45
- require_full_compilation = False ,
46
- min_block_size = 3 ,
47
- torch_executed_ops = [],
48
- torch_executed_modules = [],
28
+ precision : LowerPrecision = PRECISION ,
29
+ debug : bool = DEBUG ,
30
+ workspace_size : int = MAX_WORKSPACE_SIZE ,
31
+ max_num_trt_engines : int = MAX_NUM_TRT_ENGINES ,
32
+ ** kwargs
49
33
):
50
- logger .warn (
51
- "The Dynamo backend is an experimental feature, for which the "
52
- + "following arguments are unsupported: "
53
- + "{input_signature, disable_tf32, sparse_weights, refit, capability, "
54
- + "num_avg_timing_iters, dla_sram_size, dla_local_dram_size, "
55
- + "dla_global_dram_size, calibrator, truncate_long_and_double, "
56
- + "require_full_compilation, min_block_size, torch_executed_ops, "
57
- + "torch_executed_modules}"
34
+ """Create torch.compile backend given specified arguments
35
+
36
+ Args:
37
+ precision:
38
+ debug: Whether to print out verbose debugging information
39
+ workspace_size: Maximum workspace TRT is allowed to use for the module
40
+ precision: Model Layer precision
41
+ Returns:
42
+ Backend for torch.compile
43
+ """
44
+ settings = CompilationSettings (
45
+ debug = debug ,
46
+ precision = precision ,
47
+ workspace_size = workspace_size ,
48
+ max_num_trt_engines = max_num_trt_engines ,
58
49
)
59
50
60
51
return partial (
61
52
tensorrt_backend ,
62
- debug = debug ,
63
- enabled_precisions = enabled_precisions ,
64
- device = device ,
65
- workspace_size = workspace_size ,
53
+ settings = settings ,
66
54
)
67
55
68
56
@@ -71,19 +59,12 @@ def create_backend(
71
59
def tensorrt_backend (
72
60
gm : torch .Module ,
73
61
sample_inputs ,
74
- * ,
75
- debug = False ,
76
- enabled_precisions = set (),
77
- device = Device ._current_device (),
78
- workspace_size = 20 << 30 ,
62
+ settings : CompilationSettings = CompilationSettings (),
79
63
):
80
64
81
65
custom_backend = partial (
82
66
fx_dynamo_backend ,
83
- debug = debug ,
84
- enabled_precisions = enabled_precisions ,
85
- device = device ,
86
- workspace_size = workspace_size ,
67
+ settings = settings ,
87
68
)
88
69
89
70
# Invoke AOTAutograd to translate operators to aten
@@ -100,15 +81,15 @@ def tensorrt_backend(
100
81
def fx_dynamo_backend (
101
82
gm : torch .fx .GraphModule ,
102
83
example_inputs ,
103
- * ,
104
- debug = False ,
105
- enabled_precisions = set (),
106
- device = Device ._current_device (),
107
- workspace_size = 20 << 30 ,
84
+ settings : CompilationSettings = CompilationSettings (),
108
85
):
109
86
"""Helper function to manage translation of FX module to TRT engines"""
110
87
try :
111
- trt_compiled = compile_module (gm , example_inputs )
88
+ trt_compiled = compile_module (
89
+ gm ,
90
+ example_inputs ,
91
+ settings = settings ,
92
+ )
112
93
return trt_compiled
113
94
except :
114
95
traceback .print_exc ()
@@ -122,22 +103,23 @@ def fx_dynamo_backend(
122
103
def compile_module (
123
104
gm : torch .fx .GraphModule ,
124
105
example_inputs ,
125
- debug : bool = False ,
126
- workspace_size : int = 20 << 30 ,
127
- precision : LowerPrecision = LowerPrecision .FP32 ,
106
+ settings : CompilationSettings = CompilationSettings (),
128
107
) -> torch .fx .GraphModule :
129
- """Convert an FX module to a TRT module
108
+ """Compile an FX module
109
+
110
+ Includes: Partitioning + Conversion Phases
111
+
130
112
Args:
131
113
module: FX GraphModule to convert
132
114
inputs: Inputs to the module
133
- debug: Whether to print out verbose debugging information
134
- workspace_size: Maximum workspace TRT is allowed to use for the module
135
- precision: Model Layer precision
115
+ settings: Compilation settings
136
116
Returns:
137
- TRTModule or TRTModuleNext
117
+ Compiled FX GraphModule
138
118
"""
139
119
# Partition module into components that can be TRT-accelerated
140
- partitioned_module = partition (gm )
120
+ partitioned_module = partition (
121
+ gm , verbose = settings .debug , max_num_trt_engines = settings .max_num_trt_engines
122
+ )
141
123
142
124
# Iterate over all components that can be accelerated
143
125
# Generate the corresponding TRT Module for those
@@ -153,9 +135,9 @@ def compile_module(
153
135
trt_mod = convert_module (
154
136
submodule ,
155
137
submodule_inputs ,
156
- debug = debug ,
157
- workspace_size = workspace_size ,
158
- precision = precision ,
138
+ debug = settings . debug ,
139
+ workspace_size = settings . workspace_size ,
140
+ precision = settings . precision ,
159
141
)
160
142
161
143
# Replace FX Module with TRT Module
0 commit comments