File tree Expand file tree Collapse file tree 2 files changed +12
-11
lines changed
py/torch_tensorrt/dynamo/conversion Expand file tree Collapse file tree 2 files changed +12
-11
lines changed Original file line number Diff line number Diff line change @@ -53,19 +53,19 @@ def aten_ops_batch_norm(
53
53
54
54
@dynamo_tensorrt_converter (torch .ops .aten .cat .default )
55
55
def aten_ops_cat (
56
- network : TRTNetwork ,
56
+ ctx : ConversionContext ,
57
57
target : Target ,
58
58
args : Tuple [Argument , ...],
59
59
kwargs : Dict [str , Argument ],
60
60
name : str ,
61
61
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
62
62
return impl .cat .cat (
63
- network ,
63
+ ctx ,
64
64
target ,
65
65
SourceIR .ATEN ,
66
66
name ,
67
- tensors = args [0 ],
68
- dim = args_bounds_check (args , 2 , 1 ),
67
+ input = args [0 ],
68
+ dim = args_bounds_check (args , 1 , 0 ),
69
69
)
70
70
71
71
Original file line number Diff line number Diff line change 4
4
from torch .fx .node import Target
5
5
from torch_tensorrt .dynamo ._SourceIR import SourceIR
6
6
from torch_tensorrt .fx .converters .converter_utils import set_layer_name
7
+ from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
8
+ from torch_tensorrt .dynamo .conversion .converter_utils import get_trt_tensor
7
9
from torch_tensorrt .fx .types import TRTNetwork , TRTTensor
8
10
9
11
10
12
def cat (
11
- network : TRTNetwork ,
13
+ ctx : ConversionContext ,
12
14
target : Target ,
13
15
source_ir : Optional [SourceIR ],
14
16
name : str ,
15
- input : TRTNetwork ,
17
+ input : Union [ TRTTensor , Sequence [ TRTTensor ]] ,
16
18
dim : int ,
17
19
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
18
- if any (not isinstance (t , TRTTensor ) for t in input ): # type: ignore[union-attr]
19
- raise RuntimeError (
20
- f"cat received inputs { input } that is not part " "of the TensorRT region!"
21
- )
22
- concat_layer = network .add_concatenation (input )
20
+ for each_input in input :
21
+ if (not isinstance (each_input , TRTTensor )):
22
+ each_input = get_trt_tensor (each_input )
23
+ concat_layer = ctx .net .add_concatenation (input )
23
24
if dim < 0 :
24
25
dim = len (input [0 ].shape ) + dim
25
26
You can’t perform that action at this time.
0 commit comments