|
6 | 6 | import numpy as np
|
7 | 7 | import tensorrt as trt
|
8 | 8 | import torch
|
| 9 | +from torch import SymBool, SymFloat, SymInt |
9 | 10 | from torch.fx.node import Target
|
10 | 11 | from torch_tensorrt.fx.converters.converter_utils import (
|
11 | 12 | Frameworks,
|
@@ -46,26 +47,51 @@ def get_node_name(node: torch.fx.Node) -> str:
|
46 | 47 |
|
47 | 48 |
|
48 | 49 | def dynamic_unsupported(node: torch.fx.Node) -> bool:
|
| 50 | + """Validates that a node has no dynamic args, kwargs, or outputs""" |
| 51 | + return _dynamic_unsupported(node=node) |
| 52 | + |
| 53 | + |
| 54 | +def dynamic_unsupported_with_args( |
| 55 | + arg_positions_to_check: Optional[List[int]] = None, |
| 56 | +) -> Callable[[torch.fx.Node], bool]: |
| 57 | + """Returns a validator that a node has no dynamic args at specific positions""" |
| 58 | + return functools.partial(_dynamic_unsupported, arg_positions_to_check=arg_positions_to_check) |
| 59 | + |
| 60 | + |
| 61 | +def _dynamic_unsupported( |
| 62 | + node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None |
| 63 | +) -> bool: |
49 | 64 | # Validate that none of the inputs to the node have Dynamic shapes
|
50 | 65 | assert isinstance(
|
51 | 66 | node, torch.fx.Node
|
52 | 67 | ), "Inputs to validator functions must be FX Nodes"
|
53 | 68 |
|
| 69 | + def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: |
| 70 | + """Checks if a node itself has Dynamic properties""" |
| 71 | + return getattr( |
| 72 | + subnode.meta["val"], "_has_symbolic_sizes_strides", False |
| 73 | + ) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool)) |
| 74 | + |
54 | 75 | # Check node value itself
|
55 |
| - if getattr(node.meta["val"], "_has_symbolic_sizes_strides", False): |
| 76 | + if arg_positions_to_check is None and _is_subnode_dynamic(node): |
56 | 77 | return False
|
57 | 78 |
|
58 | 79 | # Check node arguments individually
|
59 |
| - if any( |
60 |
| - getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False) |
61 |
| - for arg in node.args |
62 |
| - if isinstance(arg, torch.fx.Node) |
| 80 | + if arg_positions_to_check is None and any( |
| 81 | + _is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node) |
| 82 | + ): |
| 83 | + return False |
| 84 | + # Check specific arg positions if the caller has specified positions to check |
| 85 | + elif arg_positions_to_check is not None and any( |
| 86 | + _is_subnode_dynamic(node.args[i]) |
| 87 | + for i in arg_positions_to_check |
| 88 | + if isinstance(node.args[i], torch.fx.Node) |
63 | 89 | ):
|
64 | 90 | return False
|
65 | 91 |
|
66 | 92 | # Check node keyword arguments individually
|
67 |
| - if any( |
68 |
| - getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False) |
| 93 | + if arg_positions_to_check is None and any( |
| 94 | + _is_subnode_dynamic(kwarg) |
69 | 95 | for kwarg in node.kwargs.values()
|
70 | 96 | if isinstance(kwarg, torch.fx.Node)
|
71 | 97 | ):
|
|
0 commit comments