@@ -149,6 +149,8 @@ def native_group_norm(
149
149
eps : float ,
150
150
return_mean_rstd : bool = True ,
151
151
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
152
+ # TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
153
+ # with INormalization Layer
152
154
assert (
153
155
len (input .shape ) >= 3
154
156
), f"The input dimension should not be less than 3, got { len (input .shape )} !"
@@ -187,28 +189,105 @@ def native_group_norm(
187
189
shape ,
188
190
)
189
191
192
+ if weight is None :
193
+ weight = to_numpy (1.0 )
194
+
195
+ if bias is None :
196
+ bias = to_numpy (0.0 )
197
+
190
198
weight = get_trt_tensor (ctx , weight , f"{ name } _weight" )
191
199
bias = get_trt_tensor (ctx , bias , f"{ name } _bias" )
192
- if tuple (reshaped_input .shape ) != tuple (weight .shape ):
193
- weight = impl .slice .expand (
194
- ctx ,
195
- target ,
196
- source_ir ,
197
- f"{ name } _expand_weight" ,
198
- weight ,
199
- reshaped_input .shape ,
200
- )
201
- if tuple (reshaped_input .shape ) != tuple (bias .shape ):
202
- bias = impl .slice .expand (
203
- ctx , target , source_ir , f"{ name } _expand_bias" , bias , reshaped_input .shape
204
- )
200
+ weight_bias_shape = (1 , C ) + (1 ,) * (len (input .shape ) - 2 )
201
+
205
202
dims = list (range (1 , len (input .shape )))
206
- axes = get_axes_for_reduce_op (dims )
207
- group_norm = ctx .net .add_normalization (reshaped_input , weight , bias , axes )
208
- group_norm .epsilon = eps
209
- group_norm .compute_precision = input .dtype
210
- set_layer_name (group_norm , target , f"{ name } _group_norm" , source_ir )
211
- output = group_norm .get_output (0 )
203
+
204
+ # E[X]
205
+ mean_trt = impl .reduce .mean (
206
+ ctx ,
207
+ target ,
208
+ source_ir ,
209
+ f"{ name } _mean" ,
210
+ reshaped_input ,
211
+ dims ,
212
+ True ,
213
+ )
214
+
215
+ mean_trt = impl .slice .expand (
216
+ ctx ,
217
+ target ,
218
+ source_ir ,
219
+ f"{ name } _expand_mean_trt" ,
220
+ mean_trt ,
221
+ reshaped_input .shape ,
222
+ )
223
+
224
+ # X - E[X]
225
+ sub_trt = impl .elementwise .sub (
226
+ ctx ,
227
+ target ,
228
+ source_ir ,
229
+ f"{ name } _sub" ,
230
+ reshaped_input ,
231
+ mean_trt ,
232
+ )
233
+
234
+ # variance
235
+ pow_trt = get_trt_tensor (ctx , 2 , f"{ name } _power" , np .float32 )
236
+ pow_var = impl .elementwise .pow (
237
+ ctx ,
238
+ target ,
239
+ source_ir ,
240
+ f"{ name } _pow" ,
241
+ sub_trt ,
242
+ pow_trt ,
243
+ )
244
+
245
+ var_trt = impl .reduce .mean (
246
+ ctx ,
247
+ target ,
248
+ source_ir ,
249
+ f"{ name } _mean_var" ,
250
+ pow_var ,
251
+ dims ,
252
+ True ,
253
+ )
254
+
255
+ var_trt = impl .slice .expand (
256
+ ctx ,
257
+ target ,
258
+ source_ir ,
259
+ f"{ name } _expand_var_trt" ,
260
+ var_trt ,
261
+ reshaped_input .shape ,
262
+ )
263
+
264
+ eps_trt = get_trt_tensor (ctx , eps , f"{ name } _eps" , np .float32 )
265
+ add_trt = impl .elementwise .add (
266
+ ctx ,
267
+ target ,
268
+ source_ir ,
269
+ f"{ name } _add" ,
270
+ var_trt ,
271
+ eps_trt ,
272
+ )
273
+
274
+ sqrt_trt = impl .unary .sqrt (
275
+ ctx ,
276
+ target ,
277
+ source_ir ,
278
+ f"{ name } _sqrt" ,
279
+ add_trt ,
280
+ )
281
+
282
+ # y = (X - E[X]) / sqrt((var + eps))
283
+ output = impl .elementwise .div (
284
+ ctx ,
285
+ target ,
286
+ source_ir ,
287
+ f"{ name } _div" ,
288
+ sub_trt ,
289
+ sqrt_trt ,
290
+ )
212
291
213
292
shape = list (output .shape )
214
293
for i , s in enumerate (shape ):
@@ -222,6 +301,40 @@ def native_group_norm(
222
301
reshaped_output = impl .shuffle .reshape (
223
302
ctx , target , source_ir , f"{ name } _reshape_output" , output , shape
224
303
)
304
+ reshaped_gamma = impl .shuffle .reshape (
305
+ ctx ,
306
+ target ,
307
+ source_ir ,
308
+ f"{ name } _reshape_gamma" ,
309
+ weight ,
310
+ weight_bias_shape ,
311
+ )
312
+
313
+ reshaped_output = impl .elementwise .mul (
314
+ ctx ,
315
+ target ,
316
+ source_ir ,
317
+ f"{ name } _mul_gamma" ,
318
+ reshaped_output ,
319
+ reshaped_gamma ,
320
+ )
321
+
322
+ reshaped_bias = impl .shuffle .reshape (
323
+ ctx ,
324
+ target ,
325
+ source_ir ,
326
+ f"{ name } _reshape_beta" ,
327
+ bias ,
328
+ weight_bias_shape ,
329
+ )
330
+ reshaped_output = impl .elementwise .add (
331
+ ctx ,
332
+ target ,
333
+ source_ir ,
334
+ f"{ name } _add_beta" ,
335
+ reshaped_output ,
336
+ reshaped_bias ,
337
+ )
225
338
if return_mean_rstd :
226
339
# return fake mean and rstd for now
227
340
return reshaped_output , None , None
0 commit comments