@@ -26,7 +26,7 @@ def get_padded_shape_tensors(
26
26
source_ir : Optional [SourceIR ],
27
27
name : str ,
28
28
input : TRTTensor ,
29
- pad : Sequence [int ],
29
+ pad : Sequence [Union [ int , TRTTensor ] ],
30
30
) -> TRTTensor :
31
31
rank = len (input .shape )
32
32
if len (pad ) // 2 > rank :
@@ -47,11 +47,11 @@ def get_padded_shape_tensors(
47
47
start_list = [0 ] * rank
48
48
for i in range (len (pad ) // 2 ):
49
49
dim_index = rank - (i + 1 )
50
- pad_before = pad [i * 2 ]
51
- pad_after = pad [i * 2 + 1 ]
50
+ pad_before = get_trt_tensor ( ctx , pad [i * 2 ], f" { name } _pad_before_ { i } " )
51
+ pad_after = get_trt_tensor ( ctx , pad [i * 2 + 1 ], f" { name } _pad_after_ { i } " )
52
52
53
- pad_sum = get_trt_tensor (
54
- ctx , pad_before + pad_after , f"{ name } _pad_sum_{ i } " , dtype = np . int32
53
+ pad_sum = impl . elementwise . add (
54
+ ctx , target , source_ir , f"{ name } _pad_sum_{ i } " , pad_before , pad_after
55
55
)
56
56
dim_shape = ctx .net .add_slice (
57
57
input_shape_tensor ,
@@ -63,7 +63,9 @@ def get_padded_shape_tensors(
63
63
new_dim_shape = impl .elementwise .add (
64
64
ctx , target , source_ir , f"{ name } _shape_dim_{ i } " , dim_shape , pad_sum
65
65
)
66
- start_list [dim_index ] = - pad_before
66
+ start_list [dim_index ] = impl .elementwise .sub (
67
+ ctx , target , source_ir , f"{ name } _pad_before_neg_{ i } " , 0 , pad_before
68
+ )
67
69
68
70
slices = []
69
71
for j in range (rank ):
@@ -79,14 +81,23 @@ def get_padded_shape_tensors(
79
81
).get_output (0 )
80
82
)
81
83
padded_shape_tensor = impl .cat .cat (
82
- ctx , target , source_ir , f"{ name } _cat_dim_{ i } " , slices , 0
84
+ ctx ,
85
+ target ,
86
+ source_ir ,
87
+ f"{ name } _cat_dim_{ i } " ,
88
+ slices ,
89
+ 0 ,
90
+ cast_dtype = padded_shape_tensor .dtype ,
83
91
)
84
92
85
- start_indices_tensor = get_trt_tensor (
93
+ start_indices_tensor = impl . cat . cat (
86
94
ctx ,
87
- np .array (start_list , dtype = np .int32 ),
95
+ target ,
96
+ source_ir ,
88
97
f"{ name } _start_indices_tensor" ,
89
- dtype = np .int32 ,
98
+ start_list ,
99
+ 0 ,
100
+ cast_dtype = padded_shape_tensor .dtype ,
90
101
)
91
102
92
103
return start_indices_tensor , padded_shape_tensor
@@ -98,7 +109,7 @@ def constant_padNd(
98
109
source_ir : Optional [SourceIR ],
99
110
name : str ,
100
111
input : TRTTensor ,
101
- pad : Sequence [int ],
112
+ pad : Sequence [Union [ int , TRTTensor ] ],
102
113
value : Union [int , float ] = 0 ,
103
114
) -> TRTTensor :
104
115
0 commit comments