Skip to content

Commit e77c7d3

Browse files
gs-oliveapbose
authored andcommitted
feat/fix: Update dynamic unsupported implementation
- Add support for selecting individual argument positions to check and expand checking to include symbolic types, which are sometimes passed in as arguments
1 parent 49a55a5 commit e77c7d3

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -349,36 +349,25 @@ def aten_ops_softmax(
349349
return impl.normalization.softmax(
350350
network, target, SourceIR.ATEN, name, args[0], args[1]
351351
)
352-
353-
def dynamic_unsupported_split(node: torch.fx.Node) -> bool:
354-
# Validate that none of the inputs to the node have Dynamic shapes
355-
assert isinstance(
356-
node, torch.fx.Node
357-
), "Inputs to validator functions must be FX Nodes"
358-
359-
if isinstance(node.args[1], torch.fx.Node):
360-
if getattr(node.args[1].meta["val"], "_has_symbolic_sizes_strides", True):
361-
return False
362-
return True
363352

364353

365354
@dynamo_tensorrt_converter(
366-
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_split
355+
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1])
367356
)
368357
@dynamo_tensorrt_converter(
369-
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_split
358+
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1])
370359
)
371360
@dynamo_tensorrt_converter(
372361
torch.ops.aten.split_with_sizes.default,
373-
capability_validator=dynamic_unsupported_split,
362+
capability_validator=dynamic_unsupported_with_args([1]),
374363
)
375364
def aten_ops_split(
376365
network: TRTNetwork,
377366
target: Target,
378367
args: Tuple[Argument, ...],
379368
kwargs: Dict[str, Argument],
380369
name: str,
381-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
370+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
382371
return impl.split.split(
383372
network,
384373
target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import tensorrt as trt
88
import torch
9+
from torch import SymBool, SymFloat, SymInt
910
from torch.fx.node import Target
1011
from torch_tensorrt.fx.converters.converter_utils import (
1112
Frameworks,
@@ -46,26 +47,51 @@ def get_node_name(node: torch.fx.Node) -> str:
4647

4748

4849
def dynamic_unsupported(node: torch.fx.Node) -> bool:
50+
"""Validates that a node has no dynamic args, kwargs, or outputs"""
51+
return _dynamic_unsupported(node=node)
52+
53+
54+
def dynamic_unsupported_with_args(
55+
arg_positions_to_check: Optional[List[int]] = None,
56+
) -> Callable[[torch.fx.Node], bool]:
57+
"""Returns a validator that a node has no dynamic args at specific positions"""
58+
return functools.partial(_dynamic_unsupported, arg_positions_to_check=arg_positions_to_check)
59+
60+
61+
def _dynamic_unsupported(
62+
node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None
63+
) -> bool:
4964
# Validate that none of the inputs to the node have Dynamic shapes
5065
assert isinstance(
5166
node, torch.fx.Node
5267
), "Inputs to validator functions must be FX Nodes"
5368

69+
def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool:
70+
"""Checks if a node itself has Dynamic properties"""
71+
return getattr(
72+
subnode.meta["val"], "_has_symbolic_sizes_strides", False
73+
) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool))
74+
5475
# Check node value itself
55-
if getattr(node.meta["val"], "_has_symbolic_sizes_strides", False):
76+
if arg_positions_to_check is None and _is_subnode_dynamic(node):
5677
return False
5778

5879
# Check node arguments individually
59-
if any(
60-
getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
61-
for arg in node.args
62-
if isinstance(arg, torch.fx.Node)
80+
if arg_positions_to_check is None and any(
81+
_is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node)
82+
):
83+
return False
84+
# Check specific arg positions if the caller has specified positions to check
85+
elif arg_positions_to_check is not None and any(
86+
_is_subnode_dynamic(node.args[i])
87+
for i in arg_positions_to_check
88+
if isinstance(node.args[i], torch.fx.Node)
6389
):
6490
return False
6591

6692
# Check node keyword arguments individually
67-
if any(
68-
getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
93+
if arg_positions_to_check is None and any(
94+
_is_subnode_dynamic(kwarg)
6995
for kwarg in node.kwargs.values()
7096
if isinstance(kwarg, torch.fx.Node)
7197
):

0 commit comments

Comments
 (0)