Skip to content

Commit 937d342

Browse files
committed
Still fail on stabel diffusion 1.5. Used decomposed ops instead of using INormalization Layer. Supported dynamic shape
1 parent a8e2618 commit 937d342

File tree

2 files changed

+243
-48
lines changed

2 files changed

+243
-48
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 223 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,143 @@ def layer_norm(
227227
# return reshaped_output, None, None
228228
# return reshaped_output
229229

230+
# def native_group_norm(
231+
# ctx: ConversionContext,
232+
# target: Target,
233+
# source_ir: Optional[SourceIR],
234+
# name: str,
235+
# input: TRTTensor,
236+
# weight: Optional[Union[torch.Tensor, np.ndarray]],
237+
# bias: Optional[Union[torch.Tensor, np.ndarray]],
238+
# N: int,
239+
# C: int,
240+
# HxW: int,
241+
# group: int,
242+
# eps: float,
243+
# return_mean_rstd: bool = True,
244+
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
245+
# assert (
246+
# len(input.shape) >= 3
247+
# ), f"The input dimension should not be less than 3, got {len(input.shape)}!"
248+
249+
# B = input.shape[0]
250+
# # if C is provided, it must be as same as the channel from the input shape,
251+
# # else if C is zero, we should get the channel from the input shape
252+
# if C == 0:
253+
# C = input.shape[1]
254+
# assert (
255+
# C == input.shape[1]
256+
# ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
257+
# # Groups are a subdivision of the channel dimension.
258+
# assert (
259+
# C % group == 0
260+
# ), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
261+
# input = get_trt_tensor(ctx, input, f"{name}_input")
262+
263+
# shape = list(input.shape)
264+
265+
# for i, s in enumerate(shape):
266+
# if i == 0 and s > 0:
267+
# shape[i] = B * group
268+
# elif i == 1:
269+
# shape[i] = C // group
270+
# elif i > 1 and s == -1:
271+
# shape[i] = 0
272+
273+
# # Normalize every group.
274+
# reshaped_input = impl.shuffle.reshape(
275+
# ctx,
276+
# target,
277+
# source_ir,
278+
# f"{name}_reshape_input",
279+
# input,
280+
# shape,
281+
# )
282+
283+
# weight = get_trt_tensor(ctx, weight, f"{name}_weight")
284+
# bias = get_trt_tensor(ctx, bias, f"{name}_bias")
285+
# weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
286+
287+
# dims = list(range(1, len(input.shape)))
288+
# axes = get_axes_for_reduce_op(dims)
289+
# dummy_weight = get_trt_tensor(ctx, np.array([1.0]), f"{name}_dummy_weight")
290+
# dummy_weight = impl.slice.expand(
291+
# ctx, target, source_ir, f"{name}_expand_dummy_weight", dummy_weight, reshaped_input.shape
292+
# )
293+
# dummy_bias = get_trt_tensor(ctx, np.array([0.0]), f"{name}_dummy_bias")
294+
# dummy_bias = impl.slice.expand(
295+
# ctx, target, source_ir, f"{name}_expand_dummy_bias", dummy_bias, reshaped_input.shape
296+
# )
297+
# group_norm = ctx.net.add_normalization(reshaped_input, dummy_weight, dummy_bias, axes)
298+
# group_norm.epsilon = eps
299+
# group_norm.compute_precision = input.dtype
300+
# set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
301+
# output = group_norm.get_output(0)
302+
303+
# shape = list(output.shape)
304+
# for i, s in enumerate(shape):
305+
# if i == 0 and s > 0:
306+
# shape[i] = B
307+
# elif i == 1:
308+
# shape[i] = C
309+
# elif i > 1 and s == -1:
310+
# shape[i] = 0
311+
312+
# reshaped_output = impl.shuffle.reshape(
313+
# ctx, target, source_ir, f"{name}_reshape_output", output, shape
314+
# )
315+
316+
317+
# # weight = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_weight_unsqueeze1", weight, (0))
318+
# # weight = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_weight_unsqueeze2", weight, (2))
319+
# # weight = impl.slice.expand(
320+
# # ctx, target, source_ir, f"{name}_expand_weight", weight, shape
321+
# # )
322+
# # bias = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_bias_unsqueeze1", bias, (0))
323+
# # bias = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_bias_unsqueeze2", bias, (2))
324+
# # bias = impl.slice.expand(
325+
# # ctx, target, source_ir, f"{name}_expand_bias", bias, shape
326+
# # )
327+
328+
# reshaped_gamma = impl.shuffle.reshape(
329+
# ctx,
330+
# target,
331+
# source_ir,
332+
# f"{name}_reshape_gamma",
333+
# weight,
334+
# weight_bias_shape,
335+
# )
336+
337+
# reshaped_output = impl.elementwise.mul(
338+
# ctx,
339+
# target,
340+
# source_ir,
341+
# f"{name}_mul_gamma",
342+
# reshaped_output,
343+
# reshaped_gamma,
344+
# )
345+
346+
# reshaped_bias = impl.shuffle.reshape(
347+
# ctx,
348+
# target,
349+
# source_ir,
350+
# f"{name}_reshape_beta",
351+
# bias,
352+
# weight_bias_shape,
353+
# )
354+
# reshaped_output = impl.elementwise.add(
355+
# ctx,
356+
# target,
357+
# source_ir,
358+
# f"{name}_add_beta",
359+
# reshaped_output,
360+
# reshaped_bias,
361+
# )
362+
# if return_mean_rstd:
363+
# # return fake mean and rstd for now
364+
# return reshaped_output, None, None
365+
# return reshaped_output
366+
230367

231368
def native_group_norm(
232369
ctx: ConversionContext,
@@ -243,6 +380,8 @@ def native_group_norm(
243380
eps: float,
244381
return_mean_rstd: bool = True,
245382
) -> Union[TRTTensor, Sequence[TRTTensor]]:
383+
# TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
384+
# with INormalization Layer
246385
assert (
247386
len(input.shape) >= 3
248387
), f"The input dimension should not be less than 3, got {len(input.shape)}!"
@@ -286,36 +425,94 @@ def native_group_norm(
286425
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
287426

288427
dims = list(range(1, len(input.shape)))
289-
axes = get_axes_for_reduce_op(dims)
290-
# Use dummy weight since the normalization layer cannot well handle the scale and shift of group norm due to shape mismatch
291-
# TODO: check with TRT the correct way to use 'num_groups' to implement group norm
292-
dummy_weight = get_trt_tensor(
293-
ctx, np.array([1.0]), f"{name}_dummy_weight", input.dtype
428+
429+
# E[X]
430+
mean_trt = impl.reduce.mean(
431+
ctx,
432+
target,
433+
source_ir,
434+
f"{name}_mean",
435+
reshaped_input,
436+
dims,
437+
True,
294438
)
295-
dummy_weight = impl.slice.expand(
439+
440+
mean_trt = impl.slice.expand(
296441
ctx,
297442
target,
298443
source_ir,
299-
f"{name}_expand_dummy_weight",
300-
dummy_weight,
444+
f"{name}_expand_mean_trt",
445+
mean_trt,
301446
reshaped_input.shape,
302447
)
303-
dummy_bias = get_trt_tensor(ctx, np.array([0.0]), f"{name}_dummy_bias", input.dtype)
304-
dummy_bias = impl.slice.expand(
448+
449+
# X - E[X]
450+
sub_trt = impl.elementwise.sub(
451+
ctx,
452+
target,
453+
source_ir,
454+
f"{name}_sub",
455+
reshaped_input,
456+
mean_trt,
457+
)
458+
459+
# variance
460+
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
461+
pow_var = impl.elementwise.pow(
462+
ctx,
463+
target,
464+
source_ir,
465+
f"{name}_pow",
466+
sub_trt,
467+
pow_trt,
468+
)
469+
470+
var_trt = impl.reduce.mean(
471+
ctx,
472+
target,
473+
source_ir,
474+
f"{name}_mean_var",
475+
pow_var,
476+
dims,
477+
True,
478+
)
479+
480+
var_trt = impl.slice.expand(
305481
ctx,
306482
target,
307483
source_ir,
308-
f"{name}_expand_dummy_bias",
309-
dummy_bias,
484+
f"{name}_expand_var_trt",
485+
var_trt,
310486
reshaped_input.shape,
311487
)
312-
group_norm = ctx.net.add_normalization(
313-
reshaped_input, dummy_weight, dummy_bias, axes
488+
489+
eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
490+
add_trt = impl.elementwise.add(
491+
ctx,
492+
target,
493+
source_ir,
494+
f"{name}_add",
495+
var_trt,
496+
eps_trt,
497+
)
498+
499+
sqrt_trt = impl.unary.sqrt(
500+
ctx,
501+
target,
502+
source_ir,
503+
f"{name}_sqrt",
504+
add_trt,
505+
)
506+
507+
# y = (X - E[X]) / sqrt((var + eps))
508+
output = impl.elementwise.div(
509+
ctx,
510+
target,
511+
source_ir,
512+
f"{name}_div",
513+
sub_trt,
514+
sqrt_trt,
314515
)
315-
group_norm.epsilon = eps
316-
group_norm.compute_precision = input.dtype
317-
set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
318-
output = group_norm.get_output(0)
319516

320517
shape = list(output.shape)
321518
for i, s in enumerate(shape):
@@ -329,12 +526,11 @@ def native_group_norm(
329526
reshaped_output = impl.shuffle.reshape(
330527
ctx, target, source_ir, f"{name}_reshape_output", output, shape
331528
)
332-
333-
weight = impl.shuffle.reshape(
529+
reshaped_gamma = impl.shuffle.reshape(
334530
ctx,
335531
target,
336532
source_ir,
337-
f"{name}_weight",
533+
f"{name}_reshape_gamma",
338534
weight,
339535
weight_bias_shape,
340536
)
@@ -343,26 +539,26 @@ def native_group_norm(
343539
ctx,
344540
target,
345541
source_ir,
346-
f"{name}_mul_weight",
542+
f"{name}_mul_gamma",
347543
reshaped_output,
348-
weight,
544+
reshaped_gamma,
349545
)
350546

351-
bias = impl.shuffle.reshape(
547+
reshaped_bias = impl.shuffle.reshape(
352548
ctx,
353549
target,
354550
source_ir,
355-
f"{name}_reshape_bias",
551+
f"{name}_reshape_beta",
356552
bias,
357553
weight_bias_shape,
358554
)
359555
reshaped_output = impl.elementwise.add(
360556
ctx,
361557
target,
362558
source_ir,
363-
f"{name}_add_bias",
559+
f"{name}_add_beta",
364560
reshaped_output,
365-
bias,
561+
reshaped_bias,
366562
)
367563
if return_mean_rstd:
368564
# return fake mean and rstd for now

tests/py/dynamo/conversion/test_group_norm_aten.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -112,27 +112,26 @@ def forward(self, x):
112112
inputs,
113113
)
114114

115-
# TODO: Half precision has accuracy issue for now.
116-
# def test_groupnorm_sd(self):
117-
# class GroupNorm(torch.nn.Module):
118-
# def forward(self, x):
119-
# return torch.ops.aten.native_group_norm.default(
120-
# x,
121-
# torch.randn((320,)).half(),
122-
# torch.randn((320,)).half(),
123-
# 2,
124-
# 320,
125-
# 4096,
126-
# 32,
127-
# 1e-05,
128-
# )[0]
129-
130-
# inputs = [torch.randn(2, 320, 64, 64).half()]
131-
# with torch.no_grad():
132-
# self.run_test(
133-
# GroupNorm(),
134-
# inputs,
135-
# )
115+
def test_groupnorm_sd(self):
116+
class GroupNorm(torch.nn.Module):
117+
def forward(self, x):
118+
return torch.ops.aten.native_group_norm.default(
119+
x,
120+
torch.randn((320,)).half(),
121+
torch.randn((320,)).half(),
122+
2,
123+
320,
124+
4096,
125+
32,
126+
1e-05,
127+
)[0]
128+
129+
inputs = [torch.randn(2, 320, 64, 64).half()]
130+
with torch.no_grad():
131+
self.run_test(
132+
GroupNorm(),
133+
inputs,
134+
)
136135

137136
@parameterized.expand(
138137
[

0 commit comments

Comments
 (0)