Skip to content

Commit 83e4f3c

Browse files
committed
Addressing review comments
1 parent 759b260 commit 83e4f3c

File tree

1 file changed

+14
-9
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+14
-9
lines changed
Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from typing import Dict, Optional, Sequence, Union
22

3+
import numpy as np
34
import torch
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
77
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8-
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
SourceIR,
10+
get_positive_dim,
11+
get_trt_tensor,
12+
)
13+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
914
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1015

1116

@@ -14,16 +19,16 @@ def cat(
1419
target: Target,
1520
source_ir: Optional[SourceIR],
1621
name: str,
17-
input: Union[TRTTensor, Sequence[TRTTensor]],
22+
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
1823
dim: int,
1924
) -> Union[TRTTensor, Sequence[TRTTensor]]:
25+
trt_inputs = []
2026
for each_input in input:
21-
if(not isinstance(each_input, TRTTensor)):
22-
each_input = get_trt_tensor(each_input)
23-
concat_layer = ctx.net.add_concatenation(input)
24-
if dim < 0:
25-
dim = len(input[0].shape) + dim
26-
27+
if not isinstance(each_input, TRTTensor):
28+
each_input = get_trt_tensor(ctx, each_input, name + "_tensor_{i}")
29+
trt_inputs.append(each_input)
30+
concat_layer = ctx.net.add_concatenation(trt_inputs)
31+
dim = get_positive_dim(dim, len(input[0].shape))
2732
concat_layer.axis = dim
2833
set_layer_name(concat_layer, target, name + "_gather", source_ir)
2934
return concat_layer.get_output(0)

0 commit comments

Comments
 (0)