Skip to content

Commit 57c66bf

Browse files
committed
Use INormalizationLayer layer for GroupNorm
1 parent 8e2c82d commit 57c66bf

File tree

3 files changed

+73
-369
lines changed

3 files changed

+73
-369
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,14 @@ def aten_ops_layer_norm(
175175
0: (TRTTensor,),
176176
}
177177
)
178-
def aten_ops_native_group_norm(
178+
def aten_ops_group_norm(
179179
ctx: ConversionContext,
180180
target: Target,
181181
args: Tuple[Argument, ...],
182182
kwargs: Dict[str, Argument],
183183
name: str,
184184
) -> Union[TRTTensor, Sequence[TRTTensor]]:
185-
return impl.normalization.native_group_norm(
185+
return impl.normalization.group_norm(
186186
ctx,
187187
target,
188188
SourceIR.ATEN,
@@ -198,40 +198,6 @@ def aten_ops_native_group_norm(
198198
)
199199

200200

201-
@dynamo_tensorrt_converter(
202-
torch.ops.aten.group_norm.default,
203-
supports_dynamic_shapes=True,
204-
)
205-
@dynamo_tensorrt_converter(
206-
torch.ops.aten.group_norm,
207-
supports_dynamic_shapes=True,
208-
)
209-
@enforce_tensor_types(
210-
{
211-
0: (TRTTensor,),
212-
}
213-
)
214-
def aten_ops_group_norm(
215-
ctx: ConversionContext,
216-
target: Target,
217-
args: Tuple[Argument, ...],
218-
kwargs: Dict[str, Argument],
219-
name: str,
220-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
221-
return impl.normalization.group_norm(
222-
ctx,
223-
target,
224-
SourceIR.ATEN,
225-
name,
226-
input=args[0],
227-
num_groups=args[1],
228-
weight=args_bounds_check(args, 2, None),
229-
bias=args_bounds_check(args, 3, None),
230-
eps=args_bounds_check(args, 4, 1e-05),
231-
cudnn_enabled=args_bounds_check(args, 5, True),
232-
)
233-
234-
235201
@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True)
236202
def aten_ops_cat(
237203
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 48 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
2+
from typing import List, Optional, Sequence, Tuple, Union
33

44
import numpy as np
55
import tensorrt as trt
@@ -16,7 +16,6 @@
1616
get_trt_tensor,
1717
has_dynamic_shape,
1818
set_layer_name,
19-
to_numpy,
2019
)
2120
from torch_tensorrt.dynamo.conversion.impl.cat import cat
2221
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
@@ -204,240 +203,80 @@ def layer_norm(
204203
return layer_norm.get_output(0)
205204

206205

207-
def native_group_norm(
206+
def group_norm(
208207
ctx: ConversionContext,
209208
target: Target,
210209
source_ir: Optional[SourceIR],
211210
name: str,
212211
input: TRTTensor,
213-
weight: Optional[Union[torch.Tensor, np.ndarray]],
214-
bias: Optional[Union[torch.Tensor, np.ndarray]],
212+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
213+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
215214
N: int,
216215
C: int,
217216
HxW: int,
218217
group: int,
219218
eps: float,
220-
return_mean_rstd: bool = True,
221-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
222-
# TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
223-
# with INormalization Layer
219+
) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]:
224220
assert (
225221
len(input.shape) >= 3
226-
), f"The input dimension should not be less than 3, got {len(input.shape)}!"
222+
), f"Expected at least 3 dimensions for input tensor but got {len(input.shape)}"
227223

228-
B = input.shape[0]
229-
# if C is provided, it must be as same as the channel from the input shape,
230-
# else if C is zero, we should get the channel from the input shape
231-
if C == 0:
232-
C = input.shape[1]
233224
assert (
234225
C == input.shape[1]
235-
), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
236-
# Groups are a subdivision of the channel dimension.
237-
assert (
238-
C % group == 0
239-
), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
240-
input = get_trt_tensor(ctx, input, f"{name}_input")
241-
242-
shape = list(input.shape)
243-
244-
for i, s in enumerate(shape):
245-
if i == 0 and s > 0:
246-
shape[i] = B * group
247-
elif i == 1:
248-
shape[i] = C // group
249-
elif i > 1 and s == -1:
250-
shape[i] = 0
251-
252-
# Normalize every group.
253-
reshaped_input = impl.shuffle.reshape(
254-
ctx,
255-
target,
256-
source_ir,
257-
f"{name}_reshape_input",
258-
input,
259-
shape,
260-
)
261-
262-
if weight is None:
263-
weight = to_numpy(1.0)
226+
), f"num_channels ({C}) must be equal to number of channels in input ({input.shape[1]})"
264227

265-
if bias is None:
266-
bias = to_numpy(0.0)
228+
weight_one = get_trt_tensor(ctx, 1.0, f"{name}_weight_one", input.dtype)
229+
bias_zero = get_trt_tensor(ctx, 0.0, f"{name}_bias_zero", input.dtype)
267230

268-
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
269-
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
270-
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
271-
272-
dims = list(range(1, len(input.shape)))
273-
274-
# E[X]
275-
mean_trt = impl.reduce.mean(
276-
ctx,
277-
target,
278-
source_ir,
279-
f"{name}_mean",
280-
reshaped_input,
281-
dims,
282-
True,
283-
)
284-
285-
mean_trt = impl.slice.expand(
286-
ctx,
287-
target,
288-
source_ir,
289-
f"{name}_expand_mean_trt",
290-
mean_trt,
291-
reshaped_input.shape,
292-
)
293-
294-
# X - E[X]
295-
sub_trt = impl.elementwise.sub(
296-
ctx,
297-
target,
298-
source_ir,
299-
f"{name}_sub",
300-
reshaped_input,
301-
mean_trt,
302-
)
303-
304-
# variance
305-
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
306-
pow_var = impl.elementwise.pow(
307-
ctx,
308-
target,
309-
source_ir,
310-
f"{name}_pow",
311-
sub_trt,
312-
pow_trt,
313-
)
314-
315-
var_trt = impl.reduce.mean(
316-
ctx,
317-
target,
318-
source_ir,
319-
f"{name}_mean_var",
320-
pow_var,
321-
dims,
322-
True,
323-
)
324-
325-
var_trt = impl.slice.expand(
326-
ctx,
327-
target,
328-
source_ir,
329-
f"{name}_expand_var_trt",
330-
var_trt,
331-
reshaped_input.shape,
332-
)
333-
334-
eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
335-
add_trt = impl.elementwise.add(
336-
ctx,
337-
target,
338-
source_ir,
339-
f"{name}_add",
340-
var_trt,
341-
eps_trt,
342-
)
231+
shape = [1, group] + [1] * (len(input.shape) - 2)
343232

344-
sqrt_trt = impl.unary.sqrt(
345-
ctx,
346-
target,
347-
source_ir,
348-
f"{name}_sqrt",
349-
add_trt,
233+
expanded_weight_one = impl.slice.expand(
234+
ctx, target, source_ir, f"{name}_expand_weight_one", weight_one, shape
350235
)
351-
352-
# y = (X - E[X]) / sqrt((var + eps))
353-
output = impl.elementwise.div(
354-
ctx,
355-
target,
356-
source_ir,
357-
f"{name}_div",
358-
sub_trt,
359-
sqrt_trt,
236+
expanded_bias_zero = impl.slice.expand(
237+
ctx, target, source_ir, f"{name}_expand_bias_zero", bias_zero, shape
360238
)
361239

362-
shape = list(output.shape)
363-
for i, s in enumerate(shape):
364-
if i == 0 and s > 0:
365-
shape[i] = B
366-
elif i == 1:
367-
shape[i] = C
368-
elif i > 1 and s == -1:
369-
shape[i] = 0
240+
axes = get_axes_for_reduce_op([i for i in range(2, len(input.shape))])
370241

371-
reshaped_output = impl.shuffle.reshape(
372-
ctx, target, source_ir, f"{name}_reshape_output", output, shape
373-
)
374-
reshaped_gamma = impl.shuffle.reshape(
375-
ctx,
376-
target,
377-
source_ir,
378-
f"{name}_reshape_gamma",
379-
weight,
380-
weight_bias_shape,
242+
# INormalizationLayer scales the normalized output per-group, but PyTorch scales the normalized output per-channel,
243+
# hence causing diverse result. Let TensorRT does no-op for scaling here, and do scaling ourselves later
244+
layer = ctx.net.add_normalization(
245+
input, expanded_weight_one, expanded_bias_zero, axes
381246
)
247+
layer.epsilon = eps
248+
layer.num_groups = group
249+
set_layer_name(layer, target, name, source_ir)
250+
output = layer.get_output(0)
382251

383-
reshaped_output = impl.elementwise.mul(
384-
ctx,
385-
target,
386-
source_ir,
387-
f"{name}_mul_gamma",
388-
reshaped_output,
389-
reshaped_gamma,
390-
)
252+
shape[1] = C
391253

392-
reshaped_bias = impl.shuffle.reshape(
393-
ctx,
394-
target,
395-
source_ir,
396-
f"{name}_reshape_beta",
397-
bias,
398-
weight_bias_shape,
399-
)
400-
reshaped_output = impl.elementwise.add(
401-
ctx,
402-
target,
403-
source_ir,
404-
f"{name}_add_beta",
405-
reshaped_output,
406-
reshaped_bias,
407-
)
408-
if return_mean_rstd:
409-
# return fake mean and rstd for now
410-
return reshaped_output, None, None
411-
return reshaped_output
254+
if weight is not None:
255+
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
256+
weight = cast_trt_tensor(
257+
ctx, weight, input.dtype, f"{name}_cast_weight", target, source_ir
258+
)
259+
weight = impl.shuffle.reshape(
260+
ctx, target, source_ir, f"{name}_reshape_weight", weight, shape
261+
)
262+
output = impl.elementwise.mul(
263+
ctx, target, source_ir, f"{name}_mul_weight", output, weight
264+
)
412265

266+
if bias is not None:
267+
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
268+
bias = cast_trt_tensor(
269+
ctx, bias, input.dtype, f"{name}_cast_bias", target, source_ir
270+
)
271+
bias = impl.shuffle.reshape(
272+
ctx, target, source_ir, f"{name}_reshape_bias", bias, shape
273+
)
274+
output = impl.elementwise.add(
275+
ctx, target, source_ir, f"{name}_add_bias", output, bias
276+
)
413277

414-
def group_norm(
415-
ctx: ConversionContext,
416-
target: Target,
417-
source_ir: Optional[SourceIR],
418-
name: str,
419-
input: TRTTensor,
420-
num_groups: int,
421-
weight: Optional[Union[torch.Tensor, np.ndarray]],
422-
bias: Optional[Union[torch.Tensor, np.ndarray]],
423-
eps: float,
424-
cudnn_enabled: bool,
425-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
426-
return native_group_norm(
427-
ctx,
428-
target,
429-
source_ir,
430-
name,
431-
input,
432-
weight,
433-
bias,
434-
0,
435-
0,
436-
0,
437-
num_groups,
438-
eps,
439-
return_mean_rstd=False,
440-
)
278+
# return fake mean and rstd for now
279+
return output, None, None
441280

442281

443282
def softmax(

0 commit comments

Comments
 (0)