Skip to content

Commit 6cf4d6c

Browse files
committed
aten::cat rebase changes
1 parent 79e7e13 commit 6cf4d6c

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,19 @@ def aten_ops_batch_norm(
5353

5454
@dynamo_tensorrt_converter(torch.ops.aten.cat.default)
5555
def aten_ops_cat(
56-
network: TRTNetwork,
56+
ctx: ConversionContext,
5757
target: Target,
5858
args: Tuple[Argument, ...],
5959
kwargs: Dict[str, Argument],
6060
name: str,
6161
) -> Union[TRTTensor, Sequence[TRTTensor]]:
6262
return impl.cat.cat(
63-
network,
63+
ctx,
6464
target,
6565
SourceIR.ATEN,
6666
name,
67-
tensors=args[0],
68-
dim=args_bounds_check(args, 2, 1),
67+
input=args[0],
68+
dim=args_bounds_check(args, 1, 0),
6969
)
7070

7171

py/torch_tensorrt/dynamo/conversion/impl/cat.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,23 @@
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
66
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
7+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
79
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
810

911

1012
def cat(
11-
network: TRTNetwork,
13+
ctx: ConversionContext,
1214
target: Target,
1315
source_ir: Optional[SourceIR],
1416
name: str,
15-
input: TRTNetwork,
17+
input: Union[TRTTensor, Sequence[TRTTensor]],
1618
dim: int,
1719
) -> Union[TRTTensor, Sequence[TRTTensor]]:
18-
if any(not isinstance(t, TRTTensor) for t in input): # type: ignore[union-attr]
19-
raise RuntimeError(
20-
f"cat received inputs {input} that is not part " "of the TensorRT region!"
21-
)
22-
concat_layer = network.add_concatenation(input)
20+
for each_input in input:
21+
if(not isinstance(each_input, TRTTensor)):
22+
each_input = get_trt_tensor(each_input)
23+
concat_layer = ctx.net.add_concatenation(input)
2324
if dim < 0:
2425
dim = len(input[0].shape) + dim
2526

0 commit comments

Comments
 (0)