1
1
from typing import Optional , Sequence , Union
2
2
3
3
import numpy as np
4
+ import tensorrt as trt
4
5
import torch_tensorrt .dynamo .conversion .impl as impl
5
6
from torch .fx .node import Target
6
7
from torch_tensorrt import _enums
12
13
get_trt_tensor ,
13
14
set_layer_name ,
14
15
)
16
+ from torch_tensorrt .dynamo .conversion .impl .shape import get_shape_with_dynamic_shape
15
17
from torch_tensorrt .fx .types import TRTTensor
16
18
17
- import tensorrt as trt
18
-
19
19
20
20
def reshape (
21
21
ctx : ConversionContext ,
@@ -61,35 +61,106 @@ def pixel_shuffle(
61
61
input : TRTTensor ,
62
62
upscale_factor : int ,
63
63
) -> TRTTensor :
64
- shape = input .shape
65
- in_channels , in_height , in_width = shape [- 3 :]
66
- out_channels = in_channels // (upscale_factor ** 2 )
67
- out_height = in_height * upscale_factor
68
- out_width = in_width * upscale_factor
69
- new_shape = shape [:- 3 ] + (
70
- out_channels ,
64
+ # Get input shape tensor
65
+ input_shape_tensor = get_shape_with_dynamic_shape (
66
+ ctx ,
67
+ target ,
68
+ source_ir ,
69
+ name + "_shape" ,
70
+ input .shape ,
71
+ input ,
72
+ )
73
+
74
+ # Extract in_channels, in_height, and in_width from the input shape tensor
75
+ in_channels_tensor = ctx .net .add_slice (
76
+ input_shape_tensor , start = (len (input .shape ) - 3 ,), shape = (1 ,), stride = (1 ,)
77
+ ).get_output (0 )
78
+ in_height_tensor = ctx .net .add_slice (
79
+ input_shape_tensor , start = (len (input .shape ) - 2 ,), shape = (1 ,), stride = (1 ,)
80
+ ).get_output (0 )
81
+ in_width_tensor = ctx .net .add_slice (
82
+ input_shape_tensor , start = (len (input .shape ) - 1 ,), shape = (1 ,), stride = (1 ,)
83
+ ).get_output (0 )
84
+
85
+ # Calculate out_channels, out_height, and out_width as tensors
86
+ upscale_factor_sq = upscale_factor * upscale_factor
87
+ upscale_factor_tensor = get_trt_tensor (
88
+ ctx , upscale_factor , f"{ name } _upscale_factor"
89
+ )
90
+ upscale_factor_sq_tensor = get_trt_tensor (
91
+ ctx , upscale_factor_sq , f"{ name } _upscale_factor_sq"
92
+ )
93
+
94
+ out_channels_tensor = impl .elementwise .floor_divide (
95
+ ctx ,
96
+ target ,
97
+ source_ir ,
98
+ f"{ name } _out_channels_tensor" ,
99
+ in_channels_tensor ,
100
+ upscale_factor_sq_tensor ,
101
+ )
102
+ out_height_tensor = impl .elementwise .mul (
103
+ ctx ,
104
+ target ,
105
+ source_ir ,
106
+ f"{ name } _out_height_tensor" ,
107
+ in_height_tensor ,
71
108
upscale_factor ,
109
+ )
110
+ out_width_tensor = impl .elementwise .mul (
111
+ ctx ,
112
+ target ,
113
+ source_ir ,
114
+ f"{ name } _out_width_tensor" ,
115
+ in_width_tensor ,
72
116
upscale_factor ,
73
- in_height ,
74
- in_width ,
75
117
)
118
+
119
+ # Construct new shape tensor
120
+ new_shape_tensors = [
121
+ ctx .net .add_slice (
122
+ input_shape_tensor , start = (i ,), shape = (1 ,), stride = (1 ,)
123
+ ).get_output (0 )
124
+ for i in range (len (input .shape ) - 3 )
125
+ ]
126
+ new_shape_tensors += [
127
+ out_channels_tensor ,
128
+ upscale_factor_tensor ,
129
+ upscale_factor_tensor ,
130
+ in_height_tensor ,
131
+ in_width_tensor ,
132
+ ]
133
+
134
+ # Reshape tensor
76
135
reshaped_tensor = reshape (
77
- ctx , target , source_ir , f"{ name } _reshape1 " , input , new_shape
136
+ ctx , target , source_ir , f"{ name } _reshape " , input , new_shape_tensors
78
137
)
79
- rank = len (shape )
138
+
139
+ # Permute shape
140
+ rank = len (input .shape )
80
141
permute_shape = list (range (rank ))
81
142
permute_shape .insert (- 2 , rank )
82
143
permute_shape .insert (- 1 , rank + 1 )
83
144
permuted_tensor = impl .permutation .permute (
84
145
ctx , target , source_ir , f"{ name } _permute" , reshaped_tensor , permute_shape
85
146
)
147
+
148
+ # Construct output shape tensor
149
+ out_shape_tensors = [
150
+ ctx .net .add_slice (
151
+ input_shape_tensor , start = (i ,), shape = (1 ,), stride = (1 ,)
152
+ ).get_output (0 )
153
+ for i in range (len (input .shape ) - 3 )
154
+ ]
155
+ out_shape_tensors += [out_channels_tensor , out_height_tensor , out_width_tensor ]
156
+
86
157
return reshape (
87
158
ctx ,
88
159
target ,
89
160
source_ir ,
90
- f"{ name } _reshape2 " ,
161
+ f"{ name } _reshape_out " ,
91
162
permuted_tensor ,
92
- shape [: - 3 ] + ( out_channels , out_height , out_width ) ,
163
+ out_shape_tensors ,
93
164
)
94
165
95
166
@@ -101,39 +172,109 @@ def pixel_unshuffle(
101
172
input : TRTTensor ,
102
173
downscale_factor : int ,
103
174
) -> TRTTensor :
104
- shape = input .shape
105
- in_channels , in_height , in_width = shape [- 3 :]
106
- out_channels = in_channels * (downscale_factor ** 2 )
107
- out_height = in_height // downscale_factor
108
- out_width = in_width // downscale_factor
109
- new_shape = shape [:- 3 ] + (
110
- in_channels ,
111
- out_height ,
112
- downscale_factor ,
113
- out_width ,
114
- downscale_factor ,
175
+ # Get input shape tensor
176
+ input_shape_tensor = get_shape_with_dynamic_shape (
177
+ ctx ,
178
+ target ,
179
+ source_ir ,
180
+ name + "_shape" ,
181
+ input .shape ,
182
+ input ,
183
+ )
184
+
185
+ # Extract in_channels, in_height, and in_width from the input shape tensor
186
+ in_channels_tensor = ctx .net .add_slice (
187
+ input_shape_tensor , start = (len (input .shape ) - 3 ,), shape = (1 ,), stride = (1 ,)
188
+ ).get_output (0 )
189
+ in_height_tensor = ctx .net .add_slice (
190
+ input_shape_tensor , start = (len (input .shape ) - 2 ,), shape = (1 ,), stride = (1 ,)
191
+ ).get_output (0 )
192
+ in_width_tensor = ctx .net .add_slice (
193
+ input_shape_tensor , start = (len (input .shape ) - 1 ,), shape = (1 ,), stride = (1 ,)
194
+ ).get_output (0 )
195
+
196
+ # Calculate out_channels, out_height, and out_width as tensors
197
+ downscale_factor_sq = downscale_factor * downscale_factor
198
+ downscale_factor_tensor = get_trt_tensor (
199
+ ctx , downscale_factor , f"{ name } _downscale_factor"
200
+ )
201
+ downscale_factor_sq_tensor = get_trt_tensor (
202
+ ctx , downscale_factor_sq , f"{ name } _downscale_factor_sq"
115
203
)
204
+
205
+ out_channels_tensor = impl .elementwise .mul (
206
+ ctx ,
207
+ target ,
208
+ source_ir ,
209
+ f"{ name } _out_channels_tensor" ,
210
+ in_channels_tensor ,
211
+ downscale_factor_sq_tensor ,
212
+ )
213
+ out_height_tensor = impl .elementwise .floor_divide (
214
+ ctx ,
215
+ target ,
216
+ source_ir ,
217
+ f"{ name } _out_height_tensor" ,
218
+ in_height_tensor ,
219
+ downscale_factor_tensor ,
220
+ )
221
+ out_width_tensor = impl .elementwise .floor_divide (
222
+ ctx ,
223
+ target ,
224
+ source_ir ,
225
+ f"{ name } _out_width_tensor" ,
226
+ in_width_tensor ,
227
+ downscale_factor_tensor ,
228
+ )
229
+
230
+ # Construct new shape tensor
231
+ new_shape_tensors = [
232
+ ctx .net .add_slice (
233
+ input_shape_tensor , start = (i ,), shape = (1 ,), stride = (1 ,)
234
+ ).get_output (0 )
235
+ for i in range (len (input .shape ) - 3 )
236
+ ]
237
+ new_shape_tensors += [
238
+ in_channels_tensor ,
239
+ out_height_tensor ,
240
+ downscale_factor_tensor ,
241
+ out_width_tensor ,
242
+ downscale_factor_tensor ,
243
+ ]
244
+
116
245
reshaped_tensor = reshape (
117
- ctx , target , source_ir , f"{ name } _reshape1 " , input , new_shape
246
+ ctx , target , source_ir , f"{ name } _reshape " , input , new_shape_tensors
118
247
)
119
- rank = len (new_shape )
120
- permute_shape = tuple (range (rank - 5 )) + (
248
+
249
+ # Permute shape
250
+ rank = len (new_shape_tensors )
251
+ permute_shape = list (range (rank - 5 )) + [
121
252
rank - 5 , # in_channels
122
253
rank - 3 , # downscale_factor
123
254
rank - 1 , # downscale_factor
124
255
rank - 4 , # out_height
125
256
rank - 2 , # out_width
126
- )
257
+ ]
127
258
permuted_tensor = impl .permutation .permute (
128
259
ctx , target , source_ir , f"{ name } _permute" , reshaped_tensor , permute_shape
129
260
)
261
+
262
+ # Construct output shape tensor
263
+ out_shape_tensors = [
264
+ ctx .net .add_slice (
265
+ input_shape_tensor , start = (i ,), shape = (1 ,), stride = (1 ,)
266
+ ).get_output (0 )
267
+ for i in range (len (input .shape ) - 3 )
268
+ ]
269
+ out_shape_tensors += [out_channels_tensor , out_height_tensor , out_width_tensor ]
270
+
130
271
return reshape (
131
272
ctx ,
132
273
target ,
133
274
source_ir ,
134
- f"{ name } _reshape2 " ,
275
+ f"{ name } _reshape_out " ,
135
276
permuted_tensor ,
136
- shape [: - 3 ] + ( out_channels , out_height , out_width ) ,
277
+ out_shape_tensors ,
137
278
)
138
279
139
280
0 commit comments