@@ -134,6 +134,100 @@ def layer_norm(
134
134
return layer_norm .get_output (0 )
135
135
136
136
137
+ # def native_group_norm(
138
+ # ctx: ConversionContext,
139
+ # target: Target,
140
+ # source_ir: Optional[SourceIR],
141
+ # name: str,
142
+ # input: TRTTensor,
143
+ # weight: Optional[Union[torch.Tensor, np.ndarray]],
144
+ # bias: Optional[Union[torch.Tensor, np.ndarray]],
145
+ # N: int,
146
+ # C: int,
147
+ # HxW: int,
148
+ # group: int,
149
+ # eps: float,
150
+ # return_mean_rstd: bool = True,
151
+ # ) -> Union[TRTTensor, Sequence[TRTTensor]]:
152
+ # assert (
153
+ # len(input.shape) >= 3
154
+ # ), f"The input dimension should not be less than 3, got {len(input.shape)}!"
155
+
156
+ # B = input.shape[0]
157
+ # # if C is provided, it must be as same as the channel from the input shape,
158
+ # # else if C is zero, we should get the channel from the input shape
159
+ # if C == 0:
160
+ # C = input.shape[1]
161
+ # assert (
162
+ # C == input.shape[1]
163
+ # ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
164
+ # # Groups are a subdivision of the channel dimension.
165
+ # assert (
166
+ # C % group == 0
167
+ # ), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
168
+ # input = get_trt_tensor(ctx, input, f"{name}_input")
169
+
170
+ # shape = list(input.shape)
171
+
172
+ # for i, s in enumerate(shape):
173
+ # if i == 0 and s > 0:
174
+ # shape[i] = B * group
175
+ # elif i == 1:
176
+ # shape[i] = C // group
177
+ # elif i > 1 and s == -1:
178
+ # shape[i] = 0
179
+
180
+ # # Normalize every group.
181
+ # reshaped_input = impl.shuffle.reshape(
182
+ # ctx,
183
+ # target,
184
+ # source_ir,
185
+ # f"{name}_reshape_input",
186
+ # input,
187
+ # shape,
188
+ # )
189
+
190
+ # weight = get_trt_tensor(ctx, weight, f"{name}_weight")
191
+ # bias = get_trt_tensor(ctx, bias, f"{name}_bias")
192
+ # if tuple(reshaped_input.shape) != tuple(weight.shape):
193
+ # weight = impl.slice.expand(
194
+ # ctx,
195
+ # target,
196
+ # source_ir,
197
+ # f"{name}_expand_weight",
198
+ # weight,
199
+ # reshaped_input.shape,
200
+ # )
201
+ # if tuple(reshaped_input.shape) != tuple(bias.shape):
202
+ # bias = impl.slice.expand(
203
+ # ctx, target, source_ir, f"{name}_expand_bias", bias, reshaped_input.shape
204
+ # )
205
+ # dims = list(range(1, len(input.shape)))
206
+ # axes = get_axes_for_reduce_op(dims)
207
+ # group_norm = ctx.net.add_normalization(reshaped_input, weight, bias, axes)
208
+ # group_norm.epsilon = eps
209
+ # group_norm.compute_precision = input.dtype
210
+ # set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
211
+ # output = group_norm.get_output(0)
212
+
213
+ # shape = list(output.shape)
214
+ # for i, s in enumerate(shape):
215
+ # if i == 0 and s > 0:
216
+ # shape[i] = B
217
+ # elif i == 1:
218
+ # shape[i] = C
219
+ # elif i > 1 and s == -1:
220
+ # shape[i] = 0
221
+
222
+ # reshaped_output = impl.shuffle.reshape(
223
+ # ctx, target, source_ir, f"{name}_reshape_output", output, shape
224
+ # )
225
+ # if return_mean_rstd:
226
+ # # return fake mean and rstd for now
227
+ # return reshaped_output, None, None
228
+ # return reshaped_output
229
+
230
+
137
231
def native_group_norm (
138
232
ctx : ConversionContext ,
139
233
target : Target ,
@@ -189,22 +283,35 @@ def native_group_norm(
189
283
190
284
weight = get_trt_tensor (ctx , weight , f"{ name } _weight" )
191
285
bias = get_trt_tensor (ctx , bias , f"{ name } _bias" )
192
- if tuple (reshaped_input .shape ) != tuple (weight .shape ):
193
- weight = impl .slice .expand (
194
- ctx ,
195
- target ,
196
- source_ir ,
197
- f"{ name } _expand_weight" ,
198
- weight ,
199
- reshaped_input .shape ,
200
- )
201
- if tuple (reshaped_input .shape ) != tuple (bias .shape ):
202
- bias = impl .slice .expand (
203
- ctx , target , source_ir , f"{ name } _expand_bias" , bias , reshaped_input .shape
204
- )
286
+ weight_bias_shape = (1 , C ) + (1 ,) * (len (input .shape ) - 2 )
287
+
205
288
dims = list (range (1 , len (input .shape )))
206
289
axes = get_axes_for_reduce_op (dims )
207
- group_norm = ctx .net .add_normalization (reshaped_input , weight , bias , axes )
290
+ # Use dummy weight since the normalization layer cannot well handle the scale and shift of group norm due to shape mismatch
291
+ # TODO: check with TRT the correct way to use 'num_groups' to implement group norm
292
+ dummy_weight = get_trt_tensor (
293
+ ctx , np .array ([1.0 ]), f"{ name } _dummy_weight" , input .dtype
294
+ )
295
+ dummy_weight = impl .slice .expand (
296
+ ctx ,
297
+ target ,
298
+ source_ir ,
299
+ f"{ name } _expand_dummy_weight" ,
300
+ dummy_weight ,
301
+ reshaped_input .shape ,
302
+ )
303
+ dummy_bias = get_trt_tensor (ctx , np .array ([0.0 ]), f"{ name } _dummy_bias" , input .dtype )
304
+ dummy_bias = impl .slice .expand (
305
+ ctx ,
306
+ target ,
307
+ source_ir ,
308
+ f"{ name } _expand_dummy_bias" ,
309
+ dummy_bias ,
310
+ reshaped_input .shape ,
311
+ )
312
+ group_norm = ctx .net .add_normalization (
313
+ reshaped_input , dummy_weight , dummy_bias , axes
314
+ )
208
315
group_norm .epsilon = eps
209
316
group_norm .compute_precision = input .dtype
210
317
set_layer_name (group_norm , target , f"{ name } _group_norm" , source_ir )
@@ -222,6 +329,41 @@ def native_group_norm(
222
329
reshaped_output = impl .shuffle .reshape (
223
330
ctx , target , source_ir , f"{ name } _reshape_output" , output , shape
224
331
)
332
+
333
+ weight = impl .shuffle .reshape (
334
+ ctx ,
335
+ target ,
336
+ source_ir ,
337
+ f"{ name } _weight" ,
338
+ weight ,
339
+ weight_bias_shape ,
340
+ )
341
+
342
+ reshaped_output = impl .elementwise .mul (
343
+ ctx ,
344
+ target ,
345
+ source_ir ,
346
+ f"{ name } _mul_weight" ,
347
+ reshaped_output ,
348
+ weight ,
349
+ )
350
+
351
+ bias = impl .shuffle .reshape (
352
+ ctx ,
353
+ target ,
354
+ source_ir ,
355
+ f"{ name } _reshape_bias" ,
356
+ bias ,
357
+ weight_bias_shape ,
358
+ )
359
+ reshaped_output = impl .elementwise .add (
360
+ ctx ,
361
+ target ,
362
+ source_ir ,
363
+ f"{ name } _add_bias" ,
364
+ reshaped_output ,
365
+ bias ,
366
+ )
225
367
if return_mean_rstd :
226
368
# return fake mean and rstd for now
227
369
return reshaped_output , None , None
0 commit comments