Skip to content

Commit 5f47fb1

Browse files
committed
Still fail on stabel diffusion 1.5. Revert back to decomposed ops instead of using INormalization Layer
1 parent a8e2618 commit 5f47fb1

File tree

1 file changed

+221
-27
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/normalization

1 file changed

+221
-27
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 221 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,143 @@ def layer_norm(
227227
# return reshaped_output, None, None
228228
# return reshaped_output
229229

230+
# def native_group_norm(
231+
# ctx: ConversionContext,
232+
# target: Target,
233+
# source_ir: Optional[SourceIR],
234+
# name: str,
235+
# input: TRTTensor,
236+
# weight: Optional[Union[torch.Tensor, np.ndarray]],
237+
# bias: Optional[Union[torch.Tensor, np.ndarray]],
238+
# N: int,
239+
# C: int,
240+
# HxW: int,
241+
# group: int,
242+
# eps: float,
243+
# return_mean_rstd: bool = True,
244+
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
245+
# assert (
246+
# len(input.shape) >= 3
247+
# ), f"The input dimension should not be less than 3, got {len(input.shape)}!"
248+
249+
# B = input.shape[0]
250+
# # if C is provided, it must be as same as the channel from the input shape,
251+
# # else if C is zero, we should get the channel from the input shape
252+
# if C == 0:
253+
# C = input.shape[1]
254+
# assert (
255+
# C == input.shape[1]
256+
# ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
257+
# # Groups are a subdivision of the channel dimension.
258+
# assert (
259+
# C % group == 0
260+
# ), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
261+
# input = get_trt_tensor(ctx, input, f"{name}_input")
262+
263+
# shape = list(input.shape)
264+
265+
# for i, s in enumerate(shape):
266+
# if i == 0 and s > 0:
267+
# shape[i] = B * group
268+
# elif i == 1:
269+
# shape[i] = C // group
270+
# elif i > 1 and s == -1:
271+
# shape[i] = 0
272+
273+
# # Normalize every group.
274+
# reshaped_input = impl.shuffle.reshape(
275+
# ctx,
276+
# target,
277+
# source_ir,
278+
# f"{name}_reshape_input",
279+
# input,
280+
# shape,
281+
# )
282+
283+
# weight = get_trt_tensor(ctx, weight, f"{name}_weight")
284+
# bias = get_trt_tensor(ctx, bias, f"{name}_bias")
285+
# weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
286+
287+
# dims = list(range(1, len(input.shape)))
288+
# axes = get_axes_for_reduce_op(dims)
289+
# dummy_weight = get_trt_tensor(ctx, np.array([1.0]), f"{name}_dummy_weight")
290+
# dummy_weight = impl.slice.expand(
291+
# ctx, target, source_ir, f"{name}_expand_dummy_weight", dummy_weight, reshaped_input.shape
292+
# )
293+
# dummy_bias = get_trt_tensor(ctx, np.array([0.0]), f"{name}_dummy_bias")
294+
# dummy_bias = impl.slice.expand(
295+
# ctx, target, source_ir, f"{name}_expand_dummy_bias", dummy_bias, reshaped_input.shape
296+
# )
297+
# group_norm = ctx.net.add_normalization(reshaped_input, dummy_weight, dummy_bias, axes)
298+
# group_norm.epsilon = eps
299+
# group_norm.compute_precision = input.dtype
300+
# set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
301+
# output = group_norm.get_output(0)
302+
303+
# shape = list(output.shape)
304+
# for i, s in enumerate(shape):
305+
# if i == 0 and s > 0:
306+
# shape[i] = B
307+
# elif i == 1:
308+
# shape[i] = C
309+
# elif i > 1 and s == -1:
310+
# shape[i] = 0
311+
312+
# reshaped_output = impl.shuffle.reshape(
313+
# ctx, target, source_ir, f"{name}_reshape_output", output, shape
314+
# )
315+
316+
317+
# # weight = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_weight_unsqueeze1", weight, (0))
318+
# # weight = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_weight_unsqueeze2", weight, (2))
319+
# # weight = impl.slice.expand(
320+
# # ctx, target, source_ir, f"{name}_expand_weight", weight, shape
321+
# # )
322+
# # bias = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_bias_unsqueeze1", bias, (0))
323+
# # bias = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_bias_unsqueeze2", bias, (2))
324+
# # bias = impl.slice.expand(
325+
# # ctx, target, source_ir, f"{name}_expand_bias", bias, shape
326+
# # )
327+
328+
# reshaped_gamma = impl.shuffle.reshape(
329+
# ctx,
330+
# target,
331+
# source_ir,
332+
# f"{name}_reshape_gamma",
333+
# weight,
334+
# weight_bias_shape,
335+
# )
336+
337+
# reshaped_output = impl.elementwise.mul(
338+
# ctx,
339+
# target,
340+
# source_ir,
341+
# f"{name}_mul_gamma",
342+
# reshaped_output,
343+
# reshaped_gamma,
344+
# )
345+
346+
# reshaped_bias = impl.shuffle.reshape(
347+
# ctx,
348+
# target,
349+
# source_ir,
350+
# f"{name}_reshape_beta",
351+
# bias,
352+
# weight_bias_shape,
353+
# )
354+
# reshaped_output = impl.elementwise.add(
355+
# ctx,
356+
# target,
357+
# source_ir,
358+
# f"{name}_add_beta",
359+
# reshaped_output,
360+
# reshaped_bias,
361+
# )
362+
# if return_mean_rstd:
363+
# # return fake mean and rstd for now
364+
# return reshaped_output, None, None
365+
# return reshaped_output
366+
230367

231368
def native_group_norm(
232369
ctx: ConversionContext,
@@ -286,36 +423,94 @@ def native_group_norm(
286423
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
287424

288425
dims = list(range(1, len(input.shape)))
289-
axes = get_axes_for_reduce_op(dims)
290-
# Use dummy weight since the normalization layer cannot well handle the scale and shift of group norm due to shape mismatch
291-
# TODO: check with TRT the correct way to use 'num_groups' to implement group norm
292-
dummy_weight = get_trt_tensor(
293-
ctx, np.array([1.0]), f"{name}_dummy_weight", input.dtype
426+
427+
# E[X]
428+
mean_trt = impl.reduce.mean(
429+
ctx,
430+
target,
431+
source_ir,
432+
f"{name}_mean",
433+
reshaped_input,
434+
dims,
435+
True,
294436
)
295-
dummy_weight = impl.slice.expand(
437+
438+
mean_trt = impl.slice.expand(
296439
ctx,
297440
target,
298441
source_ir,
299-
f"{name}_expand_dummy_weight",
300-
dummy_weight,
442+
f"{name}_expand_mean_trt",
443+
mean_trt,
301444
reshaped_input.shape,
302445
)
303-
dummy_bias = get_trt_tensor(ctx, np.array([0.0]), f"{name}_dummy_bias", input.dtype)
304-
dummy_bias = impl.slice.expand(
446+
447+
# X - E[X]
448+
sub_trt = impl.elementwise.sub(
449+
ctx,
450+
target,
451+
source_ir,
452+
f"{name}_sub",
453+
reshaped_input,
454+
mean_trt,
455+
)
456+
457+
# variance
458+
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
459+
pow_var = impl.elementwise.pow(
460+
ctx,
461+
target,
462+
source_ir,
463+
f"{name}_pow",
464+
sub_trt,
465+
pow_trt,
466+
)
467+
468+
var_trt = impl.reduce.mean(
469+
ctx,
470+
target,
471+
source_ir,
472+
f"{name}_mean_var",
473+
pow_var,
474+
dims,
475+
True,
476+
)
477+
478+
var_trt = impl.slice.expand(
305479
ctx,
306480
target,
307481
source_ir,
308-
f"{name}_expand_dummy_bias",
309-
dummy_bias,
482+
f"{name}_expand_var_trt",
483+
var_trt,
310484
reshaped_input.shape,
311485
)
312-
group_norm = ctx.net.add_normalization(
313-
reshaped_input, dummy_weight, dummy_bias, axes
486+
487+
eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
488+
add_trt = impl.elementwise.add(
489+
ctx,
490+
target,
491+
source_ir,
492+
f"{name}_add",
493+
var_trt,
494+
eps_trt,
495+
)
496+
497+
sqrt_trt = impl.unary.sqrt(
498+
ctx,
499+
target,
500+
source_ir,
501+
f"{name}_sqrt",
502+
add_trt,
503+
)
504+
505+
# y = (X - E[X]) / sqrt((var + eps))
506+
output = impl.elementwise.div(
507+
ctx,
508+
target,
509+
source_ir,
510+
f"{name}_div",
511+
sub_trt,
512+
sqrt_trt,
314513
)
315-
group_norm.epsilon = eps
316-
group_norm.compute_precision = input.dtype
317-
set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
318-
output = group_norm.get_output(0)
319514

320515
shape = list(output.shape)
321516
for i, s in enumerate(shape):
@@ -329,12 +524,11 @@ def native_group_norm(
329524
reshaped_output = impl.shuffle.reshape(
330525
ctx, target, source_ir, f"{name}_reshape_output", output, shape
331526
)
332-
333-
weight = impl.shuffle.reshape(
527+
reshaped_gamma = impl.shuffle.reshape(
334528
ctx,
335529
target,
336530
source_ir,
337-
f"{name}_weight",
531+
f"{name}_reshape_gamma",
338532
weight,
339533
weight_bias_shape,
340534
)
@@ -343,26 +537,26 @@ def native_group_norm(
343537
ctx,
344538
target,
345539
source_ir,
346-
f"{name}_mul_weight",
540+
f"{name}_mul_gamma",
347541
reshaped_output,
348-
weight,
542+
reshaped_gamma,
349543
)
350544

351-
bias = impl.shuffle.reshape(
545+
reshaped_bias = impl.shuffle.reshape(
352546
ctx,
353547
target,
354548
source_ir,
355-
f"{name}_reshape_bias",
549+
f"{name}_reshape_beta",
356550
bias,
357551
weight_bias_shape,
358552
)
359553
reshaped_output = impl.elementwise.add(
360554
ctx,
361555
target,
362556
source_ir,
363-
f"{name}_add_bias",
557+
f"{name}_add_beta",
364558
reshaped_output,
365-
bias,
559+
reshaped_bias,
366560
)
367561
if return_mean_rstd:
368562
# return fake mean and rstd for now

0 commit comments

Comments
 (0)