Skip to content

Commit afc9b5d

Browse files
committed
aten::cat converter moving to impl
1 parent 7d0d3f6 commit afc9b5d

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def aten_ops_cat(
6565
SourceIR.ATEN,
6666
name,
6767
tensors=args[0],
68-
dim = args_bounds_check(args, 2, 1),
68+
dim=args_bounds_check(args, 2, 1),
6969
)
7070

7171

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
1-
from typing import Optional, Union, Sequence, Dict
1+
from typing import Dict, Optional, Sequence, Union
22

33
import torch
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
66
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
77
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
88

9+
910
def cat(
1011
network: TRTNetwork,
1112
target: Target,
1213
source_ir: Optional[SourceIR],
1314
name: str,
1415
input: TRTNetwork,
1516
dim: int,
16-
1717
) -> Union[TRTTensor, Sequence[TRTTensor]]:
18-
1918
if any(not isinstance(t, TRTTensor) for t in input): # type: ignore[union-attr]
2019
raise RuntimeError(
2120
f"cat received inputs {input} that is not part " "of the TensorRT region!"
@@ -26,4 +25,4 @@ def cat(
2625

2726
concat_layer.axis = dim
2827
set_layer_name(concat_layer, target, name + "_gather", source_ir)
29-
return concat_layer.get_output(0)
28+
return concat_layer.get_output(0)

0 commit comments

Comments
 (0)