Skip to content

Commit e5b7120

Browse files
gs-oliveapbose
authored andcommitted
feat/fix: Update dynamic unsupported implementation
- Add support for selecting individual argument positions to check and expand checking to include symbolic types, which are sometimes passed in as arguments
1 parent 3f4a2e8 commit e5b7120

File tree

2 files changed

+38
-31
lines changed

2 files changed

+38
-31
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -352,36 +352,25 @@ def aten_ops_softmax(
352352
return impl.normalization.softmax(
353353
network, target, SourceIR.ATEN, name, args[0], args[1]
354354
)
355-
356-
def dynamic_unsupported_split(node: torch.fx.Node) -> bool:
357-
# Validate that none of the inputs to the node have Dynamic shapes
358-
assert isinstance(
359-
node, torch.fx.Node
360-
), "Inputs to validator functions must be FX Nodes"
361-
362-
if isinstance(node.args[1], torch.fx.Node):
363-
if getattr(node.args[1].meta["val"], "_has_symbolic_sizes_strides", True):
364-
return False
365-
return True
366355

367356

368357
@dynamo_tensorrt_converter(
369-
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_split
358+
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1])
370359
)
371360
@dynamo_tensorrt_converter(
372-
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_split
361+
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1])
373362
)
374363
@dynamo_tensorrt_converter(
375364
torch.ops.aten.split_with_sizes.default,
376-
capability_validator=dynamic_unsupported_split,
365+
capability_validator=dynamic_unsupported_with_args([1]),
377366
)
378367
def aten_ops_split(
379368
network: TRTNetwork,
380369
target: Target,
381370
args: Tuple[Argument, ...],
382371
kwargs: Dict[str, Argument],
383372
name: str,
384-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
373+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
385374
return impl.split.split(
386375
network,
387376
target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import functools
22
import logging
33
import re
4-
from typing import Any, List, Optional, Tuple, Union
4+
from typing import Any, List, Optional, Tuple, Union, Callable
55

66
import numpy as np
77
import tensorrt as trt
88
import torch
9+
from torch import SymBool, SymFloat, SymInt
910
from torch.fx.node import Target
1011
from torch_tensorrt.fx.converters.converter_utils import (
1112
Frameworks,
@@ -60,34 +61,51 @@ def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
6061

6162

6263
def dynamic_unsupported(node: torch.fx.Node) -> bool:
64+
"""Validates that a node has no dynamic args, kwargs, or outputs"""
65+
return _dynamic_unsupported(node=node)
66+
67+
68+
def dynamic_unsupported_with_args(
69+
arg_positions_to_check: Optional[List[int]] = None,
70+
) -> Callable[[torch.fx.Node], bool]:
71+
"""Returns a validator that a node has no dynamic args at specific positions"""
72+
return functools.partial(_dynamic_unsupported, arg_positions_to_check=arg_positions_to_check)
73+
74+
75+
def _dynamic_unsupported(
76+
node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None
77+
) -> bool:
6378
# Validate that none of the inputs to the node have Dynamic shapes
6479
assert isinstance(
6580
node, torch.fx.Node
6681
), "Inputs to validator functions must be FX Nodes"
6782

83+
def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool:
84+
"""Checks if a node itself has Dynamic properties"""
85+
return getattr(
86+
subnode.meta["val"], "_has_symbolic_sizes_strides", False
87+
) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool))
88+
6889
# Check node value itself
69-
if ("val" in node.meta) and getattr(
70-
node.meta["val"], "_has_symbolic_sizes_strides", False
71-
):
90+
if arg_positions_to_check is None and _is_subnode_dynamic(node):
7291
return False
7392

7493
# Check node arguments individually
75-
if any(
76-
(
77-
("val" in arg.meta)
78-
and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
79-
)
80-
for arg in node.args
81-
if isinstance(arg, torch.fx.Node)
94+
if arg_positions_to_check is None and any(
95+
_is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node)
96+
):
97+
return False
98+
# Check specific arg positions if the caller has specified positions to check
99+
elif arg_positions_to_check is not None and any(
100+
_is_subnode_dynamic(node.args[i])
101+
for i in arg_positions_to_check
102+
if isinstance(node.args[i], torch.fx.Node)
82103
):
83104
return False
84105

85106
# Check node keyword arguments individually
86-
if any(
87-
(
88-
("val" in kwarg.meta)
89-
and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
90-
)
107+
if arg_positions_to_check is None and any(
108+
_is_subnode_dynamic(kwarg)
91109
for kwarg in node.kwargs.values()
92110
if isinstance(kwarg, torch.fx.Node)
93111
):

0 commit comments

Comments
 (0)