Skip to content

Commit 828fec0

Browse files
committed
feat: Add _to_copy, operator.get and clone
- Add ATen converters for key operators in the pipeline of multiple models - Add robust testing and patch issues in interpreter - Add evaluator and casting utilities to the converter utils
1 parent f7b03f4 commit 828fec0

File tree

10 files changed

+328
-12
lines changed

10 files changed

+328
-12
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ def aten_ops_div(
7070
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
7171
):
7272
kwargs_new["input"] = cast_trt_tensor(
73-
network, kwargs_new["input"], trt.float32, name
73+
network, kwargs_new["input"], trt.float32, name, target
7474
)
7575
elif isinstance(args[1], TRTTensor) and (
7676
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
7777
):
7878
kwargs_new["other"] = cast_trt_tensor(
79-
network, kwargs_new["other"], trt.float32, name
79+
network, kwargs_new["other"], trt.float32, name, target
8080
)
8181
rounding_mode = kwargs.get("rounding_mode")
8282
if rounding_mode is None:
@@ -377,3 +377,77 @@ def aten_ops_permute(
377377
args[0],
378378
args[1],
379379
)
380+
381+
382+
def to_copy_dtype_validator(to_copy_node: Node):
383+
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}
384+
385+
# Validate input node has convertible kwargs
386+
if "dtype" in to_copy_node.kwargs:
387+
if to_copy_node.kwargs["dtype"] in allowed_casts:
388+
return True
389+
else:
390+
_LOGGER.debug(
391+
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
392+
)
393+
return False
394+
else:
395+
_LOGGER.debug(
396+
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
397+
)
398+
return False
399+
400+
401+
@dynamo_tensorrt_converter(
402+
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
403+
)
404+
def aten_ops_to_copy_dtype(
405+
network: TRTNetwork,
406+
target: Target,
407+
args: Tuple[Argument, ...],
408+
kwargs: Dict[str, Argument],
409+
name: str,
410+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
411+
return impl.cast.to_copy(
412+
network,
413+
target,
414+
SourceIR.ATEN,
415+
name,
416+
args[0],
417+
kwargs["dtype"],
418+
)
419+
420+
421+
@dynamo_tensorrt_converter(operator.getitem)
422+
def operator_getitem(
423+
network: TRTNetwork,
424+
target: Target,
425+
args: Tuple[Argument, ...],
426+
kwargs: Dict[str, Argument],
427+
name: str,
428+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
429+
return impl.evaluators.getitem(
430+
network,
431+
target,
432+
SourceIR.ATEN,
433+
name,
434+
args[0],
435+
args[1],
436+
)
437+
438+
439+
@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
440+
def aten_ops_clone(
441+
network: TRTNetwork,
442+
target: Target,
443+
args: Tuple[Argument, ...],
444+
kwargs: Dict[str, Argument],
445+
name: str,
446+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
447+
return impl.evaluators.clone(
448+
network,
449+
target,
450+
SourceIR.ATEN,
451+
name,
452+
args[0],
453+
)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torch
22

3+
from torch.fx.node import _get_qualified_name, Target
4+
35
from torch_tensorrt.fx.types import (
46
TRTDataType,
57
TRTNetwork,
@@ -12,7 +14,9 @@
1214
)
1315

1416
import tensorrt as trt
15-
from typing import List
17+
from typing import List, Optional, Union
18+
19+
from .._SourceIR import SourceIR
1620

1721

1822
def dynamic_unsupported(node: torch.fx.Node) -> bool:
@@ -49,24 +53,33 @@ def cast_trt_tensor(
4953
input_val: TRTTensor,
5054
dtype: TRTDataType,
5155
name: str,
56+
target: Union[Target, str] = "",
57+
source_ir: Optional[SourceIR] = None,
5258
) -> TRTTensor:
5359
"""
5460
Given a TRT Tensor, convert that Tensor to the specified dtype
5561
Adds an Identity layer to the network which performs the conversion
5662
Args:
5763
network (TRTNetwork): A TensorRT network
5864
input_val (TRTTensor): A TRT Tensor to cast to a new data type
59-
dtype (TRTDataType): The TRTDataType to cast the input Tensor to
65+
dtype (TRTDataType, torch.dtype, np.dtype): The data type to cast the input Tensor to
6066
name (str): Name of the calling layer
6167
Returns:
6268
A TensorRT ITensor which has been casted to the specified dtype
6369
"""
6470
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)
6571

6672
if input_val.dtype != trt_dtype:
73+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
74+
target_name = (
75+
f"{source_ir}_ops{'.' + target}"
76+
if (isinstance(target, str) and target)
77+
else f"{source_ir}_ops.{_get_qualified_name(target)}"
78+
)
79+
6780
identity_layer = network.add_identity(input_val)
6881
identity_layer.set_output_type(0, trt_dtype)
69-
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}"
82+
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} -{name}-[{target_name}]-[{name}]"
7083
return identity_layer.get_output(0)
7184
else:
7285
return input_val

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@
1212
from . import squeeze
1313
from . import unsqueeze
1414
from . import permutation
15+
from . import cast
16+
from . import evaluators
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Optional
2+
from torch.fx.node import Target
3+
4+
from torch_tensorrt.dynamo._SourceIR import SourceIR
5+
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
6+
7+
from torch_tensorrt.fx.types import (
8+
TRTNetwork,
9+
TRTTensor,
10+
TRTDataType,
11+
)
12+
13+
14+
def to_copy(
15+
network: TRTNetwork,
16+
target: Target,
17+
source_ir: Optional[SourceIR],
18+
name: str,
19+
input: TRTTensor,
20+
dtype: TRTDataType,
21+
) -> TRTTensor:
22+
if not isinstance(input, TRTTensor):
23+
raise RuntimeError(
24+
f"to_copy received input {input} that is not a TensorRT ITensor"
25+
)
26+
27+
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
28+
return casted_tensor

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,13 @@ def convert_binary_elementwise(
137137
trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT)
138138

139139
if trt_promoted_type != lhs_val.dtype:
140-
lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name)
140+
lhs_val = cast_trt_tensor(
141+
network, lhs_val, trt_promoted_type, name, target, source_ir
142+
)
141143
if trt_promoted_type != rhs_val.dtype:
142-
rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name)
144+
rhs_val = cast_trt_tensor(
145+
network, rhs_val, trt_promoted_type, name, target, source_ir
146+
)
143147

144148
# Check the limitation in the doc string.
145149
if network.has_implicit_batch_dimension:
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import operator
2+
import logging
3+
from typing import Optional, Sequence
4+
from torch.fx.node import Target
5+
6+
from torch_tensorrt.dynamo.conversion import SourceIR
7+
8+
from torch_tensorrt.fx.types import (
9+
TRTNetwork,
10+
TRTTensor,
11+
)
12+
13+
14+
LOGGER: logging.Logger = logging.getLogger(__name__)
15+
16+
17+
def getitem(
18+
network: TRTNetwork,
19+
target: Target,
20+
source_ir: Optional[SourceIR],
21+
name: str,
22+
input: Sequence[TRTTensor],
23+
index: int,
24+
) -> TRTTensor:
25+
LOGGER.debug(f"Evaluating getitem on object with name: {name}")
26+
27+
# Directly index the input sequence and return the value
28+
return operator.getitem(input, index)
29+
30+
31+
def clone(
32+
network: TRTNetwork,
33+
target: Target,
34+
source_ir: Optional[SourceIR],
35+
name: str,
36+
input: TRTTensor,
37+
) -> TRTTensor:
38+
if not isinstance(input, TRTTensor):
39+
raise RuntimeError(
40+
f"clone received input {input} that is not a TensorRT ITensor"
41+
)
42+
43+
LOGGER.debug(f"Evaluating clone on object with name: {name}")
44+
45+
return input

py/torch_tensorrt/dynamo/conversion/trt_interpreter.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
3030

3131

32+
class UnsupportedOperatorException(RuntimeError):
33+
pass
34+
35+
3236
class TRTInterpreterResult(NamedTuple):
3337
engine: Any
3438
input_names: Sequence[str]
@@ -288,7 +292,7 @@ def call_module(self, target, args, kwargs):
288292
converter = CONVERTERS.get(self._cur_node)
289293

290294
if not converter:
291-
raise RuntimeError(
295+
raise UnsupportedOperatorException(
292296
f"Conversion of module of type {submod_type} not currently supported!"
293297
)
294298

@@ -298,7 +302,7 @@ def call_module(self, target, args, kwargs):
298302
def call_function(self, target, args, kwargs):
299303
converter = CONVERTERS.get(self._cur_node)
300304
if not converter:
301-
raise RuntimeError(
305+
raise UnsupportedOperatorException(
302306
f"Conversion of function {torch.typename(target)} not currently supported!"
303307
)
304308

@@ -310,7 +314,7 @@ def call_method(self, target, args, kwargs):
310314
converter = CONVERTERS.get(self._cur_node)
311315

312316
if not converter:
313-
raise RuntimeError(
317+
raise UnsupportedOperatorException(
314318
f"Conversion of method {target} not currently supported!"
315319
)
316320

py/torch_tensorrt/dynamo/test_utils.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def generate_graph(
217217
expected_ops: Set[Callable],
218218
unexpected_ops: Optional[Set[Callable]] = None,
219219
customized_passes: List[Callable] = None,
220+
disable_passes: bool = False,
220221
):
221222
# Torchdynamo+aot proxytensor tracer
222223
# Below are common passes
@@ -234,6 +235,10 @@ def generate_graph(
234235
# Combine with customized passes specific to any model
235236
if customized_passes:
236237
passes_list.extend(customized_passes)
238+
239+
if disable_passes:
240+
passes_list = []
241+
237242
fx_module, _ = aten_tracer.trace(mod, original_inputs)
238243
for passes in passes_list:
239244
pr: PassResult = passes(fx_module)
@@ -261,9 +266,17 @@ def run_test(
261266
atol=1e-03,
262267
precision=torch.float,
263268
check_dtype=True,
269+
disable_passes=False,
264270
):
265271
mod.eval()
266-
mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None)
272+
mod = self.generate_graph(
273+
mod,
274+
inputs,
275+
expected_ops,
276+
unexpected_ops,
277+
None,
278+
disable_passes=disable_passes,
279+
)
267280

268281
if apply_passes is not None:
269282
pass_tracer = chain_passes(*apply_passes)
@@ -293,10 +306,18 @@ def run_test_with_dynamic_shape(
293306
unexpected_ops=None,
294307
rtol=1e-03,
295308
atol=1e-03,
309+
disable_passes=False,
296310
):
297311
mod.eval()
298312
inputs = [spec.example_tensor("opt_shape") for spec in input_specs]
299-
mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None)
313+
mod = self.generate_graph(
314+
mod,
315+
inputs,
316+
expected_ops,
317+
unexpected_ops,
318+
None,
319+
disable_passes=disable_passes,
320+
)
300321

301322
interp = TRTInterpreter(
302323
mod,

0 commit comments

Comments
 (0)