Skip to content

Commit 61e716e

Browse files
committed
feat: Add _to_copy, operator.get and clone
- Add ATen converters for key operators in the pipeline of multiple models - Add robust testing and patch issues in interpreter - Add evaluator and casting utilities to the converter utils
1 parent 06e544e commit 61e716e

File tree

11 files changed

+334
-28
lines changed

11 files changed

+334
-28
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
2828

2929

30+
class UnsupportedOperatorException(RuntimeError):
31+
pass
32+
33+
3034
class TRTInterpreterResult(NamedTuple):
3135
engine: Any
3236
input_names: Sequence[str]
@@ -301,7 +305,7 @@ def call_module(
301305
converter = CONVERTERS.get(self._cur_node)
302306

303307
if not converter:
304-
raise RuntimeError(
308+
raise UnsupportedOperatorException(
305309
f"Conversion of module of type {submod_type} not currently supported!"
306310
)
307311

@@ -312,7 +316,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
312316
# TODO: Why is this stateful? We should be able to take in the inputs
313317
converter = CONVERTERS.get(self._cur_node)
314318
if not converter:
315-
raise RuntimeError(
319+
raise UnsupportedOperatorException(
316320
f"Conversion of function {torch.typename(target)} not currently supported!"
317321
)
318322

@@ -324,7 +328,7 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
324328
converter = CONVERTERS.get(self._cur_node)
325329

326330
if not converter:
327-
raise RuntimeError(
331+
raise UnsupportedOperatorException(
328332
f"Conversion of method {target} not currently supported!"
329333
)
330334

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import logging
2+
import operator
23
from typing import Any, Dict, Optional, Sequence, Tuple, Union
34

5+
import tensorrt as trt
46
import torch
57
from torch.fx.node import Argument, Node, Target
68
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -12,8 +14,6 @@
1214
from torch_tensorrt.fx.converters import acc_ops_converters
1315
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1416

15-
import tensorrt as trt
16-
1717
from .converter_registry import dynamo_tensorrt_converter
1818

1919
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -76,13 +76,13 @@ def aten_ops_div(
7676
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
7777
):
7878
kwargs_new["input"] = cast_trt_tensor(
79-
network, kwargs_new["input"], trt.float32, name
79+
network, kwargs_new["input"], trt.float32, name, target
8080
)
8181
elif isinstance(args[1], TRTTensor) and (
8282
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
8383
):
8484
kwargs_new["other"] = cast_trt_tensor(
85-
network, kwargs_new["other"], trt.float32, name
85+
network, kwargs_new["other"], trt.float32, name, target
8686
)
8787
rounding_mode = kwargs.get("rounding_mode")
8888
if rounding_mode is None:
@@ -101,7 +101,7 @@ def aten_ops_div(
101101
)
102102

103103

104-
def embedding_param_validator(embedding_node: Node):
104+
def embedding_param_validator(embedding_node: Node) -> bool:
105105
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
106106
sparse = args_bounds_check(embedding_node.args, 4)
107107

@@ -365,3 +365,77 @@ def aten_ops_permute(
365365
args[0],
366366
args[1],
367367
)
368+
369+
370+
def to_copy_dtype_validator(to_copy_node: Node) -> bool:
371+
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}
372+
373+
# Validate input node has convertible kwargs
374+
if "dtype" in to_copy_node.kwargs:
375+
if to_copy_node.kwargs["dtype"] in allowed_casts:
376+
return True
377+
else:
378+
_LOGGER.debug(
379+
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
380+
)
381+
return False
382+
else:
383+
_LOGGER.debug(
384+
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
385+
)
386+
return False
387+
388+
389+
@dynamo_tensorrt_converter(
390+
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
391+
)
392+
def aten_ops_to_copy_dtype(
393+
network: TRTNetwork,
394+
target: Target,
395+
args: Tuple[Argument, ...],
396+
kwargs: Dict[str, Argument],
397+
name: str,
398+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
399+
return impl.cast.to_copy(
400+
network,
401+
target,
402+
SourceIR.ATEN,
403+
name,
404+
args[0],
405+
kwargs["dtype"],
406+
)
407+
408+
409+
@dynamo_tensorrt_converter(operator.getitem)
410+
def operator_getitem(
411+
network: TRTNetwork,
412+
target: Target,
413+
args: Tuple[Argument, ...],
414+
kwargs: Dict[str, Argument],
415+
name: str,
416+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
417+
return impl.evaluators.getitem(
418+
network,
419+
target,
420+
SourceIR.ATEN,
421+
name,
422+
args[0],
423+
args[1],
424+
)
425+
426+
427+
@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
428+
def aten_ops_clone(
429+
network: TRTNetwork,
430+
target: Target,
431+
args: Tuple[Argument, ...],
432+
kwargs: Dict[str, Argument],
433+
name: str,
434+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
435+
return impl.evaluators.clone(
436+
network,
437+
target,
438+
SourceIR.ATEN,
439+
name,
440+
args[0],
441+
)

py/torch_tensorrt/dynamo/conversion/converter_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def dynamo_tensorrt_converter(
6666
enabled: bool = True,
6767
capability_validator: Optional[Callable[[Node], bool]] = None,
6868
priority: ConverterPriority = ConverterPriority.STANDARD,
69-
) -> Callable[[Any], Any]:
69+
) -> Callable[[Any], Union[TRTTensor, Sequence[TRTTensor]]]:
7070
"""Decorator for Dynamo TensorRT Converter
7171
7272
Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import logging
22
import re
3-
from typing import List
3+
from typing import List, Optional
44

55
import tensorrt as trt
66
import torch
7+
from torch.fx.node import Target, _get_qualified_name
78
from torch_tensorrt.fx.converters.converter_utils import (
89
Frameworks,
910
unified_dtype_converter,
1011
)
1112
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
1213

14+
from .._SourceIR import SourceIR
15+
1316
_LOGGER: logging.Logger = logging.getLogger(__name__)
1417

1518

@@ -71,24 +74,35 @@ def cast_trt_tensor(
7174
input_val: TRTTensor,
7275
dtype: TRTDataType,
7376
name: str,
77+
target: Target = "",
78+
source_ir: Optional[SourceIR] = None,
7479
) -> TRTTensor:
7580
"""
7681
Given a TRT Tensor, convert that Tensor to the specified dtype
7782
Adds an Identity layer to the network which performs the conversion
7883
Args:
7984
network (TRTNetwork): A TensorRT network
8085
input_val (TRTTensor): A TRT Tensor to cast to a new data type
81-
dtype (TRTDataType): The TRTDataType to cast the input Tensor to
86+
dtype (TRTDataType, torch.dtype, np.dtype): The data type to cast the input Tensor to
8287
name (str): Name of the calling layer
88+
target (Target): Target of calling node
89+
source_ir (SourceIR): SourceIR of calling converter
8390
Returns:
8491
A TensorRT ITensor which has been casted to the specified dtype
8592
"""
8693
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)
8794

8895
if input_val.dtype != trt_dtype:
96+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
97+
target_name = (
98+
f"{source_ir}_ops{'.' + target if target else ''}"
99+
if (isinstance(target, str))
100+
else f"{source_ir}_ops.{_get_qualified_name(target)}"
101+
)
102+
89103
identity_layer = network.add_identity(input_val)
90104
identity_layer.set_output_type(0, trt_dtype)
91-
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}"
105+
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} -{name}-[{target_name}]-[{name}]"
92106
return identity_layer.get_output(0)
93107
else:
94108
return input_val

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
from . import (
44
activation,
5+
cast,
56
condition,
67
elementwise,
78
embedding,
9+
evaluators,
810
matmul,
911
normalization,
1012
permutation,
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Optional
2+
3+
from torch.fx.node import Target
4+
from torch_tensorrt.dynamo._SourceIR import SourceIR
5+
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
6+
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
7+
8+
9+
def to_copy(
10+
network: TRTNetwork,
11+
target: Target,
12+
source_ir: Optional[SourceIR],
13+
name: str,
14+
input: TRTTensor,
15+
dtype: TRTDataType,
16+
) -> TRTTensor:
17+
if not isinstance(input, TRTTensor):
18+
raise RuntimeError(
19+
f"to_copy received input {input} that is not a TensorRT ITensor"
20+
)
21+
22+
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
23+
return casted_tensor

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from typing import Any, Callable, Optional, Union
44

5+
import tensorrt as trt
56
import torch
67
from torch.fx.node import Target
78
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -15,8 +16,6 @@
1516
from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor
1617
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1718

18-
import tensorrt as trt
19-
2019

2120
def get_python_op_from_trt_elementwise_op(
2221
trt_op: TRTElementWiseOp,
@@ -132,9 +131,13 @@ def convert_binary_elementwise(
132131
trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT)
133132

134133
if trt_promoted_type != lhs_val.dtype:
135-
lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name)
134+
lhs_val = cast_trt_tensor(
135+
network, lhs_val, trt_promoted_type, name, target, source_ir
136+
)
136137
if trt_promoted_type != rhs_val.dtype:
137-
rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name)
138+
rhs_val = cast_trt_tensor(
139+
network, rhs_val, trt_promoted_type, name, target, source_ir
140+
)
138141

139142
# Check the limitation in the doc string.
140143
if network.has_implicit_batch_dimension:
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import logging
2+
import operator
3+
from typing import Optional, Sequence
4+
5+
from torch.fx.node import Target
6+
from torch_tensorrt.dynamo._SourceIR import SourceIR
7+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
8+
9+
LOGGER: logging.Logger = logging.getLogger(__name__)
10+
11+
12+
def getitem(
13+
network: TRTNetwork,
14+
target: Target,
15+
source_ir: Optional[SourceIR],
16+
name: str,
17+
input: Sequence[TRTTensor],
18+
index: int,
19+
) -> TRTTensor:
20+
LOGGER.debug(f"Evaluating getitem on object with name: {name}")
21+
22+
# Directly index the input sequence and return the value
23+
return operator.getitem(input, index)
24+
25+
26+
def clone(
27+
network: TRTNetwork,
28+
target: Target,
29+
source_ir: Optional[SourceIR],
30+
name: str,
31+
input: TRTTensor,
32+
) -> TRTTensor:
33+
if not isinstance(input, TRTTensor):
34+
raise RuntimeError(
35+
f"clone received input {input} that is not a TensorRT ITensor"
36+
)
37+
38+
LOGGER.debug(f"Evaluating clone on object with name: {name}")
39+
40+
return input

0 commit comments

Comments
 (0)