1
- import copy
2
1
from typing import Optional , Sequence , Union
3
2
4
- import torch
5
- import torch_tensorrt .dynamo .conversion .impl as impl
3
+ import tensorrt as trt
6
4
from torch .fx .node import Target
7
5
from torch_tensorrt .dynamo ._SourceIR import SourceIR
8
6
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
9
- from torch_tensorrt .fx .converters .converter_utils import has_dynamic_shape
7
+ from torch_tensorrt .fx .converters .converter_utils import (
8
+ get_trt_tensor ,
9
+ has_dynamic_shape ,
10
+ set_layer_name ,
11
+ )
10
12
from torch_tensorrt .fx .types import TRTTensor
11
13
14
+ """
15
+ Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0.
16
+ Use ISliceLayer to pad the tensor, which supports new non-constant, reflects padding
17
+ mode and clamp, and supports padding output with dynamic shape.
18
+ """
19
+
12
20
13
21
def constant_padNd (
14
22
ctx : ConversionContext ,
@@ -19,43 +27,36 @@ def constant_padNd(
19
27
pad : Sequence [int ],
20
28
value : Union [int , float ] = 0 ,
21
29
) -> TRTTensor :
22
- """
23
- Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0.
24
- Use ISliceLayer to pad the tensor, which supports new non-constant, reflects padding
25
- mode and clamp, and supports padding output with dynamic shape.
26
- """
27
30
if has_dynamic_shape (input .shape ):
28
31
assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for padding."
29
32
30
- # Implement constant padding via concat
31
- curr_dim = len (input .shape ) - 1
32
-
33
- for i in range (0 , len (pad ), 2 ):
34
- input_shape = list (input .shape )
35
-
36
- pre_pad = pad [i ]
37
- post_pad = pad [i + 1 ]
38
- pre_pad_shape = copy .deepcopy (input_shape )
39
- pre_pad_shape [curr_dim ] = pre_pad
40
- pre_pad_tensor = torch .full (pre_pad_shape , float (value ))
41
- if pre_pad == post_pad :
42
- post_pad_tensor = pre_pad_tensor
43
- else :
44
- post_pad_shape = copy .deepcopy (input_shape )
45
- post_pad_shape [curr_dim ] = post_pad
46
- post_pad_tensor = torch .full (post_pad_shape , float (value ))
47
- output = impl .cat .cat (
48
- ctx ,
49
- target ,
50
- source_ir ,
51
- f"{ name } _concat{ curr_dim } " ,
52
- input = (pre_pad_tensor , input , post_pad_tensor ),
53
- dim = curr_dim ,
33
+ rank = len (input .shape )
34
+
35
+ if len (pad ) / 2 > rank :
36
+ raise RuntimeError (
37
+ f"Trying to pad last { len (pad ) / 2 } dimension but the input only has { rank } dimension."
54
38
)
55
- curr_dim -= 1
56
- input = output
57
39
58
- return output
40
+ start_list = [0 ] * len (input .shape )
41
+ new_shape = input .shape
42
+
43
+ for i in range (0 , len (pad ) // 2 ):
44
+ start_list [- i - 1 ] = - pad [i * 2 ]
45
+ new_shape [- i - 1 ] += pad [i * 2 ] + pad [i * 2 + 1 ]
46
+
47
+ stride_list = [1 ] * len (new_shape )
48
+ layer = ctx .net .add_slice (
49
+ input ,
50
+ start = tuple (start_list ),
51
+ shape = tuple (new_shape ),
52
+ stride = tuple (stride_list ),
53
+ )
54
+ value_const = get_trt_tensor (ctx .net , value , f"{ name } _value" , input .dtype )
55
+ layer .set_input (4 , value_const )
56
+ layer .mode = trt .SliceMode .FILL
57
+
58
+ set_layer_name (layer , target , name , source_ir )
59
+ return layer .get_output (0 )
59
60
60
61
61
62
def reflection_padNd (
@@ -69,53 +70,32 @@ def reflection_padNd(
69
70
if has_dynamic_shape (input .shape ):
70
71
assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for padding."
71
72
72
- padding_dims = len (padding ) // 2
73
-
74
- if padding_dims == 1 or padding_dims == 2 or padding_dims == 3 :
75
- for i in range (padding_dims ):
76
- dim = - 1 - i
77
- pre_pad , post_pad = padding [2 * i ], padding [2 * i + 1 ]
78
- pre_pad_tensor = impl .slice .slice_op (
79
- ctx ,
80
- target ,
81
- source_ir ,
82
- f"{ name } _slice_pre{ i } " ,
83
- input ,
84
- dim = dim ,
85
- start = pre_pad ,
86
- stop = 0 ,
87
- step = - 1 ,
88
- )
89
-
90
- post_pad_tensor = impl .slice .slice_op (
91
- ctx ,
92
- target ,
93
- source_ir ,
94
- f"{ name } _slice_post{ i } " ,
95
- input ,
96
- dim = dim ,
97
- start = input .shape [dim ] - 2 ,
98
- stop = input .shape [dim ] - post_pad - 2 ,
99
- step = - 1 ,
100
- )
101
-
102
- output = impl .cat .cat (
103
- ctx ,
104
- target ,
105
- source_ir ,
106
- f"{ name } _concat_dim{ dim } " ,
107
- input = (pre_pad_tensor , input , post_pad_tensor ),
108
- dim = dim ,
109
- )
110
- input = output
111
-
112
- return output
73
+ rank = len (input .shape )
113
74
114
- else :
75
+ if len ( padding ) / 2 > rank :
115
76
raise RuntimeError (
116
- f"We currently only support for padding 1D, 2D, and 3D, but got { padding_dims } D "
77
+ f"Trying to pad last { len ( padding ) / 2 } dimension but the input only has { rank } dimension. "
117
78
)
118
79
80
+ start_list = [0 ] * len (input .shape )
81
+ new_shape = input .shape
82
+
83
+ for i in range (0 , len (padding ) // 2 ):
84
+ start_list [- i - 1 ] = - padding [i * 2 ]
85
+ new_shape [- i - 1 ] += padding [i * 2 ] + padding [i * 2 + 1 ]
86
+
87
+ stride_list = [1 ] * len (new_shape )
88
+ layer = ctx .net .add_slice (
89
+ input ,
90
+ start = tuple (start_list ),
91
+ shape = tuple (new_shape ),
92
+ stride = tuple (stride_list ),
93
+ )
94
+ layer .mode = trt .SliceMode .REFLECT
95
+
96
+ set_layer_name (layer , target , name , source_ir )
97
+ return layer .get_output (0 )
98
+
119
99
120
100
def replication_padNd (
121
101
ctx : ConversionContext ,
@@ -128,71 +108,32 @@ def replication_padNd(
128
108
if has_dynamic_shape (input .shape ):
129
109
assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for padding."
130
110
131
- padding_dims = len (padding ) // 2
132
-
133
- if padding_dims == 1 or padding_dims == 2 or padding_dims == 3 :
134
- for i in range (padding_dims ):
135
- dim = - 1 - i
136
- pre_pad , post_pad = padding [2 * i ], padding [2 * i + 1 ]
137
- pre_pad_tensor = impl .slice .slice_op (
138
- ctx ,
139
- target ,
140
- source_ir ,
141
- f"{ name } _slice_pre{ i } " ,
142
- input ,
143
- dim = dim ,
144
- start = 0 ,
145
- stop = 1 ,
146
- step = 1 ,
147
- )
148
- new_shape = input .shape
149
- new_shape [dim ] = pre_pad
150
- pre_pad_tensor = impl .slice .expand (
151
- ctx ,
152
- target ,
153
- source_ir ,
154
- f"{ name } _expand_pre{ i } " ,
155
- pre_pad_tensor ,
156
- new_shape ,
157
- )
158
-
159
- post_pad_tensor = impl .slice .slice_op (
160
- ctx ,
161
- target ,
162
- source_ir ,
163
- f"{ name } _slice_post{ i } " ,
164
- input ,
165
- dim = dim ,
166
- start = input .shape [dim ] - 1 ,
167
- stop = input .shape [dim ],
168
- step = 1 ,
169
- )
170
- new_shape [dim ] = post_pad
171
- post_pad_tensor = impl .slice .expand (
172
- ctx ,
173
- target ,
174
- source_ir ,
175
- f"{ name } _expand_post{ i } " ,
176
- post_pad_tensor ,
177
- new_shape ,
178
- )
179
- output = impl .cat .cat (
180
- ctx ,
181
- target ,
182
- source_ir ,
183
- f"{ name } _concat_dim{ dim } " ,
184
- input = (pre_pad_tensor , input , post_pad_tensor ),
185
- dim = dim ,
186
- )
187
- input = output
188
-
189
- return output
111
+ rank = len (input .shape )
190
112
191
- else :
113
+ if len ( padding ) / 2 > rank :
192
114
raise RuntimeError (
193
- f"We currently only support for padding 1D, 2D, and 3D, but got { padding_dims } D "
115
+ f"Trying to pad last { len ( padding ) / 2 } dimension but the input only has { rank } dimension. "
194
116
)
195
117
118
+ start_list = [0 ] * len (input .shape )
119
+ new_shape = input .shape
120
+
121
+ for i in range (0 , len (padding ) // 2 ):
122
+ start_list [- i - 1 ] = - padding [i * 2 ]
123
+ new_shape [- i - 1 ] += padding [i * 2 ] + padding [i * 2 + 1 ]
124
+
125
+ stride_list = [1 ] * len (new_shape )
126
+ layer = ctx .net .add_slice (
127
+ input ,
128
+ start = tuple (start_list ),
129
+ shape = tuple (new_shape ),
130
+ stride = tuple (stride_list ),
131
+ )
132
+ layer .mode = trt .SliceMode .CLAMP
133
+
134
+ set_layer_name (layer , target , name , source_ir )
135
+ return layer .get_output (0 )
136
+
196
137
197
138
def circular_padNd (
198
139
ctx : ConversionContext ,
@@ -205,53 +146,32 @@ def circular_padNd(
205
146
if has_dynamic_shape (input .shape ):
206
147
assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for padding."
207
148
208
- padding_dims = len (pad ) // 2
209
-
210
- if padding_dims == 1 or padding_dims == 2 or padding_dims == 3 :
211
- for i in range (padding_dims ):
212
- dim = - 1 - i
213
- pre_pad , post_pad = pad [2 * i ], pad [2 * i + 1 ]
214
- pre_pad_tensor = impl .slice .slice_op (
215
- ctx ,
216
- target ,
217
- source_ir ,
218
- f"{ name } _slice_pre{ i } " ,
219
- input ,
220
- dim = dim ,
221
- start = input .shape [dim ] - pre_pad ,
222
- stop = input .shape [dim ],
223
- step = 1 ,
224
- )
225
-
226
- post_pad_tensor = impl .slice .slice_op (
227
- ctx ,
228
- target ,
229
- source_ir ,
230
- f"{ name } _slice_post{ i } " ,
231
- input ,
232
- dim = dim ,
233
- start = 0 ,
234
- stop = post_pad ,
235
- step = 1 ,
236
- )
237
-
238
- output = impl .cat .cat (
239
- ctx ,
240
- target ,
241
- source_ir ,
242
- f"{ name } _concat_dim{ dim } " ,
243
- input = (pre_pad_tensor , input , post_pad_tensor ),
244
- dim = dim ,
245
- )
246
- input = output
247
-
248
- return output
149
+ rank = len (input .shape )
249
150
250
- else :
151
+ if len ( pad ) / 2 > rank :
251
152
raise RuntimeError (
252
- f"We currently only support for padding 1D, 2D, and 3D, but got { padding_dims } D "
153
+ f"Trying to pad last { len ( pad ) / 2 } dimension but the input only has { rank } dimension. "
253
154
)
254
155
156
+ start_list = [0 ] * len (input .shape )
157
+ new_shape = input .shape
158
+
159
+ for i in range (0 , len (pad ) // 2 ):
160
+ start_list [- i - 1 ] = - pad [i * 2 ]
161
+ new_shape [- i - 1 ] += pad [i * 2 ] + pad [i * 2 + 1 ]
162
+
163
+ stride_list = [1 ] * len (new_shape )
164
+ layer = ctx .net .add_slice (
165
+ input ,
166
+ start = tuple (start_list ),
167
+ shape = tuple (new_shape ),
168
+ stride = tuple (stride_list ),
169
+ )
170
+ layer .mode = trt .SliceMode .WRAP
171
+
172
+ set_layer_name (layer , target , name , source_ir )
173
+ return layer .get_output (0 )
174
+
255
175
256
176
def pad (
257
177
ctx : ConversionContext ,
0 commit comments