Skip to content

Commit 03aca50

Browse files
committed
Prvious commit still fail on stabel diffusion 1.5. Changed to use decomposed ops instead of using INormalization Layer. Supported dynamic shape
1 parent a8e2618 commit 03aca50

File tree

2 files changed

+106
-142
lines changed

2 files changed

+106
-142
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 86 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -134,100 +134,6 @@ def layer_norm(
134134
return layer_norm.get_output(0)
135135

136136

137-
# def native_group_norm(
138-
# ctx: ConversionContext,
139-
# target: Target,
140-
# source_ir: Optional[SourceIR],
141-
# name: str,
142-
# input: TRTTensor,
143-
# weight: Optional[Union[torch.Tensor, np.ndarray]],
144-
# bias: Optional[Union[torch.Tensor, np.ndarray]],
145-
# N: int,
146-
# C: int,
147-
# HxW: int,
148-
# group: int,
149-
# eps: float,
150-
# return_mean_rstd: bool = True,
151-
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
152-
# assert (
153-
# len(input.shape) >= 3
154-
# ), f"The input dimension should not be less than 3, got {len(input.shape)}!"
155-
156-
# B = input.shape[0]
157-
# # if C is provided, it must be as same as the channel from the input shape,
158-
# # else if C is zero, we should get the channel from the input shape
159-
# if C == 0:
160-
# C = input.shape[1]
161-
# assert (
162-
# C == input.shape[1]
163-
# ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
164-
# # Groups are a subdivision of the channel dimension.
165-
# assert (
166-
# C % group == 0
167-
# ), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
168-
# input = get_trt_tensor(ctx, input, f"{name}_input")
169-
170-
# shape = list(input.shape)
171-
172-
# for i, s in enumerate(shape):
173-
# if i == 0 and s > 0:
174-
# shape[i] = B * group
175-
# elif i == 1:
176-
# shape[i] = C // group
177-
# elif i > 1 and s == -1:
178-
# shape[i] = 0
179-
180-
# # Normalize every group.
181-
# reshaped_input = impl.shuffle.reshape(
182-
# ctx,
183-
# target,
184-
# source_ir,
185-
# f"{name}_reshape_input",
186-
# input,
187-
# shape,
188-
# )
189-
190-
# weight = get_trt_tensor(ctx, weight, f"{name}_weight")
191-
# bias = get_trt_tensor(ctx, bias, f"{name}_bias")
192-
# if tuple(reshaped_input.shape) != tuple(weight.shape):
193-
# weight = impl.slice.expand(
194-
# ctx,
195-
# target,
196-
# source_ir,
197-
# f"{name}_expand_weight",
198-
# weight,
199-
# reshaped_input.shape,
200-
# )
201-
# if tuple(reshaped_input.shape) != tuple(bias.shape):
202-
# bias = impl.slice.expand(
203-
# ctx, target, source_ir, f"{name}_expand_bias", bias, reshaped_input.shape
204-
# )
205-
# dims = list(range(1, len(input.shape)))
206-
# axes = get_axes_for_reduce_op(dims)
207-
# group_norm = ctx.net.add_normalization(reshaped_input, weight, bias, axes)
208-
# group_norm.epsilon = eps
209-
# group_norm.compute_precision = input.dtype
210-
# set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
211-
# output = group_norm.get_output(0)
212-
213-
# shape = list(output.shape)
214-
# for i, s in enumerate(shape):
215-
# if i == 0 and s > 0:
216-
# shape[i] = B
217-
# elif i == 1:
218-
# shape[i] = C
219-
# elif i > 1 and s == -1:
220-
# shape[i] = 0
221-
222-
# reshaped_output = impl.shuffle.reshape(
223-
# ctx, target, source_ir, f"{name}_reshape_output", output, shape
224-
# )
225-
# if return_mean_rstd:
226-
# # return fake mean and rstd for now
227-
# return reshaped_output, None, None
228-
# return reshaped_output
229-
230-
231137
def native_group_norm(
232138
ctx: ConversionContext,
233139
target: Target,
@@ -243,6 +149,8 @@ def native_group_norm(
243149
eps: float,
244150
return_mean_rstd: bool = True,
245151
) -> Union[TRTTensor, Sequence[TRTTensor]]:
152+
# TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
153+
# with INormalization Layer
246154
assert (
247155
len(input.shape) >= 3
248156
), f"The input dimension should not be less than 3, got {len(input.shape)}!"
@@ -286,36 +194,94 @@ def native_group_norm(
286194
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
287195

288196
dims = list(range(1, len(input.shape)))
289-
axes = get_axes_for_reduce_op(dims)
290-
# Use dummy weight since the normalization layer cannot well handle the scale and shift of group norm due to shape mismatch
291-
# TODO: check with TRT the correct way to use 'num_groups' to implement group norm
292-
dummy_weight = get_trt_tensor(
293-
ctx, np.array([1.0]), f"{name}_dummy_weight", input.dtype
197+
198+
# E[X]
199+
mean_trt = impl.reduce.mean(
200+
ctx,
201+
target,
202+
source_ir,
203+
f"{name}_mean",
204+
reshaped_input,
205+
dims,
206+
True,
294207
)
295-
dummy_weight = impl.slice.expand(
208+
209+
mean_trt = impl.slice.expand(
296210
ctx,
297211
target,
298212
source_ir,
299-
f"{name}_expand_dummy_weight",
300-
dummy_weight,
213+
f"{name}_expand_mean_trt",
214+
mean_trt,
301215
reshaped_input.shape,
302216
)
303-
dummy_bias = get_trt_tensor(ctx, np.array([0.0]), f"{name}_dummy_bias", input.dtype)
304-
dummy_bias = impl.slice.expand(
217+
218+
# X - E[X]
219+
sub_trt = impl.elementwise.sub(
220+
ctx,
221+
target,
222+
source_ir,
223+
f"{name}_sub",
224+
reshaped_input,
225+
mean_trt,
226+
)
227+
228+
# variance
229+
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
230+
pow_var = impl.elementwise.pow(
231+
ctx,
232+
target,
233+
source_ir,
234+
f"{name}_pow",
235+
sub_trt,
236+
pow_trt,
237+
)
238+
239+
var_trt = impl.reduce.mean(
240+
ctx,
241+
target,
242+
source_ir,
243+
f"{name}_mean_var",
244+
pow_var,
245+
dims,
246+
True,
247+
)
248+
249+
var_trt = impl.slice.expand(
305250
ctx,
306251
target,
307252
source_ir,
308-
f"{name}_expand_dummy_bias",
309-
dummy_bias,
253+
f"{name}_expand_var_trt",
254+
var_trt,
310255
reshaped_input.shape,
311256
)
312-
group_norm = ctx.net.add_normalization(
313-
reshaped_input, dummy_weight, dummy_bias, axes
257+
258+
eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
259+
add_trt = impl.elementwise.add(
260+
ctx,
261+
target,
262+
source_ir,
263+
f"{name}_add",
264+
var_trt,
265+
eps_trt,
266+
)
267+
268+
sqrt_trt = impl.unary.sqrt(
269+
ctx,
270+
target,
271+
source_ir,
272+
f"{name}_sqrt",
273+
add_trt,
274+
)
275+
276+
# y = (X - E[X]) / sqrt((var + eps))
277+
output = impl.elementwise.div(
278+
ctx,
279+
target,
280+
source_ir,
281+
f"{name}_div",
282+
sub_trt,
283+
sqrt_trt,
314284
)
315-
group_norm.epsilon = eps
316-
group_norm.compute_precision = input.dtype
317-
set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
318-
output = group_norm.get_output(0)
319285

320286
shape = list(output.shape)
321287
for i, s in enumerate(shape):
@@ -329,12 +295,11 @@ def native_group_norm(
329295
reshaped_output = impl.shuffle.reshape(
330296
ctx, target, source_ir, f"{name}_reshape_output", output, shape
331297
)
332-
333-
weight = impl.shuffle.reshape(
298+
reshaped_gamma = impl.shuffle.reshape(
334299
ctx,
335300
target,
336301
source_ir,
337-
f"{name}_weight",
302+
f"{name}_reshape_gamma",
338303
weight,
339304
weight_bias_shape,
340305
)
@@ -343,26 +308,26 @@ def native_group_norm(
343308
ctx,
344309
target,
345310
source_ir,
346-
f"{name}_mul_weight",
311+
f"{name}_mul_gamma",
347312
reshaped_output,
348-
weight,
313+
reshaped_gamma,
349314
)
350315

351-
bias = impl.shuffle.reshape(
316+
reshaped_bias = impl.shuffle.reshape(
352317
ctx,
353318
target,
354319
source_ir,
355-
f"{name}_reshape_bias",
320+
f"{name}_reshape_beta",
356321
bias,
357322
weight_bias_shape,
358323
)
359324
reshaped_output = impl.elementwise.add(
360325
ctx,
361326
target,
362327
source_ir,
363-
f"{name}_add_bias",
328+
f"{name}_add_beta",
364329
reshaped_output,
365-
bias,
330+
reshaped_bias,
366331
)
367332
if return_mean_rstd:
368333
# return fake mean and rstd for now

tests/py/dynamo/conversion/test_group_norm_aten.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -112,27 +112,26 @@ def forward(self, x):
112112
inputs,
113113
)
114114

115-
# TODO: Half precision has accuracy issue for now.
116-
# def test_groupnorm_sd(self):
117-
# class GroupNorm(torch.nn.Module):
118-
# def forward(self, x):
119-
# return torch.ops.aten.native_group_norm.default(
120-
# x,
121-
# torch.randn((320,)).half(),
122-
# torch.randn((320,)).half(),
123-
# 2,
124-
# 320,
125-
# 4096,
126-
# 32,
127-
# 1e-05,
128-
# )[0]
129-
130-
# inputs = [torch.randn(2, 320, 64, 64).half()]
131-
# with torch.no_grad():
132-
# self.run_test(
133-
# GroupNorm(),
134-
# inputs,
135-
# )
115+
def test_groupnorm_sd(self):
116+
class GroupNorm(torch.nn.Module):
117+
def forward(self, x):
118+
return torch.ops.aten.native_group_norm.default(
119+
x,
120+
torch.randn((320,)).half(),
121+
torch.randn((320,)).half(),
122+
2,
123+
320,
124+
4096,
125+
32,
126+
1e-05,
127+
)[0]
128+
129+
inputs = [torch.randn(2, 320, 64, 64).half()]
130+
with torch.no_grad():
131+
self.run_test(
132+
GroupNorm(),
133+
inputs,
134+
)
136135

137136
@parameterized.expand(
138137
[

0 commit comments

Comments
 (0)