1
1
from __future__ import annotations
2
2
3
+ import functools
3
4
import logging
4
5
from dataclasses import dataclass , field
5
6
from enum import Enum , auto
17
18
cast ,
18
19
)
19
20
21
+ import tensorrt as trt
22
+ import torch
23
+ from torch import SymBool , SymFloat , SymInt
20
24
from torch ._ops import OpOverloadPacket
21
25
from torch .fx .node import Argument , Node , Target , _get_qualified_name
22
26
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
23
27
from torch_tensorrt .fx .converter_registry import CONVERTERS as FX_CONVERTERS
24
28
25
- import tensorrt as trt
26
-
27
29
logger = logging .getLogger (__name__ )
28
30
29
31
LegacyConverterImplSignature = Callable [
@@ -76,22 +78,119 @@ class ConverterSupport:
76
78
capability_validator: Function which takes in a Node and returns a bool indicating
77
79
whether that node can be supported by its companion converter. Note that
78
80
this function must not modify the node or its graph
81
+ supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs.
79
82
"""
80
83
81
84
converter_implementation : ConverterImplSignature
82
85
capability_validator : Callable [[Node ], bool ] = field (default = lambda node : True )
86
+ supports_dynamic_shapes : bool = False
83
87
84
88
85
89
# Dictionary representing Dynamo aten-only converters
86
90
# Each converter maps to a sequence of at least one ConverterSupport object(s)
87
91
DYNAMO_ATEN_CONVERTERS : Dict [Target , Sequence [ConverterSupport ]] = {}
88
92
89
93
94
+ def has_static_shapes (node : torch .fx .Node ) -> bool :
95
+ """Returns True if a node has static args, kwargs, or outputs"""
96
+ return not _has_dynamic_shapes (node = node )
97
+
98
+
99
+ def has_dynamic_shapes (node : torch .fx .Node ) -> bool :
100
+ """Returns True if a node has dynamic args, kwargs, or outputs"""
101
+ return _has_dynamic_shapes (node = node )
102
+
103
+
104
+ def has_dynamic_shapes_in_args (
105
+ arg_positions_to_check : Optional [List [int ]] = None ,
106
+ ) -> Callable [[torch .fx .Node ], bool ]:
107
+ """Returns True if a node has dynamic inputs in node.args at specified positions"""
108
+ return functools .partial (
109
+ _has_dynamic_shapes , arg_positions_to_check = arg_positions_to_check
110
+ )
111
+
112
+
113
+ def has_static_shapes_in_args (
114
+ arg_positions_to_check : Optional [List [int ]] = None ,
115
+ ) -> Callable [[torch .fx .Node ], bool ]:
116
+ """Returns True if a node has static inputs in node.args at specified positions"""
117
+ _has_static_shapes = lambda node , arg_positions_to_check : not _has_dynamic_shapes (
118
+ node , arg_positions_to_check
119
+ )
120
+ return functools .partial (
121
+ _has_static_shapes , arg_positions_to_check = arg_positions_to_check
122
+ )
123
+
124
+
125
+ def _has_dynamic_shapes (
126
+ node : torch .fx .Node , arg_positions_to_check : Optional [List [int ]] = None
127
+ ) -> bool :
128
+ # Validate that none of the inputs to the node have Dynamic shapes
129
+ assert isinstance (
130
+ node , torch .fx .Node
131
+ ), "Inputs to validator functions must be FX Nodes"
132
+
133
+ def _is_subnode_dynamic (subnode : torch .fx .Node ) -> bool :
134
+ """Checks if a node itself has Dynamic properties"""
135
+ _has_symbolic_sizes_strides , is_shape_dynamic = False , False
136
+ if "val" in subnode .meta :
137
+ _has_symbolic_sizes_strides = getattr (
138
+ subnode .meta ["val" ], "_has_symbolic_sizes_strides" , False
139
+ )
140
+ meta_val = subnode .meta ["val" ]
141
+ if isinstance (meta_val , (list , tuple )):
142
+ for val in meta_val :
143
+ shape = val .size ()
144
+ if any (
145
+ isinstance (dim , (SymFloat , SymInt , SymBool )) for dim in shape
146
+ ):
147
+ is_shape_dynamic = True
148
+ break
149
+ elif isinstance (meta_val , (SymFloat , SymInt , SymBool )):
150
+ is_shape_dynamic = True
151
+ else :
152
+ shape = subnode .meta ["val" ].size ()
153
+ is_shape_dynamic = any (
154
+ isinstance (dim , (SymFloat , SymInt , SymBool )) for dim in shape
155
+ )
156
+
157
+ return _has_symbolic_sizes_strides or is_shape_dynamic
158
+
159
+ # Check node value itself
160
+ if arg_positions_to_check is None and _is_subnode_dynamic (node ):
161
+ return True
162
+
163
+ # Check node arguments individually
164
+ if arg_positions_to_check is None and any (
165
+ _is_subnode_dynamic (arg ) for arg in node .args if isinstance (arg , torch .fx .Node )
166
+ ):
167
+ return True
168
+ # Check specific arg positions if the caller has specified positions to check
169
+ elif arg_positions_to_check is not None and any (
170
+ _is_subnode_dynamic (node .args [i ])
171
+ for i in arg_positions_to_check
172
+ if isinstance (node .args [i ], torch .fx .Node )
173
+ ):
174
+ return True
175
+
176
+ # Check node keyword arguments individually
177
+ if arg_positions_to_check is None and any (
178
+ _is_subnode_dynamic (kwarg )
179
+ for kwarg in node .kwargs .values ()
180
+ if isinstance (kwarg , torch .fx .Node )
181
+ ):
182
+ return True
183
+
184
+ return False
185
+
186
+
90
187
def dynamo_tensorrt_converter (
91
188
key : Target ,
189
+ * ,
92
190
enabled : bool = True ,
93
191
capability_validator : Optional [Callable [[Node ], bool ]] = None ,
94
192
priority : ConverterPriority = ConverterPriority .STANDARD ,
193
+ supports_dynamic_shapes : bool = False ,
95
194
) -> Callable [[ConverterImplSignature ], ConverterImplSignature ]:
96
195
"""Decorator for Dynamo TensorRT Converter
97
196
@@ -117,14 +216,18 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat
117
216
118
217
# If no capability_validator function is specified, use the default function - always return true
119
218
if capability_validator is None :
120
- converter_support = ConverterSupport (converter_implementation = converter )
219
+ converter_support = ConverterSupport (
220
+ converter_implementation = converter ,
221
+ supports_dynamic_shapes = supports_dynamic_shapes ,
222
+ )
121
223
else :
122
224
assert callable (
123
225
capability_validator
124
226
), "Argument checking function must be callable"
125
227
converter_support = ConverterSupport (
126
228
converter_implementation = converter ,
127
229
capability_validator = capability_validator ,
230
+ supports_dynamic_shapes = supports_dynamic_shapes ,
128
231
)
129
232
130
233
# OpOverloadPackets are only valid if they have a single overload, or
@@ -194,6 +297,7 @@ def __init__(
194
297
],
195
298
registry_names : Optional [Sequence [str ]] = None ,
196
299
registry_calling_conventions : Optional [Sequence [CallingConvention ]] = None ,
300
+ assume_dynamic_shape_support : bool = False ,
197
301
):
198
302
# Copy reference to each dictionary object into attribute list
199
303
self .registries = list (registries )
@@ -215,9 +319,12 @@ def __init__(
215
319
]
216
320
217
321
self .disallowed_targets : Collection [Target ] = set ()
218
-
322
+ self . assume_dynamic_shape_support = assume_dynamic_shape_support
219
323
self .validate_invariants ()
220
324
325
+ def set_dynamic_shape_support (self , assume_dynamic_shape_support : bool ) -> None :
326
+ self .assume_dynamic_shape_support = assume_dynamic_shape_support
327
+
221
328
def set_disallowed_targets (self , torch_executed_ops : Collection [Target ]) -> None :
222
329
self .disallowed_targets = torch_executed_ops
223
330
@@ -324,13 +431,24 @@ def __getitem__(
324
431
325
432
if isinstance (converters , (list , tuple )):
326
433
for candidate in converters :
327
- if candidate .capability_validator (node ):
434
+ # We enable the converter under 4 conditions
435
+ # 1) capability validator is True
436
+ # 2) Assume dynamic_shape support is True
437
+ # 3) Node only has static shaped inputs
438
+ # 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True
439
+ if candidate .capability_validator (node ) and (
440
+ self .assume_dynamic_shape_support
441
+ or not has_dynamic_shapes (node )
442
+ or candidate .supports_dynamic_shapes
443
+ ):
328
444
return (
329
445
candidate .converter_implementation ,
330
446
calling_convention ,
331
447
)
332
448
else :
333
- return converters , calling_convention
449
+ # Assuming FX converters don't have dynamic shapes supported
450
+ if not has_dynamic_shapes (node ):
451
+ return converters , calling_convention
334
452
335
453
raise KeyError (
336
454
f"None of the converter registries have a validated entry for { key } , with node { node } "
0 commit comments