1
1
from typing import Dict , Optional , Sequence , Union
2
2
3
+ import numpy as np
3
4
import torch
4
5
from torch .fx .node import Target
5
6
from torch_tensorrt .dynamo ._SourceIR import SourceIR
6
- from torch_tensorrt .fx .converters .converter_utils import set_layer_name
7
7
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
8
- from torch_tensorrt .dynamo .conversion .converter_utils import get_trt_tensor
8
+ from torch_tensorrt .dynamo .conversion .converter_utils import (
9
+ SourceIR ,
10
+ get_positive_dim ,
11
+ get_trt_tensor ,
12
+ )
13
+ from torch_tensorrt .fx .converters .converter_utils import set_layer_name
9
14
from torch_tensorrt .fx .types import TRTNetwork , TRTTensor
10
15
11
16
@@ -14,15 +19,15 @@ def cat(
14
19
target : Target ,
15
20
source_ir : Optional [SourceIR ],
16
21
name : str ,
17
- input : Union [TRTTensor , Sequence [ TRTTensor ]],
22
+ input : Sequence [ Union [TRTTensor , torch . Tensor , np . ndarray ]],
18
23
dim : int ,
19
24
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
20
25
for each_input in input :
21
- if ( not isinstance (each_input , TRTTensor ) ):
26
+ if not isinstance (each_input , TRTTensor ):
22
27
each_input = get_trt_tensor (each_input )
23
28
concat_layer = ctx .net .add_concatenation (input )
24
29
if dim < 0 :
25
- dim = len (input [0 ].shape ) + dim
30
+ dim = get_positive_dim ( dim , len (input [0 ].shape ))
26
31
27
32
concat_layer .axis = dim
28
33
set_layer_name (concat_layer , target , name + "_gather" , source_ir )
0 commit comments