@@ -227,6 +227,143 @@ def layer_norm(
227
227
# return reshaped_output, None, None
228
228
# return reshaped_output
229
229
230
+ # def native_group_norm(
231
+ # ctx: ConversionContext,
232
+ # target: Target,
233
+ # source_ir: Optional[SourceIR],
234
+ # name: str,
235
+ # input: TRTTensor,
236
+ # weight: Optional[Union[torch.Tensor, np.ndarray]],
237
+ # bias: Optional[Union[torch.Tensor, np.ndarray]],
238
+ # N: int,
239
+ # C: int,
240
+ # HxW: int,
241
+ # group: int,
242
+ # eps: float,
243
+ # return_mean_rstd: bool = True,
244
+ # ) -> Union[TRTTensor, Sequence[TRTTensor]]:
245
+ # assert (
246
+ # len(input.shape) >= 3
247
+ # ), f"The input dimension should not be less than 3, got {len(input.shape)}!"
248
+
249
+ # B = input.shape[0]
250
+ # # if C is provided, it must be as same as the channel from the input shape,
251
+ # # else if C is zero, we should get the channel from the input shape
252
+ # if C == 0:
253
+ # C = input.shape[1]
254
+ # assert (
255
+ # C == input.shape[1]
256
+ # ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
257
+ # # Groups are a subdivision of the channel dimension.
258
+ # assert (
259
+ # C % group == 0
260
+ # ), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
261
+ # input = get_trt_tensor(ctx, input, f"{name}_input")
262
+
263
+ # shape = list(input.shape)
264
+
265
+ # for i, s in enumerate(shape):
266
+ # if i == 0 and s > 0:
267
+ # shape[i] = B * group
268
+ # elif i == 1:
269
+ # shape[i] = C // group
270
+ # elif i > 1 and s == -1:
271
+ # shape[i] = 0
272
+
273
+ # # Normalize every group.
274
+ # reshaped_input = impl.shuffle.reshape(
275
+ # ctx,
276
+ # target,
277
+ # source_ir,
278
+ # f"{name}_reshape_input",
279
+ # input,
280
+ # shape,
281
+ # )
282
+
283
+ # weight = get_trt_tensor(ctx, weight, f"{name}_weight")
284
+ # bias = get_trt_tensor(ctx, bias, f"{name}_bias")
285
+ # weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
286
+
287
+ # dims = list(range(1, len(input.shape)))
288
+ # axes = get_axes_for_reduce_op(dims)
289
+ # dummy_weight = get_trt_tensor(ctx, np.array([1.0]), f"{name}_dummy_weight")
290
+ # dummy_weight = impl.slice.expand(
291
+ # ctx, target, source_ir, f"{name}_expand_dummy_weight", dummy_weight, reshaped_input.shape
292
+ # )
293
+ # dummy_bias = get_trt_tensor(ctx, np.array([0.0]), f"{name}_dummy_bias")
294
+ # dummy_bias = impl.slice.expand(
295
+ # ctx, target, source_ir, f"{name}_expand_dummy_bias", dummy_bias, reshaped_input.shape
296
+ # )
297
+ # group_norm = ctx.net.add_normalization(reshaped_input, dummy_weight, dummy_bias, axes)
298
+ # group_norm.epsilon = eps
299
+ # group_norm.compute_precision = input.dtype
300
+ # set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
301
+ # output = group_norm.get_output(0)
302
+
303
+ # shape = list(output.shape)
304
+ # for i, s in enumerate(shape):
305
+ # if i == 0 and s > 0:
306
+ # shape[i] = B
307
+ # elif i == 1:
308
+ # shape[i] = C
309
+ # elif i > 1 and s == -1:
310
+ # shape[i] = 0
311
+
312
+ # reshaped_output = impl.shuffle.reshape(
313
+ # ctx, target, source_ir, f"{name}_reshape_output", output, shape
314
+ # )
315
+
316
+
317
+ # # weight = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_weight_unsqueeze1", weight, (0))
318
+ # # weight = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_weight_unsqueeze2", weight, (2))
319
+ # # weight = impl.slice.expand(
320
+ # # ctx, target, source_ir, f"{name}_expand_weight", weight, shape
321
+ # # )
322
+ # # bias = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_bias_unsqueeze1", bias, (0))
323
+ # # bias = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_bias_unsqueeze2", bias, (2))
324
+ # # bias = impl.slice.expand(
325
+ # # ctx, target, source_ir, f"{name}_expand_bias", bias, shape
326
+ # # )
327
+
328
+ # reshaped_gamma = impl.shuffle.reshape(
329
+ # ctx,
330
+ # target,
331
+ # source_ir,
332
+ # f"{name}_reshape_gamma",
333
+ # weight,
334
+ # weight_bias_shape,
335
+ # )
336
+
337
+ # reshaped_output = impl.elementwise.mul(
338
+ # ctx,
339
+ # target,
340
+ # source_ir,
341
+ # f"{name}_mul_gamma",
342
+ # reshaped_output,
343
+ # reshaped_gamma,
344
+ # )
345
+
346
+ # reshaped_bias = impl.shuffle.reshape(
347
+ # ctx,
348
+ # target,
349
+ # source_ir,
350
+ # f"{name}_reshape_beta",
351
+ # bias,
352
+ # weight_bias_shape,
353
+ # )
354
+ # reshaped_output = impl.elementwise.add(
355
+ # ctx,
356
+ # target,
357
+ # source_ir,
358
+ # f"{name}_add_beta",
359
+ # reshaped_output,
360
+ # reshaped_bias,
361
+ # )
362
+ # if return_mean_rstd:
363
+ # # return fake mean and rstd for now
364
+ # return reshaped_output, None, None
365
+ # return reshaped_output
366
+
230
367
231
368
def native_group_norm (
232
369
ctx : ConversionContext ,
@@ -243,6 +380,8 @@ def native_group_norm(
243
380
eps : float ,
244
381
return_mean_rstd : bool = True ,
245
382
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
383
+ # TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
384
+ # with INormalization Layer
246
385
assert (
247
386
len (input .shape ) >= 3
248
387
), f"The input dimension should not be less than 3, got { len (input .shape )} !"
@@ -286,36 +425,94 @@ def native_group_norm(
286
425
weight_bias_shape = (1 , C ) + (1 ,) * (len (input .shape ) - 2 )
287
426
288
427
dims = list (range (1 , len (input .shape )))
289
- axes = get_axes_for_reduce_op (dims )
290
- # Use dummy weight since the normalization layer cannot well handle the scale and shift of group norm due to shape mismatch
291
- # TODO: check with TRT the correct way to use 'num_groups' to implement group norm
292
- dummy_weight = get_trt_tensor (
293
- ctx , np .array ([1.0 ]), f"{ name } _dummy_weight" , input .dtype
428
+
429
+ # E[X]
430
+ mean_trt = impl .reduce .mean (
431
+ ctx ,
432
+ target ,
433
+ source_ir ,
434
+ f"{ name } _mean" ,
435
+ reshaped_input ,
436
+ dims ,
437
+ True ,
294
438
)
295
- dummy_weight = impl .slice .expand (
439
+
440
+ mean_trt = impl .slice .expand (
296
441
ctx ,
297
442
target ,
298
443
source_ir ,
299
- f"{ name } _expand_dummy_weight " ,
300
- dummy_weight ,
444
+ f"{ name } _expand_mean_trt " ,
445
+ mean_trt ,
301
446
reshaped_input .shape ,
302
447
)
303
- dummy_bias = get_trt_tensor (ctx , np .array ([0.0 ]), f"{ name } _dummy_bias" , input .dtype )
304
- dummy_bias = impl .slice .expand (
448
+
449
+ # X - E[X]
450
+ sub_trt = impl .elementwise .sub (
451
+ ctx ,
452
+ target ,
453
+ source_ir ,
454
+ f"{ name } _sub" ,
455
+ reshaped_input ,
456
+ mean_trt ,
457
+ )
458
+
459
+ # variance
460
+ pow_trt = get_trt_tensor (ctx , 2 , f"{ name } _power" , np .float32 )
461
+ pow_var = impl .elementwise .pow (
462
+ ctx ,
463
+ target ,
464
+ source_ir ,
465
+ f"{ name } _pow" ,
466
+ sub_trt ,
467
+ pow_trt ,
468
+ )
469
+
470
+ var_trt = impl .reduce .mean (
471
+ ctx ,
472
+ target ,
473
+ source_ir ,
474
+ f"{ name } _mean_var" ,
475
+ pow_var ,
476
+ dims ,
477
+ True ,
478
+ )
479
+
480
+ var_trt = impl .slice .expand (
305
481
ctx ,
306
482
target ,
307
483
source_ir ,
308
- f"{ name } _expand_dummy_bias " ,
309
- dummy_bias ,
484
+ f"{ name } _expand_var_trt " ,
485
+ var_trt ,
310
486
reshaped_input .shape ,
311
487
)
312
- group_norm = ctx .net .add_normalization (
313
- reshaped_input , dummy_weight , dummy_bias , axes
488
+
489
+ eps_trt = get_trt_tensor (ctx , eps , f"{ name } _eps" , np .float32 )
490
+ add_trt = impl .elementwise .add (
491
+ ctx ,
492
+ target ,
493
+ source_ir ,
494
+ f"{ name } _add" ,
495
+ var_trt ,
496
+ eps_trt ,
497
+ )
498
+
499
+ sqrt_trt = impl .unary .sqrt (
500
+ ctx ,
501
+ target ,
502
+ source_ir ,
503
+ f"{ name } _sqrt" ,
504
+ add_trt ,
505
+ )
506
+
507
+ # y = (X - E[X]) / sqrt((var + eps))
508
+ output = impl .elementwise .div (
509
+ ctx ,
510
+ target ,
511
+ source_ir ,
512
+ f"{ name } _div" ,
513
+ sub_trt ,
514
+ sqrt_trt ,
314
515
)
315
- group_norm .epsilon = eps
316
- group_norm .compute_precision = input .dtype
317
- set_layer_name (group_norm , target , f"{ name } _group_norm" , source_ir )
318
- output = group_norm .get_output (0 )
319
516
320
517
shape = list (output .shape )
321
518
for i , s in enumerate (shape ):
@@ -329,12 +526,11 @@ def native_group_norm(
329
526
reshaped_output = impl .shuffle .reshape (
330
527
ctx , target , source_ir , f"{ name } _reshape_output" , output , shape
331
528
)
332
-
333
- weight = impl .shuffle .reshape (
529
+ reshaped_gamma = impl .shuffle .reshape (
334
530
ctx ,
335
531
target ,
336
532
source_ir ,
337
- f"{ name } _weight " ,
533
+ f"{ name } _reshape_gamma " ,
338
534
weight ,
339
535
weight_bias_shape ,
340
536
)
@@ -343,26 +539,26 @@ def native_group_norm(
343
539
ctx ,
344
540
target ,
345
541
source_ir ,
346
- f"{ name } _mul_weight " ,
542
+ f"{ name } _mul_gamma " ,
347
543
reshaped_output ,
348
- weight ,
544
+ reshaped_gamma ,
349
545
)
350
546
351
- bias = impl .shuffle .reshape (
547
+ reshaped_bias = impl .shuffle .reshape (
352
548
ctx ,
353
549
target ,
354
550
source_ir ,
355
- f"{ name } _reshape_bias " ,
551
+ f"{ name } _reshape_beta " ,
356
552
bias ,
357
553
weight_bias_shape ,
358
554
)
359
555
reshaped_output = impl .elementwise .add (
360
556
ctx ,
361
557
target ,
362
558
source_ir ,
363
- f"{ name } _add_bias " ,
559
+ f"{ name } _add_beta " ,
364
560
reshaped_output ,
365
- bias ,
561
+ reshaped_bias ,
366
562
)
367
563
if return_mean_rstd :
368
564
# return fake mean and rstd for now
0 commit comments