Skip to content

Commit 722457b

Browse files
authored
feat: Add validators for dynamic shapes in converter registration (#2796)
1 parent 2a4c37b commit 722457b

File tree

8 files changed

+290
-88
lines changed

8 files changed

+290
-88
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def compile(
4747
*,
4848
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
4949
disable_tf32: bool = _defaults.DISABLE_TF32,
50+
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
5051
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
5152
enabled_precisions: (
5253
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
@@ -106,6 +107,7 @@ def compile(
106107
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
107108
108109
disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
110+
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
109111
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
110112
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
111113
refit (bool): Enable refitting
@@ -189,6 +191,7 @@ def compile(
189191
),
190192
"debug": debug,
191193
"device": device,
194+
"assume_dynamic_shape_support": assume_dynamic_shape_support,
192195
"workspace_size": workspace_size,
193196
"min_block_size": min_block_size,
194197
"torch_executed_ops": (
@@ -239,6 +242,9 @@ def compile_module(
239242
"""
240243
dryrun_tracker = DryRunTracker()
241244

245+
# Assume converters support dynamic shapes and disable validation
246+
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)
247+
242248
# Set torch-executed ops
243249
CONVERTERS.set_disallowed_targets(settings.torch_executed_ops)
244250

@@ -443,6 +449,7 @@ def convert_module_to_trt_engine(
443449
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
444450
) = _defaults.ENABLED_PRECISIONS,
445451
debug: bool = _defaults.DEBUG,
452+
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
446453
workspace_size: int = _defaults.WORKSPACE_SIZE,
447454
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
448455
torch_executed_ops: Optional[Set[str]] = None,
@@ -550,6 +557,7 @@ def convert_module_to_trt_engine(
550557
enabled_precisions = {dtype._from(e) for e in enabled_precisions}
551558

552559
compilation_options = {
560+
"assume_dynamic_shape_support": assume_dynamic_shape_support,
553561
"enabled_precisions": enabled_precisions,
554562
"debug": debug,
555563
"workspace_size": workspace_size,
@@ -589,6 +597,10 @@ def convert_module_to_trt_engine(
589597

590598
settings = CompilationSettings(**compilation_options)
591599
logger.info("Compilation Settings: %s\n", settings)
600+
601+
# Assume converters support dynamic shapes and disable validation
602+
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)
603+
592604
try:
593605
interpreter_result = interpret_module_to_result(gm, input_list, settings)
594606
except UnsupportedOperatorException:

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
DEBUG = False
77
DEVICE = None
88
DISABLE_TF32 = False
9+
ASSUME_DYNAMIC_SHAPE_SUPPORT = False
910
DLA_LOCAL_DRAM_SIZE = 1073741824
1011
DLA_GLOBAL_DRAM_SIZE = 536870912
1112
DLA_SRAM_SIZE = 1048576

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch_tensorrt._Device import Device
66
from torch_tensorrt._enums import EngineCapability, dtype
77
from torch_tensorrt.dynamo._defaults import (
8+
ASSUME_DYNAMIC_SHAPE_SUPPORT,
89
DEBUG,
910
DISABLE_TF32,
1011
DLA_GLOBAL_DRAM_SIZE,
@@ -57,6 +58,7 @@ class CompilationSettings:
5758
device (Device): GPU to compile the model on
5859
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
5960
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
61+
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
6062
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
6163
sparse_weights (bool): Whether to allow the builder to use sparse weights
6264
refit (bool): Whether to build a refittable engine
@@ -87,6 +89,7 @@ class CompilationSettings:
8789
device: Device = field(default_factory=default_device)
8890
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
8991
disable_tf32: bool = DISABLE_TF32
92+
assume_dynamic_shape_support: bool = ASSUME_DYNAMIC_SHAPE_SUPPORT
9093
sparse_weights: bool = SPARSE_WEIGHTS
9194
refit: bool = REFIT
9295
engine_capability: EngineCapability = field(

py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import functools
34
import logging
45
from dataclasses import dataclass, field
56
from enum import Enum, auto
@@ -17,13 +18,14 @@
1718
cast,
1819
)
1920

21+
import tensorrt as trt
22+
import torch
23+
from torch import SymBool, SymFloat, SymInt
2024
from torch._ops import OpOverloadPacket
2125
from torch.fx.node import Argument, Node, Target, _get_qualified_name
2226
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
2327
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS
2428

25-
import tensorrt as trt
26-
2729
logger = logging.getLogger(__name__)
2830

2931
LegacyConverterImplSignature = Callable[
@@ -76,22 +78,119 @@ class ConverterSupport:
7678
capability_validator: Function which takes in a Node and returns a bool indicating
7779
whether that node can be supported by its companion converter. Note that
7880
this function must not modify the node or its graph
81+
supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs.
7982
"""
8083

8184
converter_implementation: ConverterImplSignature
8285
capability_validator: Callable[[Node], bool] = field(default=lambda node: True)
86+
supports_dynamic_shapes: bool = False
8387

8488

8589
# Dictionary representing Dynamo aten-only converters
8690
# Each converter maps to a sequence of at least one ConverterSupport object(s)
8791
DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {}
8892

8993

94+
def has_static_shapes(node: torch.fx.Node) -> bool:
95+
"""Returns True if a node has static args, kwargs, or outputs"""
96+
return not _has_dynamic_shapes(node=node)
97+
98+
99+
def has_dynamic_shapes(node: torch.fx.Node) -> bool:
100+
"""Returns True if a node has dynamic args, kwargs, or outputs"""
101+
return _has_dynamic_shapes(node=node)
102+
103+
104+
def has_dynamic_shapes_in_args(
105+
arg_positions_to_check: Optional[List[int]] = None,
106+
) -> Callable[[torch.fx.Node], bool]:
107+
"""Returns True if a node has dynamic inputs in node.args at specified positions"""
108+
return functools.partial(
109+
_has_dynamic_shapes, arg_positions_to_check=arg_positions_to_check
110+
)
111+
112+
113+
def has_static_shapes_in_args(
114+
arg_positions_to_check: Optional[List[int]] = None,
115+
) -> Callable[[torch.fx.Node], bool]:
116+
"""Returns True if a node has static inputs in node.args at specified positions"""
117+
_has_static_shapes = lambda node, arg_positions_to_check: not _has_dynamic_shapes(
118+
node, arg_positions_to_check
119+
)
120+
return functools.partial(
121+
_has_static_shapes, arg_positions_to_check=arg_positions_to_check
122+
)
123+
124+
125+
def _has_dynamic_shapes(
126+
node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None
127+
) -> bool:
128+
# Validate that none of the inputs to the node have Dynamic shapes
129+
assert isinstance(
130+
node, torch.fx.Node
131+
), "Inputs to validator functions must be FX Nodes"
132+
133+
def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool:
134+
"""Checks if a node itself has Dynamic properties"""
135+
_has_symbolic_sizes_strides, is_shape_dynamic = False, False
136+
if "val" in subnode.meta:
137+
_has_symbolic_sizes_strides = getattr(
138+
subnode.meta["val"], "_has_symbolic_sizes_strides", False
139+
)
140+
meta_val = subnode.meta["val"]
141+
if isinstance(meta_val, (list, tuple)):
142+
for val in meta_val:
143+
shape = val.size()
144+
if any(
145+
isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape
146+
):
147+
is_shape_dynamic = True
148+
break
149+
elif isinstance(meta_val, (SymFloat, SymInt, SymBool)):
150+
is_shape_dynamic = True
151+
else:
152+
shape = subnode.meta["val"].size()
153+
is_shape_dynamic = any(
154+
isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape
155+
)
156+
157+
return _has_symbolic_sizes_strides or is_shape_dynamic
158+
159+
# Check node value itself
160+
if arg_positions_to_check is None and _is_subnode_dynamic(node):
161+
return True
162+
163+
# Check node arguments individually
164+
if arg_positions_to_check is None and any(
165+
_is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node)
166+
):
167+
return True
168+
# Check specific arg positions if the caller has specified positions to check
169+
elif arg_positions_to_check is not None and any(
170+
_is_subnode_dynamic(node.args[i])
171+
for i in arg_positions_to_check
172+
if isinstance(node.args[i], torch.fx.Node)
173+
):
174+
return True
175+
176+
# Check node keyword arguments individually
177+
if arg_positions_to_check is None and any(
178+
_is_subnode_dynamic(kwarg)
179+
for kwarg in node.kwargs.values()
180+
if isinstance(kwarg, torch.fx.Node)
181+
):
182+
return True
183+
184+
return False
185+
186+
90187
def dynamo_tensorrt_converter(
91188
key: Target,
189+
*,
92190
enabled: bool = True,
93191
capability_validator: Optional[Callable[[Node], bool]] = None,
94192
priority: ConverterPriority = ConverterPriority.STANDARD,
193+
supports_dynamic_shapes: bool = False,
95194
) -> Callable[[ConverterImplSignature], ConverterImplSignature]:
96195
"""Decorator for Dynamo TensorRT Converter
97196
@@ -117,14 +216,18 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat
117216

118217
# If no capability_validator function is specified, use the default function - always return true
119218
if capability_validator is None:
120-
converter_support = ConverterSupport(converter_implementation=converter)
219+
converter_support = ConverterSupport(
220+
converter_implementation=converter,
221+
supports_dynamic_shapes=supports_dynamic_shapes,
222+
)
121223
else:
122224
assert callable(
123225
capability_validator
124226
), "Argument checking function must be callable"
125227
converter_support = ConverterSupport(
126228
converter_implementation=converter,
127229
capability_validator=capability_validator,
230+
supports_dynamic_shapes=supports_dynamic_shapes,
128231
)
129232

130233
# OpOverloadPackets are only valid if they have a single overload, or
@@ -194,6 +297,7 @@ def __init__(
194297
],
195298
registry_names: Optional[Sequence[str]] = None,
196299
registry_calling_conventions: Optional[Sequence[CallingConvention]] = None,
300+
assume_dynamic_shape_support: bool = False,
197301
):
198302
# Copy reference to each dictionary object into attribute list
199303
self.registries = list(registries)
@@ -215,9 +319,12 @@ def __init__(
215319
]
216320

217321
self.disallowed_targets: Collection[Target] = set()
218-
322+
self.assume_dynamic_shape_support = assume_dynamic_shape_support
219323
self.validate_invariants()
220324

325+
def set_dynamic_shape_support(self, assume_dynamic_shape_support: bool) -> None:
326+
self.assume_dynamic_shape_support = assume_dynamic_shape_support
327+
221328
def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None:
222329
self.disallowed_targets = torch_executed_ops
223330

@@ -324,13 +431,24 @@ def __getitem__(
324431

325432
if isinstance(converters, (list, tuple)):
326433
for candidate in converters:
327-
if candidate.capability_validator(node):
434+
# We enable the converter under 4 conditions
435+
# 1) capability validator is True
436+
# 2) Assume dynamic_shape support is True
437+
# 3) Node only has static shaped inputs
438+
# 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True
439+
if candidate.capability_validator(node) and (
440+
self.assume_dynamic_shape_support
441+
or not has_dynamic_shapes(node)
442+
or candidate.supports_dynamic_shapes
443+
):
328444
return (
329445
candidate.converter_implementation,
330446
calling_convention,
331447
)
332448
else:
333-
return converters, calling_convention
449+
# Assuming FX converters don't have dynamic shapes supported
450+
if not has_dynamic_shapes(node):
451+
return converters, calling_convention
334452

335453
raise KeyError(
336454
f"None of the converter registries have a validated entry for {key}, with node {node}"

0 commit comments

Comments
 (0)