|
1 | 1 | import functools
|
2 | 2 | import logging
|
3 | 3 | import re
|
4 |
| -from typing import Any, List, Optional, Tuple, Union |
| 4 | +from typing import Any, List, Optional, Tuple, Union, Callable |
5 | 5 |
|
6 | 6 | import numpy as np
|
7 | 7 | import tensorrt as trt
|
8 | 8 | import torch
|
| 9 | +from torch import SymBool, SymFloat, SymInt |
9 | 10 | from torch.fx.node import Target
|
10 | 11 | from torch_tensorrt.fx.converters.converter_utils import (
|
11 | 12 | Frameworks,
|
@@ -60,34 +61,51 @@ def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
|
60 | 61 |
|
61 | 62 |
|
62 | 63 | def dynamic_unsupported(node: torch.fx.Node) -> bool:
|
| 64 | + """Validates that a node has no dynamic args, kwargs, or outputs""" |
| 65 | + return _dynamic_unsupported(node=node) |
| 66 | + |
| 67 | + |
| 68 | +def dynamic_unsupported_with_args( |
| 69 | + arg_positions_to_check: Optional[List[int]] = None, |
| 70 | +) -> Callable[[torch.fx.Node], bool]: |
| 71 | + """Returns a validator that a node has no dynamic args at specific positions""" |
| 72 | + return functools.partial(_dynamic_unsupported, arg_positions_to_check=arg_positions_to_check) |
| 73 | + |
| 74 | + |
| 75 | +def _dynamic_unsupported( |
| 76 | + node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None |
| 77 | +) -> bool: |
63 | 78 | # Validate that none of the inputs to the node have Dynamic shapes
|
64 | 79 | assert isinstance(
|
65 | 80 | node, torch.fx.Node
|
66 | 81 | ), "Inputs to validator functions must be FX Nodes"
|
67 | 82 |
|
| 83 | + def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: |
| 84 | + """Checks if a node itself has Dynamic properties""" |
| 85 | + return getattr( |
| 86 | + subnode.meta["val"], "_has_symbolic_sizes_strides", False |
| 87 | + ) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool)) |
| 88 | + |
68 | 89 | # Check node value itself
|
69 |
| - if ("val" in node.meta) and getattr( |
70 |
| - node.meta["val"], "_has_symbolic_sizes_strides", False |
71 |
| - ): |
| 90 | + if arg_positions_to_check is None and _is_subnode_dynamic(node): |
72 | 91 | return False
|
73 | 92 |
|
74 | 93 | # Check node arguments individually
|
75 |
| - if any( |
76 |
| - ( |
77 |
| - ("val" in arg.meta) |
78 |
| - and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False) |
79 |
| - ) |
80 |
| - for arg in node.args |
81 |
| - if isinstance(arg, torch.fx.Node) |
| 94 | + if arg_positions_to_check is None and any( |
| 95 | + _is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node) |
| 96 | + ): |
| 97 | + return False |
| 98 | + # Check specific arg positions if the caller has specified positions to check |
| 99 | + elif arg_positions_to_check is not None and any( |
| 100 | + _is_subnode_dynamic(node.args[i]) |
| 101 | + for i in arg_positions_to_check |
| 102 | + if isinstance(node.args[i], torch.fx.Node) |
82 | 103 | ):
|
83 | 104 | return False
|
84 | 105 |
|
85 | 106 | # Check node keyword arguments individually
|
86 |
| - if any( |
87 |
| - ( |
88 |
| - ("val" in kwarg.meta) |
89 |
| - and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False) |
90 |
| - ) |
| 107 | + if arg_positions_to_check is None and any( |
| 108 | + _is_subnode_dynamic(kwarg) |
91 | 109 | for kwarg in node.kwargs.values()
|
92 | 110 | if isinstance(kwarg, torch.fx.Node)
|
93 | 111 | ):
|
|
0 commit comments