Skip to content

Commit 359b6b7

Browse files
committed
aten::cat converter moving to impl
1 parent ecdc040 commit 359b6b7

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,24 @@ def aten_ops_batch_norm(
4646
)
4747

4848

49+
@dynamo_tensorrt_converter(torch.ops.aten.cat.default)
50+
def aten_ops_cat(
51+
network: TRTNetwork,
52+
target: Target,
53+
args: Tuple[Argument, ...],
54+
kwargs: Dict[str, Argument],
55+
name: str,
56+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
57+
return impl.cat.cat(
58+
network,
59+
target,
60+
SourceIR.ATEN,
61+
name,
62+
tensors=args[0],
63+
dim = args_bounds_check(args, 2, 1),
64+
)
65+
66+
4967
def embedding_param_validator(embedding_node: Node) -> bool:
5068
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
5169
sparse = args_bounds_check(embedding_node.args, 4)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from . import (
44
activation,
55
cast,
6+
cat,
67
condition,
78
conv,
89
elementwise,
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Optional, Union, Sequence, Dict
2+
3+
import torch
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
7+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
8+
9+
def cat(
10+
network: TRTNetwork,
11+
target: Target,
12+
source_ir: Optional[SourceIR],
13+
name: str,
14+
input: TRTNetwork,
15+
dim: int,
16+
17+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
18+
19+
if any(not isinstance(t, TRTTensor) for t in input): # type: ignore[union-attr]
20+
raise RuntimeError(
21+
f"cat received inputs {input} that is not part " "of the TensorRT region!"
22+
)
23+
concat_layer = network.add_concatenation(input)
24+
if dim < 0:
25+
dim = len(input[0].shape) + dim
26+
27+
concat_layer.axis = dim
28+
set_layer_name(concat_layer, target, name + "_gather", source_ir)
29+
return concat_layer.get_output(0)

0 commit comments

Comments
 (0)