Skip to content

Commit 84f4755

Browse files
committed
feat: Data Structure update for Dynamo Registry
- Add custom class overriding default Dictionary class to access converters from various registries - Add new dictionary type `Dict[Target, Sequence[ConverterSupport]]` as well as ConverterSupport class which stores a converter and its validation implementation - Add unified `DYNAMO_CONVERTERS` dictionary which coalesces both the FX and Dynamo converter dictionaries and acts as a single unified dictionary - Streamline dictionary accesses via get/contains accessors - Add priority converter decorator enum to prioritize user-provided converters and name argument checking "capability validation" to clarify utility - Add boilerplate `no_dynamic` converter capability validator for easy use in specifying converters as not-able to handle dynamic shapes
1 parent 0ff692b commit 84f4755

File tree

4 files changed

+352
-16
lines changed

4 files changed

+352
-16
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
3+
4+
def dynamic_unsupported(node: torch.fx.Node) -> bool:
5+
# Validate that none of the inputs to the node have Dynamic shapes
6+
assert isinstance(
7+
node, torch.fx.Node
8+
), "Inputs to validator functions must be FX Nodes"
9+
10+
# Check node value itself
11+
if node.meta["val"]._has_symbolic_sizes_strides:
12+
return False
13+
14+
# Check node arguments individually
15+
if any(
16+
arg.meta["val"]._has_symbolic_sizes_strides
17+
for arg in node.args
18+
if isinstance(arg, torch.fx.Node)
19+
):
20+
return False
21+
22+
# Check node keyword arguments individually
23+
if any(
24+
kwarg.meta["val"]._has_symbolic_sizes_strides
25+
for kwarg in node.kwargs.values()
26+
if isinstance(kwarg, torch.fx.Node)
27+
):
28+
return False
29+
30+
return True

py/torch_tensorrt/dynamo/conversion/trt_interpreter.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import tensorrt as trt
1111
import torch
1212
import torch.fx
13-
from torch._ops import OpOverload
1413
from torch.fx.node import _get_qualified_name
1514
from torch.fx.passes.shape_prop import TensorMetadata
1615

@@ -69,6 +68,7 @@ def __init__(
6968
self.input_specs = input_specs
7069
self.input_specs_iter = 0
7170
self._cur_node_name: Optional[str] = None
71+
self._cur_node: Optional[torch.fx.Node] = None
7272
self._input_names: List[str] = []
7373
self._output_names: List[str] = []
7474
self._itensor_to_tensor_meta: Dict[
@@ -82,14 +82,14 @@ def validate_conversion(self):
8282
missing_converter = set()
8383

8484
for node in self.module.graph.nodes:
85-
if node.op == "call_function" and not CONVERTERS.get(node.target):
85+
if node.op == "call_function" and not CONVERTERS.get(node):
8686
missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}")
87-
elif node.op == "call_method" and not CONVERTERS.get(node.target):
87+
elif node.op == "call_method" and not CONVERTERS.get(node):
8888
missing_converter.add(f"{node.op} torch.Tensor.{node.target}")
8989
elif node.op == "call_module":
9090
submod = self.fetch_attr(node.target)
9191
submod_type = getattr(submod, "_base_class_origin", type(submod))
92-
if not CONVERTERS.get(submod_type):
92+
if not CONVERTERS.get(node):
9393
missing_converter.add(f"{node.op} {torch.typename(submod_type)}")
9494

9595
return missing_converter
@@ -226,6 +226,7 @@ def run(
226226

227227
def run_node(self, n):
228228
self._cur_node_name = str(n)
229+
self._cur_node = n
229230
# add "_itensor_to_tensor_meta"
230231
kwargs = dict(n.kwargs)
231232
kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta
@@ -276,7 +277,7 @@ def call_module(self, target, args, kwargs):
276277
assert isinstance(target, str)
277278
submod = self.fetch_attr(target)
278279
submod_type = getattr(submod, "_base_class_origin", type(submod))
279-
converter = CONVERTERS.get(submod_type)
280+
converter = CONVERTERS.get(self._cur_node)
280281

281282
if not converter:
282283
raise RuntimeError(
@@ -287,7 +288,7 @@ def call_module(self, target, args, kwargs):
287288
return converter(self.network, submod, args, kwargs, self._cur_node_name)
288289

289290
def call_function(self, target, args, kwargs):
290-
converter = CONVERTERS.get(target)
291+
converter = CONVERTERS.get(self._cur_node)
291292
if not converter:
292293
raise RuntimeError(
293294
f"Conversion of function {torch.typename(target)} not currently supported!"
@@ -298,7 +299,7 @@ def call_function(self, target, args, kwargs):
298299

299300
def call_method(self, target, args, kwargs):
300301
assert isinstance(target, str)
301-
converter = CONVERTERS.get(target)
302+
converter = CONVERTERS.get(self._cur_node)
302303

303304
if not converter:
304305
raise RuntimeError(

0 commit comments

Comments
 (0)