@@ -134,100 +134,6 @@ def layer_norm(
134
134
return layer_norm .get_output (0 )
135
135
136
136
137
- # def native_group_norm(
138
- # ctx: ConversionContext,
139
- # target: Target,
140
- # source_ir: Optional[SourceIR],
141
- # name: str,
142
- # input: TRTTensor,
143
- # weight: Optional[Union[torch.Tensor, np.ndarray]],
144
- # bias: Optional[Union[torch.Tensor, np.ndarray]],
145
- # N: int,
146
- # C: int,
147
- # HxW: int,
148
- # group: int,
149
- # eps: float,
150
- # return_mean_rstd: bool = True,
151
- # ) -> Union[TRTTensor, Sequence[TRTTensor]]:
152
- # assert (
153
- # len(input.shape) >= 3
154
- # ), f"The input dimension should not be less than 3, got {len(input.shape)}!"
155
-
156
- # B = input.shape[0]
157
- # # if C is provided, it must be as same as the channel from the input shape,
158
- # # else if C is zero, we should get the channel from the input shape
159
- # if C == 0:
160
- # C = input.shape[1]
161
- # assert (
162
- # C == input.shape[1]
163
- # ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
164
- # # Groups are a subdivision of the channel dimension.
165
- # assert (
166
- # C % group == 0
167
- # ), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
168
- # input = get_trt_tensor(ctx, input, f"{name}_input")
169
-
170
- # shape = list(input.shape)
171
-
172
- # for i, s in enumerate(shape):
173
- # if i == 0 and s > 0:
174
- # shape[i] = B * group
175
- # elif i == 1:
176
- # shape[i] = C // group
177
- # elif i > 1 and s == -1:
178
- # shape[i] = 0
179
-
180
- # # Normalize every group.
181
- # reshaped_input = impl.shuffle.reshape(
182
- # ctx,
183
- # target,
184
- # source_ir,
185
- # f"{name}_reshape_input",
186
- # input,
187
- # shape,
188
- # )
189
-
190
- # weight = get_trt_tensor(ctx, weight, f"{name}_weight")
191
- # bias = get_trt_tensor(ctx, bias, f"{name}_bias")
192
- # if tuple(reshaped_input.shape) != tuple(weight.shape):
193
- # weight = impl.slice.expand(
194
- # ctx,
195
- # target,
196
- # source_ir,
197
- # f"{name}_expand_weight",
198
- # weight,
199
- # reshaped_input.shape,
200
- # )
201
- # if tuple(reshaped_input.shape) != tuple(bias.shape):
202
- # bias = impl.slice.expand(
203
- # ctx, target, source_ir, f"{name}_expand_bias", bias, reshaped_input.shape
204
- # )
205
- # dims = list(range(1, len(input.shape)))
206
- # axes = get_axes_for_reduce_op(dims)
207
- # group_norm = ctx.net.add_normalization(reshaped_input, weight, bias, axes)
208
- # group_norm.epsilon = eps
209
- # group_norm.compute_precision = input.dtype
210
- # set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
211
- # output = group_norm.get_output(0)
212
-
213
- # shape = list(output.shape)
214
- # for i, s in enumerate(shape):
215
- # if i == 0 and s > 0:
216
- # shape[i] = B
217
- # elif i == 1:
218
- # shape[i] = C
219
- # elif i > 1 and s == -1:
220
- # shape[i] = 0
221
-
222
- # reshaped_output = impl.shuffle.reshape(
223
- # ctx, target, source_ir, f"{name}_reshape_output", output, shape
224
- # )
225
- # if return_mean_rstd:
226
- # # return fake mean and rstd for now
227
- # return reshaped_output, None, None
228
- # return reshaped_output
229
-
230
-
231
137
def native_group_norm (
232
138
ctx : ConversionContext ,
233
139
target : Target ,
@@ -243,6 +149,8 @@ def native_group_norm(
243
149
eps : float ,
244
150
return_mean_rstd : bool = True ,
245
151
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
152
+ # TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
153
+ # with INormalization Layer
246
154
assert (
247
155
len (input .shape ) >= 3
248
156
), f"The input dimension should not be less than 3, got { len (input .shape )} !"
@@ -286,36 +194,94 @@ def native_group_norm(
286
194
weight_bias_shape = (1 , C ) + (1 ,) * (len (input .shape ) - 2 )
287
195
288
196
dims = list (range (1 , len (input .shape )))
289
- axes = get_axes_for_reduce_op (dims )
290
- # Use dummy weight since the normalization layer cannot well handle the scale and shift of group norm due to shape mismatch
291
- # TODO: check with TRT the correct way to use 'num_groups' to implement group norm
292
- dummy_weight = get_trt_tensor (
293
- ctx , np .array ([1.0 ]), f"{ name } _dummy_weight" , input .dtype
197
+
198
+ # E[X]
199
+ mean_trt = impl .reduce .mean (
200
+ ctx ,
201
+ target ,
202
+ source_ir ,
203
+ f"{ name } _mean" ,
204
+ reshaped_input ,
205
+ dims ,
206
+ True ,
294
207
)
295
- dummy_weight = impl .slice .expand (
208
+
209
+ mean_trt = impl .slice .expand (
296
210
ctx ,
297
211
target ,
298
212
source_ir ,
299
- f"{ name } _expand_dummy_weight " ,
300
- dummy_weight ,
213
+ f"{ name } _expand_mean_trt " ,
214
+ mean_trt ,
301
215
reshaped_input .shape ,
302
216
)
303
- dummy_bias = get_trt_tensor (ctx , np .array ([0.0 ]), f"{ name } _dummy_bias" , input .dtype )
304
- dummy_bias = impl .slice .expand (
217
+
218
+ # X - E[X]
219
+ sub_trt = impl .elementwise .sub (
220
+ ctx ,
221
+ target ,
222
+ source_ir ,
223
+ f"{ name } _sub" ,
224
+ reshaped_input ,
225
+ mean_trt ,
226
+ )
227
+
228
+ # variance
229
+ pow_trt = get_trt_tensor (ctx , 2 , f"{ name } _power" , np .float32 )
230
+ pow_var = impl .elementwise .pow (
231
+ ctx ,
232
+ target ,
233
+ source_ir ,
234
+ f"{ name } _pow" ,
235
+ sub_trt ,
236
+ pow_trt ,
237
+ )
238
+
239
+ var_trt = impl .reduce .mean (
240
+ ctx ,
241
+ target ,
242
+ source_ir ,
243
+ f"{ name } _mean_var" ,
244
+ pow_var ,
245
+ dims ,
246
+ True ,
247
+ )
248
+
249
+ var_trt = impl .slice .expand (
305
250
ctx ,
306
251
target ,
307
252
source_ir ,
308
- f"{ name } _expand_dummy_bias " ,
309
- dummy_bias ,
253
+ f"{ name } _expand_var_trt " ,
254
+ var_trt ,
310
255
reshaped_input .shape ,
311
256
)
312
- group_norm = ctx .net .add_normalization (
313
- reshaped_input , dummy_weight , dummy_bias , axes
257
+
258
+ eps_trt = get_trt_tensor (ctx , eps , f"{ name } _eps" , np .float32 )
259
+ add_trt = impl .elementwise .add (
260
+ ctx ,
261
+ target ,
262
+ source_ir ,
263
+ f"{ name } _add" ,
264
+ var_trt ,
265
+ eps_trt ,
266
+ )
267
+
268
+ sqrt_trt = impl .unary .sqrt (
269
+ ctx ,
270
+ target ,
271
+ source_ir ,
272
+ f"{ name } _sqrt" ,
273
+ add_trt ,
274
+ )
275
+
276
+ # y = (X - E[X]) / sqrt((var + eps))
277
+ output = impl .elementwise .div (
278
+ ctx ,
279
+ target ,
280
+ source_ir ,
281
+ f"{ name } _div" ,
282
+ sub_trt ,
283
+ sqrt_trt ,
314
284
)
315
- group_norm .epsilon = eps
316
- group_norm .compute_precision = input .dtype
317
- set_layer_name (group_norm , target , f"{ name } _group_norm" , source_ir )
318
- output = group_norm .get_output (0 )
319
285
320
286
shape = list (output .shape )
321
287
for i , s in enumerate (shape ):
@@ -329,12 +295,11 @@ def native_group_norm(
329
295
reshaped_output = impl .shuffle .reshape (
330
296
ctx , target , source_ir , f"{ name } _reshape_output" , output , shape
331
297
)
332
-
333
- weight = impl .shuffle .reshape (
298
+ reshaped_gamma = impl .shuffle .reshape (
334
299
ctx ,
335
300
target ,
336
301
source_ir ,
337
- f"{ name } _weight " ,
302
+ f"{ name } _reshape_gamma " ,
338
303
weight ,
339
304
weight_bias_shape ,
340
305
)
@@ -343,26 +308,26 @@ def native_group_norm(
343
308
ctx ,
344
309
target ,
345
310
source_ir ,
346
- f"{ name } _mul_weight " ,
311
+ f"{ name } _mul_gamma " ,
347
312
reshaped_output ,
348
- weight ,
313
+ reshaped_gamma ,
349
314
)
350
315
351
- bias = impl .shuffle .reshape (
316
+ reshaped_bias = impl .shuffle .reshape (
352
317
ctx ,
353
318
target ,
354
319
source_ir ,
355
- f"{ name } _reshape_bias " ,
320
+ f"{ name } _reshape_beta " ,
356
321
bias ,
357
322
weight_bias_shape ,
358
323
)
359
324
reshaped_output = impl .elementwise .add (
360
325
ctx ,
361
326
target ,
362
327
source_ir ,
363
- f"{ name } _add_bias " ,
328
+ f"{ name } _add_beta " ,
364
329
reshaped_output ,
365
- bias ,
330
+ reshaped_bias ,
366
331
)
367
332
if return_mean_rstd :
368
333
# return fake mean and rstd for now
0 commit comments