Skip to content

Commit 846ec17

Browse files
committed
feat: Data Structure update for Dynamo Registry
- Add custom class overriding default Dictionary class to access converters from various registries - Add new dictionary type `Dict[Target, Sequence[ConverterSupport]]` as well as ConverterSupport class which stores a converter and its validation implementation - Add unified `DYNAMO_CONVERTERS` dictionary which coalesces both the FX and Dynamo converter dictionaries and acts as a single unified dictionary
1 parent e720dec commit 846ec17

File tree

3 files changed

+253
-9
lines changed

3 files changed

+253
-9
lines changed

py/torch_tensorrt/dynamo/backend/lowering/_partition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def is_node_supported(
124124
)
125125

126126
if (
127-
node.target in CONVERTERS.keys()
127+
CONVERTERS.contains_validated(node)
128128
and node_name not in self.torch_executed_ops
129129
):
130130
# If node is a proper, supported computational node, store the operator
Lines changed: 247 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,266 @@
1-
from typing import Any, Callable, Dict
1+
from dataclasses import dataclass, field
2+
from typing import Any, Callable, Dict, Optional, Sequence, Union
23

3-
from torch.fx.node import Target
4+
from torch.fx.node import Target, Node
45
from torch_tensorrt.fx.converter_registry import CONVERTERS
56

6-
DYNAMO_CONVERTERS: Dict[Target, Any] = dict(CONVERTERS)
7+
8+
@dataclass(frozen=True)
9+
class GenericNode:
10+
"""Class representing a generic Torch FX Node
11+
12+
Convenience class for when we do not have access to the node,
13+
and instead only its target, args, and kwargs, as is the case
14+
in the TRTInterpreter
15+
"""
16+
17+
target: Any
18+
args: Sequence[Any]
19+
kwargs: Dict[str, Any]
20+
21+
22+
# Defines the different types of valid nodes
23+
NodeKinds = Union[GenericNode, Node]
24+
25+
26+
@dataclass(frozen=True)
27+
class ConverterSupport:
28+
"""Class representing a converter implementation and support function
29+
30+
Args:
31+
converter_implementation: Function which converts said node to a TRT equivalent
32+
check_args: Function which takes in a NodeKind and returns a bool indicating
33+
whether that node can be supported by its companion converter. Note that
34+
this function must only access the .target, .args, and .kwargs fields
35+
of the node, and cannot modify the node or its graph
36+
"""
37+
38+
converter_implementation: Callable
39+
check_args: Callable[[NodeKinds], bool] = field(default=lambda node: True)
40+
41+
42+
# Dictionary representing Dynamo aten-only converters
43+
# Each converter maps to a sequence of at least one ConverterSupport object(s)
44+
DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {}
745

846

947
def dynamo_tensorrt_converter(
1048
key: Target,
1149
enabled: bool = True,
50+
check_args: Optional[Callable[[NodeKinds], bool]] = None,
1251
) -> Callable[[Any], Any]:
52+
"""Decorator for Dynamo TensorRT Converter
53+
54+
Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry
55+
56+
Args:
57+
key: Node target for which the converter is implemented for
58+
(for example, torch.ops.add.Tensor)
59+
enabled: Whether the converter should be enabled/cached or not
60+
check_args: Function which evaluates whether a node is valid for conversion
61+
by the decorated converter. See ConverterSupport for more details.
62+
Defaults to None, implying the check_args function is always true -
63+
this means all nodes of "key" kind can be supported by this converter
64+
Returns:
65+
The converter being decorated
66+
"""
67+
1368
def register_converter(converter):
14-
DYNAMO_CONVERTERS[key] = converter
69+
"""Helper function to register the converter, then return it"""
70+
assert callable(converter), "Converter function must be callable"
71+
72+
# If no check_args function is specified, use the default function - always return true
73+
if check_args is None:
74+
converter_support = ConverterSupport(converter_implementation=converter)
75+
else:
76+
assert callable(check_args), "Argument checking function must be callable"
77+
converter_support = ConverterSupport(
78+
converter_implementation=converter, check_args=check_args
79+
)
80+
81+
# If a converter for this operator already exists, append the new converter to the list
82+
# Otherwise, start a new list
83+
if key in DYNAMO_ATEN_CONVERTERS:
84+
DYNAMO_ATEN_CONVERTERS[key].append(converter_support)
85+
else:
86+
DYNAMO_ATEN_CONVERTERS[key] = [converter_support]
87+
1588
return converter
1689

1790
def disable_converter(converter):
1891
return converter
1992

93+
# Select whether to cache/enable the converter
2094
if enabled:
2195
return register_converter
2296
else:
2397
return disable_converter
98+
99+
100+
class ConverterRegistry:
101+
"""Registry for storing multiple converter dictionaries
102+
103+
Capable of storing dictionaries with the following signature:
104+
Dict[Target, Union[Callable, Sequence[ConverterSupport]]]
105+
106+
Also able to validate converter implementations against user-provided
107+
argument-checking functions
108+
109+
Args:
110+
registries: List of dictionaries representing converter registries.
111+
The order of the provided dictionaries is the order in which they
112+
will be traversed. This is only significant when using non-validated
113+
methods.
114+
"""
115+
116+
def __init__(
117+
self,
118+
registries: Sequence[Dict[Target, Union[Callable, Sequence[ConverterSupport]]]],
119+
):
120+
# Copy reference to each dictionary object into attribute list
121+
self.registries = [registry for registry in registries]
122+
self.validate_invariants()
123+
124+
def validate_invariants(self):
125+
"""Validates the invariants required of the dictionaries in the registries
126+
127+
Raises AssertionError if any invariants have been violated
128+
"""
129+
# All registries must be dictionaries
130+
assert all(isinstance(elt, dict) for elt in self.registries)
131+
132+
# Every dictionary in the registry must have one of two signatures:
133+
# Dict[Target, Callable] or Dict[Target, Sequence[ConverterSupport]]
134+
# Where, for the latter, the sequence must be non-empty
135+
for registry in self.registries:
136+
for converters in registry.values():
137+
if isinstance(converters, (list, tuple)):
138+
assert (
139+
all(isinstance(c, ConverterSupport) for c in converters)
140+
and len(converters) > 0
141+
)
142+
else:
143+
assert callable(converters), "Converter function must be callable"
144+
145+
def __getitem__(self, key: Target):
146+
"""Get the first-found converter in any registry
147+
148+
Searches all registries in order and returns the first converter encountered
149+
"""
150+
self.validate_invariants()
151+
152+
# Iterate over all registries and return the first converter found
153+
for registry in self.registries:
154+
if key in registry:
155+
converters = registry[key]
156+
157+
if isinstance(converters, (list, tuple)):
158+
return converters[0].converter_implementation
159+
else:
160+
return converters
161+
162+
raise KeyError(f"None of the converter registries have an entry for {key}")
163+
164+
def __getitem_with_validation__(self, node: NodeKinds):
165+
"""Get the first-found validated converter in any registry
166+
167+
Searches all registries in order and returns the first converter
168+
which passes validation on the input node
169+
"""
170+
self.validate_invariants()
171+
key = node.target
172+
173+
# Iterate over all registries, validating the converter on the input node
174+
# If no check_args function is found, assume full coverage
175+
for registry in self.registries:
176+
if key in registry:
177+
converters = registry[key]
178+
179+
if isinstance(converters, (list, tuple)):
180+
for candidate in converters:
181+
if candidate.check_args(node):
182+
return candidate.converter_implementation
183+
else:
184+
return converters
185+
186+
raise KeyError(
187+
f"None of the converter registries have a validated entry for {key}, with node {node}"
188+
)
189+
190+
def keys(self):
191+
"""Get all unique targets across all dictionaries"""
192+
return self.unique_targets()
193+
194+
def get(self, key: Target, value=None):
195+
"""Get converter for input target with a default return"""
196+
try:
197+
return self.__getitem__(key)
198+
except KeyError:
199+
return value
200+
201+
def __contains__(self, key: Target):
202+
"""Check whether a converter for input target exists"""
203+
return any(key in registry for registry in self.registries)
204+
205+
def get_validated(self, node: NodeKinds, value=None):
206+
"""Get validated converter for input node with a default return"""
207+
try:
208+
return self.__getitem_with_validation__(node)
209+
except KeyError:
210+
return value
211+
212+
def contains_validated(self, node: NodeKinds):
213+
"""Check whether a validated converter for input node exists"""
214+
try:
215+
self.__getitem_with_validation__(node)
216+
return True
217+
except KeyError:
218+
return False
219+
220+
def get_all_converters_with_target(self, key: Target):
221+
"""Get all converters across all registries for the target
222+
223+
Returns a list of all converterts having the specified target
224+
"""
225+
self.validate_invariants()
226+
converters_with_target = []
227+
228+
for registry in self.registries:
229+
if key in registry:
230+
converters = registry[key]
231+
232+
if isinstance(converters, (list, tuple)):
233+
converters_with_target.extend(
234+
[c.converter_implementation for c in converters]
235+
)
236+
else:
237+
converters_with_target.append(converters)
238+
239+
return converters_with_target
240+
241+
def __setitem__(self, key, value):
242+
raise AssertionError(
243+
f"Do not set registry members directly through the ConverterRegistry object. "
244+
+ f"Attempted to set {key}: {value} via direct assignment to ConverterRegistry."
245+
)
246+
247+
def __delitem__(self, key):
248+
raise AssertionError(
249+
f"Do not delete registry members directly through the ConverterRegistry object. "
250+
+ f"Attempted to delete {key} via direct del on ConverterRegistry."
251+
)
252+
253+
def __len__(self):
254+
"""Returns the sum of lengths of all registries stored"""
255+
return sum(len(registry) for registry in self.registries)
256+
257+
def unique_targets(self):
258+
"""Returns the set of unique converter targets stored across all registries"""
259+
return set.union(*[set(registry.keys()) for registry in self.registries])
260+
261+
262+
# Initialize dynamo converter registry with the FX and Dynamo aten registries
263+
# Note the Dynamo registry is listed first, for precedence
264+
DYNAMO_CONVERTERS: ConverterRegistry = ConverterRegistry(
265+
[DYNAMO_ATEN_CONVERTERS, CONVERTERS]
266+
)

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.fx.passes.shape_prop import TensorMetadata
1616

1717
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS
18+
from torch_tensorrt.dynamo.converter_registry import GenericNode
1819
from .input_tensor_spec import InputTensorSpec
1920
from torch_tensorrt.fx.observer import Observer
2021
from torch_tensorrt.fx.utils import (
@@ -141,9 +142,9 @@ def validate_conversion(self):
141142
missing_converter = set()
142143

143144
for node in self.module.graph.nodes:
144-
if node.op == "call_function" and not CONVERTERS.get(node.target):
145+
if node.op == "call_function" and not CONVERTERS.get_validated(node):
145146
missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}")
146-
elif node.op == "call_method" and not CONVERTERS.get(node.target):
147+
elif node.op == "call_method" and not CONVERTERS.get_validated(node):
147148
missing_converter.add(f"{node.op} torch.Tensor.{node.target}")
148149
elif node.op == "call_module":
149150
submod = self.fetch_attr(node.target)
@@ -347,7 +348,7 @@ def call_module(self, target, args, kwargs):
347348
return converter(self.network, submod, args, kwargs, self._cur_node_name)
348349

349350
def call_function(self, target, args, kwargs):
350-
converter = CONVERTERS.get(target)
351+
converter = CONVERTERS.get_validated(GenericNode(target, args, kwargs))
351352
if not converter:
352353
raise RuntimeError(
353354
f"Conversion of function {torch.typename(target)} not currently supported!"
@@ -358,7 +359,7 @@ def call_function(self, target, args, kwargs):
358359

359360
def call_method(self, target, args, kwargs):
360361
assert isinstance(target, str)
361-
converter = CONVERTERS.get(target)
362+
converter = CONVERTERS.get_validated(GenericNode(target, args, kwargs))
362363

363364
if not converter:
364365
raise RuntimeError(

0 commit comments

Comments
 (0)