Skip to content

Commit 6ac2ec8

Browse files
authored
Group norm bug fix (#3014)
1 parent abf3370 commit 6ac2ec8

File tree

2 files changed

+157
-23
lines changed

2 files changed

+157
-23
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 132 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def native_group_norm(
149149
eps: float,
150150
return_mean_rstd: bool = True,
151151
) -> Union[TRTTensor, Sequence[TRTTensor]]:
152+
# TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
153+
# with INormalization Layer
152154
assert (
153155
len(input.shape) >= 3
154156
), f"The input dimension should not be less than 3, got {len(input.shape)}!"
@@ -187,28 +189,105 @@ def native_group_norm(
187189
shape,
188190
)
189191

192+
if weight is None:
193+
weight = to_numpy(1.0)
194+
195+
if bias is None:
196+
bias = to_numpy(0.0)
197+
190198
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
191199
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
192-
if tuple(reshaped_input.shape) != tuple(weight.shape):
193-
weight = impl.slice.expand(
194-
ctx,
195-
target,
196-
source_ir,
197-
f"{name}_expand_weight",
198-
weight,
199-
reshaped_input.shape,
200-
)
201-
if tuple(reshaped_input.shape) != tuple(bias.shape):
202-
bias = impl.slice.expand(
203-
ctx, target, source_ir, f"{name}_expand_bias", bias, reshaped_input.shape
204-
)
200+
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
201+
205202
dims = list(range(1, len(input.shape)))
206-
axes = get_axes_for_reduce_op(dims)
207-
group_norm = ctx.net.add_normalization(reshaped_input, weight, bias, axes)
208-
group_norm.epsilon = eps
209-
group_norm.compute_precision = input.dtype
210-
set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
211-
output = group_norm.get_output(0)
203+
204+
# E[X]
205+
mean_trt = impl.reduce.mean(
206+
ctx,
207+
target,
208+
source_ir,
209+
f"{name}_mean",
210+
reshaped_input,
211+
dims,
212+
True,
213+
)
214+
215+
mean_trt = impl.slice.expand(
216+
ctx,
217+
target,
218+
source_ir,
219+
f"{name}_expand_mean_trt",
220+
mean_trt,
221+
reshaped_input.shape,
222+
)
223+
224+
# X - E[X]
225+
sub_trt = impl.elementwise.sub(
226+
ctx,
227+
target,
228+
source_ir,
229+
f"{name}_sub",
230+
reshaped_input,
231+
mean_trt,
232+
)
233+
234+
# variance
235+
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
236+
pow_var = impl.elementwise.pow(
237+
ctx,
238+
target,
239+
source_ir,
240+
f"{name}_pow",
241+
sub_trt,
242+
pow_trt,
243+
)
244+
245+
var_trt = impl.reduce.mean(
246+
ctx,
247+
target,
248+
source_ir,
249+
f"{name}_mean_var",
250+
pow_var,
251+
dims,
252+
True,
253+
)
254+
255+
var_trt = impl.slice.expand(
256+
ctx,
257+
target,
258+
source_ir,
259+
f"{name}_expand_var_trt",
260+
var_trt,
261+
reshaped_input.shape,
262+
)
263+
264+
eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
265+
add_trt = impl.elementwise.add(
266+
ctx,
267+
target,
268+
source_ir,
269+
f"{name}_add",
270+
var_trt,
271+
eps_trt,
272+
)
273+
274+
sqrt_trt = impl.unary.sqrt(
275+
ctx,
276+
target,
277+
source_ir,
278+
f"{name}_sqrt",
279+
add_trt,
280+
)
281+
282+
# y = (X - E[X]) / sqrt((var + eps))
283+
output = impl.elementwise.div(
284+
ctx,
285+
target,
286+
source_ir,
287+
f"{name}_div",
288+
sub_trt,
289+
sqrt_trt,
290+
)
212291

213292
shape = list(output.shape)
214293
for i, s in enumerate(shape):
@@ -222,6 +301,40 @@ def native_group_norm(
222301
reshaped_output = impl.shuffle.reshape(
223302
ctx, target, source_ir, f"{name}_reshape_output", output, shape
224303
)
304+
reshaped_gamma = impl.shuffle.reshape(
305+
ctx,
306+
target,
307+
source_ir,
308+
f"{name}_reshape_gamma",
309+
weight,
310+
weight_bias_shape,
311+
)
312+
313+
reshaped_output = impl.elementwise.mul(
314+
ctx,
315+
target,
316+
source_ir,
317+
f"{name}_mul_gamma",
318+
reshaped_output,
319+
reshaped_gamma,
320+
)
321+
322+
reshaped_bias = impl.shuffle.reshape(
323+
ctx,
324+
target,
325+
source_ir,
326+
f"{name}_reshape_beta",
327+
bias,
328+
weight_bias_shape,
329+
)
330+
reshaped_output = impl.elementwise.add(
331+
ctx,
332+
target,
333+
source_ir,
334+
f"{name}_add_beta",
335+
reshaped_output,
336+
reshaped_bias,
337+
)
225338
if return_mean_rstd:
226339
# return fake mean and rstd for now
227340
return reshaped_output, None, None

tests/py/dynamo/conversion/test_group_norm_aten.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def forward(self, x):
3131
return torch.ops.aten.group_norm.default(
3232
x,
3333
2,
34-
torch.ones((6,)),
35-
torch.zeros((6,)),
34+
torch.randn((6,)),
35+
torch.randn((6,)),
3636
1e-05,
3737
True,
3838
)
@@ -50,8 +50,8 @@ def forward(self, x):
5050
return torch.ops.aten.group_norm.default(
5151
x,
5252
2,
53-
torch.ones((6,)),
54-
torch.zeros((6,)),
53+
torch.randn((6,)),
54+
torch.randn((6,)),
5555
1e-05,
5656
True,
5757
)
@@ -112,6 +112,27 @@ def forward(self, x):
112112
inputs,
113113
)
114114

115+
def test_groupnorm_sd(self):
116+
class GroupNorm(torch.nn.Module):
117+
def forward(self, x):
118+
return torch.ops.aten.native_group_norm.default(
119+
x,
120+
torch.randn((320,)).half(),
121+
torch.randn((320,)).half(),
122+
2,
123+
320,
124+
4096,
125+
32,
126+
1e-05,
127+
)[0]
128+
129+
inputs = [torch.randn(2, 320, 64, 64).half()]
130+
with torch.no_grad():
131+
self.run_test(
132+
GroupNorm(),
133+
inputs,
134+
)
135+
115136
@parameterized.expand(
116137
[
117138
(5, 4, 4, 2, (2, 4, 2), (3, 4, 2), (5, 4, 4)),

0 commit comments

Comments
 (0)