Skip to content

Commit eaeaa46

Browse files
committed
feat: Data Structure update for Dynamo Registry
- Add custom class overriding default Dictionary class to access converters from various registries - Add new dictionary type `Dict[Target, Sequence[ConverterSupport]]` as well as ConverterSupport class which stores a converter and its validation implementation - Add unified `DYNAMO_CONVERTERS` dictionary which coalesces both the FX and Dynamo converter dictionaries and acts as a single unified dictionary
1 parent e720dec commit eaeaa46

File tree

3 files changed

+240
-9
lines changed

3 files changed

+240
-9
lines changed

py/torch_tensorrt/dynamo/backend/lowering/_partition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def is_node_supported(
124124
)
125125

126126
if (
127-
node.target in CONVERTERS.keys()
127+
CONVERTERS.contains_validated(node)
128128
and node_name not in self.torch_executed_ops
129129
):
130130
# If node is a proper, supported computational node, store the operator
Lines changed: 234 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,253 @@
1-
from typing import Any, Callable, Dict
1+
from dataclasses import dataclass, field
2+
from typing import Any, Callable, Dict, Optional, Sequence, Union
23

3-
from torch.fx.node import Target
4+
from torch.fx.node import Target, Node
45
from torch_tensorrt.fx.converter_registry import CONVERTERS
56

6-
DYNAMO_CONVERTERS: Dict[Target, Any] = dict(CONVERTERS)
7+
8+
@dataclass(frozen=True)
9+
class GenericNode:
10+
"""Class representing a generic Torch FX Node
11+
12+
Convenience class for when we do not have access to the node,
13+
and instead only its target, args, and kwargs, as is the case
14+
in the TRTInterpreter
15+
"""
16+
17+
target: Any
18+
args: Sequence[Any]
19+
kwargs: Dict[str, Any]
20+
21+
22+
# Defines the different types of valid nodes
23+
NodeKinds = Union[GenericNode, Node]
24+
25+
26+
@dataclass(frozen=True)
27+
class ConverterSupport:
28+
"""Class representing a converter implementation and support function
29+
30+
Args:
31+
check_args: Function which takes in a NodeKind and returns a bool indicating
32+
whether that node can be supported by its companion converter. Note that
33+
this function must only access the .target, .args, and .kwargs fields
34+
of the node, and cannot modify the node or its graph
35+
converter_implementation: Function which converts said node to a TRT equivalent
36+
"""
37+
38+
check_args: Callable[[NodeKinds], bool] = field(default=lambda node: True)
39+
converter_implementation: Callable
40+
41+
42+
# Dictionary representing Dynamo aten-only converters
43+
# Each converter maps to a sequence of at least one ConverterSupport object(s)
44+
DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {}
745

846

947
def dynamo_tensorrt_converter(
1048
key: Target,
1149
enabled: bool = True,
50+
check_args: Optional[Callable[[NodeKinds], bool]] = None,
1251
) -> Callable[[Any], Any]:
52+
"""Decorator for Dynamo TensorRT Converter
53+
54+
Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry
55+
56+
Args:
57+
key: Node target for which the converter is implemented for
58+
(for example, torch.ops.add.Tensor)
59+
enabled: Whether the converter should be enabled/cached or not
60+
check_args: Function which evaluates whether a node is valid for conversion
61+
by the decorated converter. See ConverterSupport for more details.
62+
Defaults to None, implying the check_args function is always true -
63+
this means all nodes of "key" kind can be supported by this converter
64+
Returns:
65+
The converter being decorated
66+
"""
67+
1368
def register_converter(converter):
14-
DYNAMO_CONVERTERS[key] = converter
69+
"""Helper function to register the converter, then return it"""
70+
assert callable(converter), "Converter function must be callable"
71+
72+
# If no check_args function is specified, use the default function - always return true
73+
if check_args is None:
74+
converter_support = ConverterSupport(converter_implementation=converter)
75+
else:
76+
assert callable(check_args), "Argument checking function must be callable"
77+
converter_support = ConverterSupport(
78+
check_args=check_args, converter_implementation=converter
79+
)
80+
81+
# If a converter for this operator already exists, append the new converter to the list
82+
# Otherwise, start a new list
83+
if key in DYNAMO_ATEN_CONVERTERS:
84+
DYNAMO_ATEN_CONVERTERS[key].append(converter_support)
85+
else:
86+
DYNAMO_ATEN_CONVERTERS[key] = [converter_support]
87+
1588
return converter
1689

1790
def disable_converter(converter):
1891
return converter
1992

93+
# Select whether to cache/enable the converter
2094
if enabled:
2195
return register_converter
2296
else:
2397
return disable_converter
98+
99+
100+
class ConverterRegistry:
101+
"""Registry for storing multiple converter dictionaries
102+
103+
Capable of storing dictionaries with two different signatures:
104+
Dict[Target, Callable] and Dict[Target, Sequence[ConverterSupport]]
105+
"""
106+
107+
def __init__(self, registries: Dict[Target, Any]):
108+
# Copy reference to each dictionary object into attribute list
109+
self.registries = [registry for registry in registries]
110+
self.validate_invariants()
111+
112+
def validate_invariants(self):
113+
"""Validates the invariants required of the dictionaries in the registries
114+
115+
Raises AssertionError if any invariants have been violated
116+
"""
117+
# All registries must be dictionaries
118+
assert all(isinstance(elt, dict) for elt in self.registries)
119+
120+
# Every dictionary in the registry must have one of two signatures:
121+
# Dict[Target, Callable] or Dict[Target, Sequence[ConverterSupport]]
122+
# Where, for the latter, the sequence must be non-empty
123+
for registry in self.registries:
124+
for converters in registry.values():
125+
if isinstance(converters, (list, tuple)):
126+
assert (
127+
all(isinstance(c, ConverterSupport) for c in converters)
128+
and len(converters) > 0
129+
)
130+
else:
131+
assert callable(converters), "Converter function must be callable"
132+
133+
def __getitem__(self, key: Target):
134+
"""Get the first-found converter in any registry
135+
136+
Searches all registries in order and returns the first converter encountered
137+
"""
138+
self.validate_invariants()
139+
140+
# Iterate over all registries and return the first converter found
141+
for registry in self.registries:
142+
if key in registry:
143+
converters = registry[key]
144+
145+
if isinstance(converters, (list, tuple)):
146+
return converters[0].converter_implementation
147+
else:
148+
return converters
149+
150+
raise KeyError(f"None of the converter registries have an entry for {key}")
151+
152+
def __getitem_with_validation__(self, node: NodeKinds):
153+
"""Get the first-found validated converter in any registry
154+
155+
Searches all registries in order and returns the first converter
156+
which passes validation on the input node
157+
"""
158+
self.validate_invariants()
159+
key = node.target
160+
161+
# Iterate over all registries, validating the converter on the input node
162+
# If no check_args function is found, assume full coverage
163+
for registry in self.registries:
164+
if key in registry:
165+
converters = registry[key]
166+
167+
if isinstance(converters, (list, tuple)):
168+
for candidate in converters:
169+
if candidate.check_args(node):
170+
return candidate.converter_implementation
171+
else:
172+
return converters
173+
174+
raise KeyError(
175+
f"None of the converter registries have a validated entry for {key}, with node {node}"
176+
)
177+
178+
def keys(self):
179+
"""Get all unique targets across all dictionaries"""
180+
return self.unique_targets()
181+
182+
def get(self, key: Target, value=None):
183+
"""Get converter for input target with a default return"""
184+
try:
185+
return self.__getitem__(key)
186+
except KeyError:
187+
return value
188+
189+
def __contains__(self, key: Target):
190+
"""Check whether a converter for input target exists"""
191+
return any(key in registry for registry in self.registries)
192+
193+
def get_validated(self, node: NodeKinds, value=None):
194+
"""Get validated converter for input node with a default return"""
195+
try:
196+
return self.__getitem_with_validation__(node)
197+
except KeyError:
198+
return value
199+
200+
def contains_validated(self, node: NodeKinds):
201+
"""Check whether a validated converter for input node exists"""
202+
try:
203+
self.__getitem_with_validation__(node)
204+
return True
205+
except KeyError:
206+
return False
207+
208+
def get_all_converters_with_target(self, key: Target):
209+
"""Get all converters across all registries for the target
210+
211+
Returns a list of all converterts having the specified target
212+
"""
213+
self.validate_invariants()
214+
converters_with_target = []
215+
216+
for registry in self.registries:
217+
if key in registry:
218+
converters = registry[key]
219+
220+
if isinstance(converters, (list, tuple)):
221+
converters_with_target.extend(
222+
[c.converter_implementation for c in converters]
223+
)
224+
else:
225+
converters_with_target.append(converters)
226+
227+
return converters_with_target
228+
229+
def __setitem__(self, key, value):
230+
raise AssertionError(
231+
f"Do not set registry members directly through the ConverterRegistry object. "
232+
+ f"Attempted to set {key}: {value} via direct assignment to ConverterRegistry."
233+
)
234+
235+
def __delitem__(self, key):
236+
raise AssertionError(
237+
f"Do not delete registry members directly through the ConverterRegistry object. "
238+
+ f"Attempted to delete {key} via direct del on ConverterRegistry."
239+
)
240+
241+
def __len__(self):
242+
"""Returns the sum of lengths of all registries stored"""
243+
return sum(len(registry) for registry in self.registries)
244+
245+
def unique_targets(self):
246+
"""Returns the set of unique converter targets stored across all registries"""
247+
return set.union(*[set(registry.keys()) for registry in self.registries])
248+
249+
250+
# Initialize dynamo converter registry with the FX and Dynamo aten registries
251+
DYNAMO_CONVERTERS: ConverterRegistry = ConverterRegistry(
252+
[CONVERTERS, DYNAMO_ATEN_CONVERTERS]
253+
)

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.fx.passes.shape_prop import TensorMetadata
1616

1717
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS
18+
from torch_tensorrt.dynamo.converter_registry import GenericNode
1819
from .input_tensor_spec import InputTensorSpec
1920
from torch_tensorrt.fx.observer import Observer
2021
from torch_tensorrt.fx.utils import (
@@ -141,9 +142,9 @@ def validate_conversion(self):
141142
missing_converter = set()
142143

143144
for node in self.module.graph.nodes:
144-
if node.op == "call_function" and not CONVERTERS.get(node.target):
145+
if node.op == "call_function" and not CONVERTERS.get_validated(node):
145146
missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}")
146-
elif node.op == "call_method" and not CONVERTERS.get(node.target):
147+
elif node.op == "call_method" and not CONVERTERS.get_validated(node):
147148
missing_converter.add(f"{node.op} torch.Tensor.{node.target}")
148149
elif node.op == "call_module":
149150
submod = self.fetch_attr(node.target)
@@ -347,7 +348,7 @@ def call_module(self, target, args, kwargs):
347348
return converter(self.network, submod, args, kwargs, self._cur_node_name)
348349

349350
def call_function(self, target, args, kwargs):
350-
converter = CONVERTERS.get(target)
351+
converter = CONVERTERS.get_validated(GenericNode(target, args, kwargs))
351352
if not converter:
352353
raise RuntimeError(
353354
f"Conversion of function {torch.typename(target)} not currently supported!"
@@ -358,7 +359,7 @@ def call_function(self, target, args, kwargs):
358359

359360
def call_method(self, target, args, kwargs):
360361
assert isinstance(target, str)
361-
converter = CONVERTERS.get(target)
362+
converter = CONVERTERS.get_validated(GenericNode(target, args, kwargs))
362363

363364
if not converter:
364365
raise RuntimeError(

0 commit comments

Comments
 (0)