@@ -227,6 +227,143 @@ def layer_norm(
227
227
# return reshaped_output, None, None
228
228
# return reshaped_output
229
229
230
+ # def native_group_norm(
231
+ # ctx: ConversionContext,
232
+ # target: Target,
233
+ # source_ir: Optional[SourceIR],
234
+ # name: str,
235
+ # input: TRTTensor,
236
+ # weight: Optional[Union[torch.Tensor, np.ndarray]],
237
+ # bias: Optional[Union[torch.Tensor, np.ndarray]],
238
+ # N: int,
239
+ # C: int,
240
+ # HxW: int,
241
+ # group: int,
242
+ # eps: float,
243
+ # return_mean_rstd: bool = True,
244
+ # ) -> Union[TRTTensor, Sequence[TRTTensor]]:
245
+ # assert (
246
+ # len(input.shape) >= 3
247
+ # ), f"The input dimension should not be less than 3, got {len(input.shape)}!"
248
+
249
+ # B = input.shape[0]
250
+ # # if C is provided, it must be as same as the channel from the input shape,
251
+ # # else if C is zero, we should get the channel from the input shape
252
+ # if C == 0:
253
+ # C = input.shape[1]
254
+ # assert (
255
+ # C == input.shape[1]
256
+ # ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
257
+ # # Groups are a subdivision of the channel dimension.
258
+ # assert (
259
+ # C % group == 0
260
+ # ), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
261
+ # input = get_trt_tensor(ctx, input, f"{name}_input")
262
+
263
+ # shape = list(input.shape)
264
+
265
+ # for i, s in enumerate(shape):
266
+ # if i == 0 and s > 0:
267
+ # shape[i] = B * group
268
+ # elif i == 1:
269
+ # shape[i] = C // group
270
+ # elif i > 1 and s == -1:
271
+ # shape[i] = 0
272
+
273
+ # # Normalize every group.
274
+ # reshaped_input = impl.shuffle.reshape(
275
+ # ctx,
276
+ # target,
277
+ # source_ir,
278
+ # f"{name}_reshape_input",
279
+ # input,
280
+ # shape,
281
+ # )
282
+
283
+ # weight = get_trt_tensor(ctx, weight, f"{name}_weight")
284
+ # bias = get_trt_tensor(ctx, bias, f"{name}_bias")
285
+ # weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
286
+
287
+ # dims = list(range(1, len(input.shape)))
288
+ # axes = get_axes_for_reduce_op(dims)
289
+ # dummy_weight = get_trt_tensor(ctx, np.array([1.0]), f"{name}_dummy_weight")
290
+ # dummy_weight = impl.slice.expand(
291
+ # ctx, target, source_ir, f"{name}_expand_dummy_weight", dummy_weight, reshaped_input.shape
292
+ # )
293
+ # dummy_bias = get_trt_tensor(ctx, np.array([0.0]), f"{name}_dummy_bias")
294
+ # dummy_bias = impl.slice.expand(
295
+ # ctx, target, source_ir, f"{name}_expand_dummy_bias", dummy_bias, reshaped_input.shape
296
+ # )
297
+ # group_norm = ctx.net.add_normalization(reshaped_input, dummy_weight, dummy_bias, axes)
298
+ # group_norm.epsilon = eps
299
+ # group_norm.compute_precision = input.dtype
300
+ # set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
301
+ # output = group_norm.get_output(0)
302
+
303
+ # shape = list(output.shape)
304
+ # for i, s in enumerate(shape):
305
+ # if i == 0 and s > 0:
306
+ # shape[i] = B
307
+ # elif i == 1:
308
+ # shape[i] = C
309
+ # elif i > 1 and s == -1:
310
+ # shape[i] = 0
311
+
312
+ # reshaped_output = impl.shuffle.reshape(
313
+ # ctx, target, source_ir, f"{name}_reshape_output", output, shape
314
+ # )
315
+
316
+
317
+ # # weight = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_weight_unsqueeze1", weight, (0))
318
+ # # weight = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_weight_unsqueeze2", weight, (2))
319
+ # # weight = impl.slice.expand(
320
+ # # ctx, target, source_ir, f"{name}_expand_weight", weight, shape
321
+ # # )
322
+ # # bias = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_bias_unsqueeze1", bias, (0))
323
+ # # bias = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_bias_unsqueeze2", bias, (2))
324
+ # # bias = impl.slice.expand(
325
+ # # ctx, target, source_ir, f"{name}_expand_bias", bias, shape
326
+ # # )
327
+
328
+ # reshaped_gamma = impl.shuffle.reshape(
329
+ # ctx,
330
+ # target,
331
+ # source_ir,
332
+ # f"{name}_reshape_gamma",
333
+ # weight,
334
+ # weight_bias_shape,
335
+ # )
336
+
337
+ # reshaped_output = impl.elementwise.mul(
338
+ # ctx,
339
+ # target,
340
+ # source_ir,
341
+ # f"{name}_mul_gamma",
342
+ # reshaped_output,
343
+ # reshaped_gamma,
344
+ # )
345
+
346
+ # reshaped_bias = impl.shuffle.reshape(
347
+ # ctx,
348
+ # target,
349
+ # source_ir,
350
+ # f"{name}_reshape_beta",
351
+ # bias,
352
+ # weight_bias_shape,
353
+ # )
354
+ # reshaped_output = impl.elementwise.add(
355
+ # ctx,
356
+ # target,
357
+ # source_ir,
358
+ # f"{name}_add_beta",
359
+ # reshaped_output,
360
+ # reshaped_bias,
361
+ # )
362
+ # if return_mean_rstd:
363
+ # # return fake mean and rstd for now
364
+ # return reshaped_output, None, None
365
+ # return reshaped_output
366
+
230
367
231
368
def native_group_norm (
232
369
ctx : ConversionContext ,
@@ -286,36 +423,94 @@ def native_group_norm(
286
423
weight_bias_shape = (1 , C ) + (1 ,) * (len (input .shape ) - 2 )
287
424
288
425
dims = list (range (1 , len (input .shape )))
289
- axes = get_axes_for_reduce_op (dims )
290
- # Use dummy weight since the normalization layer cannot well handle the scale and shift of group norm due to shape mismatch
291
- # TODO: check with TRT the correct way to use 'num_groups' to implement group norm
292
- dummy_weight = get_trt_tensor (
293
- ctx , np .array ([1.0 ]), f"{ name } _dummy_weight" , input .dtype
426
+
427
+ # E[X]
428
+ mean_trt = impl .reduce .mean (
429
+ ctx ,
430
+ target ,
431
+ source_ir ,
432
+ f"{ name } _mean" ,
433
+ reshaped_input ,
434
+ dims ,
435
+ True ,
294
436
)
295
- dummy_weight = impl .slice .expand (
437
+
438
+ mean_trt = impl .slice .expand (
296
439
ctx ,
297
440
target ,
298
441
source_ir ,
299
- f"{ name } _expand_dummy_weight " ,
300
- dummy_weight ,
442
+ f"{ name } _expand_mean_trt " ,
443
+ mean_trt ,
301
444
reshaped_input .shape ,
302
445
)
303
- dummy_bias = get_trt_tensor (ctx , np .array ([0.0 ]), f"{ name } _dummy_bias" , input .dtype )
304
- dummy_bias = impl .slice .expand (
446
+
447
+ # X - E[X]
448
+ sub_trt = impl .elementwise .sub (
449
+ ctx ,
450
+ target ,
451
+ source_ir ,
452
+ f"{ name } _sub" ,
453
+ reshaped_input ,
454
+ mean_trt ,
455
+ )
456
+
457
+ # variance
458
+ pow_trt = get_trt_tensor (ctx , 2 , f"{ name } _power" , np .float32 )
459
+ pow_var = impl .elementwise .pow (
460
+ ctx ,
461
+ target ,
462
+ source_ir ,
463
+ f"{ name } _pow" ,
464
+ sub_trt ,
465
+ pow_trt ,
466
+ )
467
+
468
+ var_trt = impl .reduce .mean (
469
+ ctx ,
470
+ target ,
471
+ source_ir ,
472
+ f"{ name } _mean_var" ,
473
+ pow_var ,
474
+ dims ,
475
+ True ,
476
+ )
477
+
478
+ var_trt = impl .slice .expand (
305
479
ctx ,
306
480
target ,
307
481
source_ir ,
308
- f"{ name } _expand_dummy_bias " ,
309
- dummy_bias ,
482
+ f"{ name } _expand_var_trt " ,
483
+ var_trt ,
310
484
reshaped_input .shape ,
311
485
)
312
- group_norm = ctx .net .add_normalization (
313
- reshaped_input , dummy_weight , dummy_bias , axes
486
+
487
+ eps_trt = get_trt_tensor (ctx , eps , f"{ name } _eps" , np .float32 )
488
+ add_trt = impl .elementwise .add (
489
+ ctx ,
490
+ target ,
491
+ source_ir ,
492
+ f"{ name } _add" ,
493
+ var_trt ,
494
+ eps_trt ,
495
+ )
496
+
497
+ sqrt_trt = impl .unary .sqrt (
498
+ ctx ,
499
+ target ,
500
+ source_ir ,
501
+ f"{ name } _sqrt" ,
502
+ add_trt ,
503
+ )
504
+
505
+ # y = (X - E[X]) / sqrt((var + eps))
506
+ output = impl .elementwise .div (
507
+ ctx ,
508
+ target ,
509
+ source_ir ,
510
+ f"{ name } _div" ,
511
+ sub_trt ,
512
+ sqrt_trt ,
314
513
)
315
- group_norm .epsilon = eps
316
- group_norm .compute_precision = input .dtype
317
- set_layer_name (group_norm , target , f"{ name } _group_norm" , source_ir )
318
- output = group_norm .get_output (0 )
319
514
320
515
shape = list (output .shape )
321
516
for i , s in enumerate (shape ):
@@ -329,12 +524,11 @@ def native_group_norm(
329
524
reshaped_output = impl .shuffle .reshape (
330
525
ctx , target , source_ir , f"{ name } _reshape_output" , output , shape
331
526
)
332
-
333
- weight = impl .shuffle .reshape (
527
+ reshaped_gamma = impl .shuffle .reshape (
334
528
ctx ,
335
529
target ,
336
530
source_ir ,
337
- f"{ name } _weight " ,
531
+ f"{ name } _reshape_gamma " ,
338
532
weight ,
339
533
weight_bias_shape ,
340
534
)
@@ -343,26 +537,26 @@ def native_group_norm(
343
537
ctx ,
344
538
target ,
345
539
source_ir ,
346
- f"{ name } _mul_weight " ,
540
+ f"{ name } _mul_gamma " ,
347
541
reshaped_output ,
348
- weight ,
542
+ reshaped_gamma ,
349
543
)
350
544
351
- bias = impl .shuffle .reshape (
545
+ reshaped_bias = impl .shuffle .reshape (
352
546
ctx ,
353
547
target ,
354
548
source_ir ,
355
- f"{ name } _reshape_bias " ,
549
+ f"{ name } _reshape_beta " ,
356
550
bias ,
357
551
weight_bias_shape ,
358
552
)
359
553
reshaped_output = impl .elementwise .add (
360
554
ctx ,
361
555
target ,
362
556
source_ir ,
363
- f"{ name } _add_bias " ,
557
+ f"{ name } _add_beta " ,
364
558
reshaped_output ,
365
- bias ,
559
+ reshaped_bias ,
366
560
)
367
561
if return_mean_rstd :
368
562
# return fake mean and rstd for now
0 commit comments