Skip to content

Commit 38b1804

Browse files
authored
Use INormalizationLayer for GroupNorm (#3273)
1 parent 302d0b8 commit 38b1804

File tree

3 files changed

+83
-375
lines changed

3 files changed

+83
-375
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ def aten_ops_native_group_norm(
183183
SourceIR.ATEN,
184184
name,
185185
input=args[0],
186-
weight=args[1],
187-
bias=args[2],
186+
weight=args_bounds_check(args, 1),
187+
bias=args_bounds_check(args, 2),
188188
N=args[3],
189189
C=args[4],
190190
HxW=args[5],
@@ -193,40 +193,6 @@ def aten_ops_native_group_norm(
193193
)
194194

195195

196-
@dynamo_tensorrt_converter(
197-
torch.ops.aten.group_norm.default,
198-
supports_dynamic_shapes=True,
199-
)
200-
@dynamo_tensorrt_converter(
201-
torch.ops.aten.group_norm,
202-
supports_dynamic_shapes=True,
203-
)
204-
@enforce_tensor_types(
205-
{
206-
0: (TRTTensor,),
207-
}
208-
)
209-
def aten_ops_group_norm(
210-
ctx: ConversionContext,
211-
target: Target,
212-
args: Tuple[Argument, ...],
213-
kwargs: Dict[str, Argument],
214-
name: str,
215-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
216-
return impl.normalization.group_norm(
217-
ctx,
218-
target,
219-
SourceIR.ATEN,
220-
name,
221-
input=args[0],
222-
num_groups=args[1],
223-
weight=args_bounds_check(args, 2, None),
224-
bias=args_bounds_check(args, 3, None),
225-
eps=args_bounds_check(args, 4, 1e-05),
226-
cudnn_enabled=args_bounds_check(args, 5, True),
227-
)
228-
229-
230196
@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True)
231197
def aten_ops_cat(
232198
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 49 additions & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
2+
from typing import List, Optional, Sequence, Tuple, Union
33

44
import numpy as np
55
import tensorrt as trt
@@ -16,7 +16,6 @@
1616
get_trt_tensor,
1717
has_dynamic_shape,
1818
set_layer_name,
19-
to_numpy,
2019
)
2120
from torch_tensorrt.dynamo.conversion.impl.cat import cat
2221
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
@@ -203,234 +202,72 @@ def native_group_norm(
203202
source_ir: Optional[SourceIR],
204203
name: str,
205204
input: TRTTensor,
206-
weight: Optional[Union[torch.Tensor, np.ndarray]],
207-
bias: Optional[Union[torch.Tensor, np.ndarray]],
205+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
206+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
208207
N: int,
209208
C: int,
210209
HxW: int,
211210
group: int,
212211
eps: float,
213-
return_mean_rstd: bool = True,
214-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
215-
# TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
216-
# with INormalization Layer
217-
assert (
218-
len(input.shape) >= 3
219-
), f"The input dimension should not be less than 3, got {len(input.shape)}!"
220-
221-
B = input.shape[0]
222-
# if C is provided, it must be as same as the channel from the input shape,
223-
# else if C is zero, we should get the channel from the input shape
224-
if C == 0:
225-
C = input.shape[1]
226-
assert (
227-
C == input.shape[1]
228-
), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
229-
# Groups are a subdivision of the channel dimension.
230-
assert (
231-
C % group == 0
232-
), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
233-
input = get_trt_tensor(ctx, input, f"{name}_input")
234-
235-
shape = list(input.shape)
236-
237-
for i, s in enumerate(shape):
238-
if i == 0 and s > 0:
239-
shape[i] = B * group
240-
elif i == 1:
241-
shape[i] = C // group
242-
elif i > 1 and s == -1:
243-
shape[i] = 0
244-
245-
# Normalize every group.
246-
reshaped_input = impl.shuffle.reshape(
247-
ctx,
248-
target,
249-
source_ir,
250-
f"{name}_reshape_input",
251-
input,
252-
shape,
253-
)
254-
255-
if weight is None:
256-
weight = to_numpy(1.0)
257-
258-
if bias is None:
259-
bias = to_numpy(0.0)
260-
261-
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
262-
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
263-
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
264-
265-
dims = list(range(1, len(input.shape)))
266-
267-
# E[X]
268-
mean_trt = impl.reduce.mean(
269-
ctx,
270-
target,
271-
source_ir,
272-
f"{name}_mean",
273-
reshaped_input,
274-
dims,
275-
True,
276-
)
212+
) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]:
213+
rank = len(input.shape)
277214

278-
mean_trt = impl.slice.expand(
279-
ctx,
280-
target,
281-
source_ir,
282-
f"{name}_expand_mean_trt",
283-
mean_trt,
284-
reshaped_input.shape,
285-
)
215+
assert rank >= 3, f"Expected at least 3 dimensions for input tensor but got {rank}"
286216

287-
# X - E[X]
288-
sub_trt = impl.elementwise.sub(
289-
ctx,
290-
target,
291-
source_ir,
292-
f"{name}_sub",
293-
reshaped_input,
294-
mean_trt,
295-
)
217+
assert (
218+
C == input.shape[1]
219+
), f"num_channels ({C}) must be equal to number of channels in input ({input.shape[1]})"
296220

297-
# variance
298-
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
299-
pow_var = impl.elementwise.pow(
300-
ctx,
301-
target,
302-
source_ir,
303-
f"{name}_pow",
304-
sub_trt,
305-
pow_trt,
306-
)
221+
weight_one = get_trt_tensor(ctx, 1.0, f"{name}_weight_one", input.dtype)
222+
bias_zero = get_trt_tensor(ctx, 0.0, f"{name}_bias_zero", input.dtype)
307223

308-
var_trt = impl.reduce.mean(
309-
ctx,
310-
target,
311-
source_ir,
312-
f"{name}_mean_var",
313-
pow_var,
314-
dims,
315-
True,
316-
)
224+
shape = [1, group] + [1] * (rank - 2)
317225

318-
var_trt = impl.slice.expand(
319-
ctx,
320-
target,
321-
source_ir,
322-
f"{name}_expand_var_trt",
323-
var_trt,
324-
reshaped_input.shape,
226+
weight_one = impl.slice.expand(
227+
ctx, target, source_ir, f"{name}_expand_weight_one", weight_one, shape
325228
)
326-
327-
eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
328-
add_trt = impl.elementwise.add(
329-
ctx,
330-
target,
331-
source_ir,
332-
f"{name}_add",
333-
var_trt,
334-
eps_trt,
229+
bias_zero = impl.slice.expand(
230+
ctx, target, source_ir, f"{name}_expand_bias_zero", bias_zero, shape
335231
)
336232

337-
sqrt_trt = impl.unary.sqrt(
338-
ctx,
339-
target,
340-
source_ir,
341-
f"{name}_sqrt",
342-
add_trt,
343-
)
233+
axes = get_axes_for_reduce_op([i for i in range(1 if group == 1 else 2, rank)])
344234

345-
# y = (X - E[X]) / sqrt((var + eps))
346-
output = impl.elementwise.div(
347-
ctx,
348-
target,
349-
source_ir,
350-
f"{name}_div",
351-
sub_trt,
352-
sqrt_trt,
353-
)
235+
# INormalizationLayer scales the normalized output per-group, but PyTorch scales the normalized output per-channel,
236+
# hence causing diverse result. Let TensorRT does no-op for scaling here, and do scaling ourselves later
237+
layer = ctx.net.add_normalization(input, weight_one, bias_zero, axes)
238+
layer.epsilon = eps
239+
layer.num_groups = group
240+
set_layer_name(layer, target, name, source_ir)
241+
output = layer.get_output(0)
354242

355-
shape = list(output.shape)
356-
for i, s in enumerate(shape):
357-
if i == 0 and s > 0:
358-
shape[i] = B
359-
elif i == 1:
360-
shape[i] = C
361-
elif i > 1 and s == -1:
362-
shape[i] = 0
243+
shape[1] = C
363244

364-
reshaped_output = impl.shuffle.reshape(
365-
ctx, target, source_ir, f"{name}_reshape_output", output, shape
366-
)
367-
reshaped_gamma = impl.shuffle.reshape(
368-
ctx,
369-
target,
370-
source_ir,
371-
f"{name}_reshape_gamma",
372-
weight,
373-
weight_bias_shape,
374-
)
375-
376-
reshaped_output = impl.elementwise.mul(
377-
ctx,
378-
target,
379-
source_ir,
380-
f"{name}_mul_gamma",
381-
reshaped_output,
382-
reshaped_gamma,
383-
)
384-
385-
reshaped_bias = impl.shuffle.reshape(
386-
ctx,
387-
target,
388-
source_ir,
389-
f"{name}_reshape_beta",
390-
bias,
391-
weight_bias_shape,
392-
)
393-
reshaped_output = impl.elementwise.add(
394-
ctx,
395-
target,
396-
source_ir,
397-
f"{name}_add_beta",
398-
reshaped_output,
399-
reshaped_bias,
400-
)
401-
if return_mean_rstd:
402-
# return fake mean and rstd for now
403-
return reshaped_output, None, None
404-
return reshaped_output
245+
if weight is not None:
246+
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
247+
weight = cast_trt_tensor(
248+
ctx, weight, input.dtype, f"{name}_cast_weight", target, source_ir
249+
)
250+
weight = impl.shuffle.reshape(
251+
ctx, target, source_ir, f"{name}_reshape_weight", weight, shape
252+
)
253+
output = impl.elementwise.mul(
254+
ctx, target, source_ir, f"{name}_mul_weight", output, weight
255+
)
405256

257+
if bias is not None:
258+
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
259+
bias = cast_trt_tensor(
260+
ctx, bias, input.dtype, f"{name}_cast_bias", target, source_ir
261+
)
262+
bias = impl.shuffle.reshape(
263+
ctx, target, source_ir, f"{name}_reshape_bias", bias, shape
264+
)
265+
output = impl.elementwise.add(
266+
ctx, target, source_ir, f"{name}_add_bias", output, bias
267+
)
406268

407-
def group_norm(
408-
ctx: ConversionContext,
409-
target: Target,
410-
source_ir: Optional[SourceIR],
411-
name: str,
412-
input: TRTTensor,
413-
num_groups: int,
414-
weight: Optional[Union[torch.Tensor, np.ndarray]],
415-
bias: Optional[Union[torch.Tensor, np.ndarray]],
416-
eps: float,
417-
cudnn_enabled: bool,
418-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
419-
return native_group_norm(
420-
ctx,
421-
target,
422-
source_ir,
423-
name,
424-
input,
425-
weight,
426-
bias,
427-
0,
428-
0,
429-
0,
430-
num_groups,
431-
eps,
432-
return_mean_rstd=False,
433-
)
269+
# return fake mean and rstd for now
270+
return output, None, None
434271

435272

436273
def softmax(

0 commit comments

Comments
 (0)