10
10
import tensorrt as trt
11
11
import torch
12
12
import torch .fx
13
- from torch ._ops import OpOverload
14
13
from torch .fx .node import _get_qualified_name
15
14
from torch .fx .passes .shape_prop import TensorMetadata
16
15
17
- from torch_tensorrt .fx import CONVERTERS
16
+ from torch_tensorrt .dynamo import DYNAMO_CONVERTERS as CONVERTERS
18
17
from torch_tensorrt import Input
19
18
from torch_tensorrt .fx .observer import Observer
20
19
from torch_tensorrt .fx .utils import (
@@ -69,6 +68,7 @@ def __init__(
69
68
self .input_specs = input_specs
70
69
self .input_specs_iter = 0
71
70
self ._cur_node_name : Optional [str ] = None
71
+ self ._cur_node : Optional [torch .fx .Node ] = None
72
72
self ._input_names : List [str ] = []
73
73
self ._output_names : List [str ] = []
74
74
self ._itensor_to_tensor_meta : Dict [
@@ -82,14 +82,14 @@ def validate_conversion(self):
82
82
missing_converter = set ()
83
83
84
84
for node in self .module .graph .nodes :
85
- if node .op == "call_function" and not CONVERTERS .get (node . target ):
85
+ if node .op == "call_function" and not CONVERTERS .get (node ):
86
86
missing_converter .add (f"{ node .op } { _get_qualified_name (node .target )} " )
87
- elif node .op == "call_method" and not CONVERTERS .get (node . target ):
87
+ elif node .op == "call_method" and not CONVERTERS .get (node ):
88
88
missing_converter .add (f"{ node .op } torch.Tensor.{ node .target } " )
89
89
elif node .op == "call_module" :
90
90
submod = self .fetch_attr (node .target )
91
91
submod_type = getattr (submod , "_base_class_origin" , type (submod ))
92
- if not CONVERTERS .get (submod_type ):
92
+ if not CONVERTERS .get (node ):
93
93
missing_converter .add (f"{ node .op } { torch .typename (submod_type )} " )
94
94
95
95
return missing_converter
@@ -226,6 +226,7 @@ def run(
226
226
227
227
def run_node (self , n ):
228
228
self ._cur_node_name = str (n )
229
+ self ._cur_node = n
229
230
# add "_itensor_to_tensor_meta"
230
231
kwargs = dict (n .kwargs )
231
232
kwargs ["_itensor_to_tensor_meta" ] = self ._itensor_to_tensor_meta
@@ -276,7 +277,7 @@ def call_module(self, target, args, kwargs):
276
277
assert isinstance (target , str )
277
278
submod = self .fetch_attr (target )
278
279
submod_type = getattr (submod , "_base_class_origin" , type (submod ))
279
- converter = CONVERTERS .get (submod_type )
280
+ converter = CONVERTERS .get (self . _cur_node )
280
281
281
282
if not converter :
282
283
raise RuntimeError (
@@ -287,7 +288,7 @@ def call_module(self, target, args, kwargs):
287
288
return converter (self .network , submod , args , kwargs , self ._cur_node_name )
288
289
289
290
def call_function (self , target , args , kwargs ):
290
- converter = CONVERTERS .get (target )
291
+ converter = CONVERTERS .get (self . _cur_node )
291
292
if not converter :
292
293
raise RuntimeError (
293
294
f"Conversion of function { torch .typename (target )} not currently supported!"
@@ -298,7 +299,7 @@ def call_function(self, target, args, kwargs):
298
299
299
300
def call_method (self , target , args , kwargs ):
300
301
assert isinstance (target , str )
301
- converter = CONVERTERS .get (target )
302
+ converter = CONVERTERS .get (self . _cur_node )
302
303
303
304
if not converter :
304
305
raise RuntimeError (
0 commit comments