Skip to content

Commit 809affd

Browse files
committed
aten::cat converter moving to impl
1 parent 4cffd6e commit 809affd

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,24 @@ def aten_ops_batch_norm(
5151
)
5252

5353

54+
@dynamo_tensorrt_converter(torch.ops.aten.cat.default)
55+
def aten_ops_cat(
56+
network: TRTNetwork,
57+
target: Target,
58+
args: Tuple[Argument, ...],
59+
kwargs: Dict[str, Argument],
60+
name: str,
61+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
62+
return impl.cat.cat(
63+
network,
64+
target,
65+
SourceIR.ATEN,
66+
name,
67+
tensors=args[0],
68+
dim = args_bounds_check(args, 2, 1),
69+
)
70+
71+
5472
def embedding_param_validator(embedding_node: Node) -> bool:
5573
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
5674
sparse = args_bounds_check(embedding_node.args, 4)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
activation,
55
attention,
66
cast,
7+
cat,
78
condition,
89
conv,
910
deconv,
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Optional, Union, Sequence, Dict
2+
3+
import torch
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
7+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
8+
9+
def cat(
10+
network: TRTNetwork,
11+
target: Target,
12+
source_ir: Optional[SourceIR],
13+
name: str,
14+
input: TRTNetwork,
15+
dim: int,
16+
17+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
18+
19+
if any(not isinstance(t, TRTTensor) for t in input): # type: ignore[union-attr]
20+
raise RuntimeError(
21+
f"cat received inputs {input} that is not part " "of the TensorRT region!"
22+
)
23+
concat_layer = network.add_concatenation(input)
24+
if dim < 0:
25+
dim = len(input[0].shape) + dim
26+
27+
concat_layer.axis = dim
28+
set_layer_name(concat_layer, target, name + "_gather", source_ir)
29+
return concat_layer.get_output(0)

0 commit comments

Comments
 (0)