Skip to content

Commit 903802b

Browse files
committed
update batch_norm and layer_norm
1 parent 80bbd8b commit 903802b

File tree

3 files changed

+37
-33
lines changed

3 files changed

+37
-33
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,37 @@ def aten_ops_batch_norm(
5959
target,
6060
SourceIR.ATEN,
6161
name,
62-
args[0],
63-
args[1],
64-
args[2],
65-
args[3],
66-
args[4],
67-
args[5],
68-
args[6],
69-
args[7],
62+
input=args[0],
63+
weight=args_bounds_check(args, 1, replacement=1),
64+
bias=args_bounds_check(args, 2, replacement=0),
65+
running_mean=args_bounds_check(args, 3),
66+
running_var=args_bounds_check(args, 4),
67+
training=args_bounds_check(args, 5),
68+
momentum=args_bounds_check(args, 6, replacement=0.1),
69+
eps=args_bounds_check(args, 7, replacement=1e-05),
70+
cudnn_enabled=args_bounds_check(args, 8, replacement=False),
71+
)
72+
73+
74+
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
75+
def aten_ops_layer_norm(
76+
ctx: ConversionContext,
77+
target: Target,
78+
args: Tuple[Argument, ...],
79+
kwargs: Dict[str, Argument],
80+
name: str,
81+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
82+
return impl.normalization.layer_norm(
83+
ctx,
84+
target,
85+
SourceIR.ATEN,
86+
name,
87+
input=args[0],
88+
normalized_shape=args[1],
89+
weight=args_bounds_check(args, 2, replacement=1),
90+
bias=args_bounds_check(args, 3, replacement=0),
91+
eps=args_bounds_check(args, 4, replacement=1e-05),
92+
cudnn_enable=args_bounds_check(args, 5, replacement=True),
7093
)
7194

7295

@@ -310,27 +333,6 @@ def aten_ops_matmul(
310333
)
311334

312335

313-
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
314-
def aten_ops_layernorm(
315-
ctx: ConversionContext,
316-
target: Target,
317-
args: Tuple[Argument, ...],
318-
kwargs: Dict[str, Argument],
319-
name: str,
320-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
321-
return impl.normalization.layer_norm(
322-
ctx,
323-
target,
324-
SourceIR.ATEN,
325-
name,
326-
args[0],
327-
args[1],
328-
args[2],
329-
args[3],
330-
args[4],
331-
)
332-
333-
334336
@dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default) # type: ignore[misc]
335337
def aten_ops_rsqrt(
336338
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def batch_norm(
3636
training: torch.Tensor,
3737
momentum: torch.Tensor,
3838
eps: List[float],
39+
cudnn_enabled: bool,
3940
) -> Union[TRTTensor, Sequence[TRTTensor]]:
4041
if not isinstance(input, TRTTensor):
4142
raise RuntimeError(
@@ -69,7 +70,7 @@ def batch_norm(
6970
input.shape[2],
7071
1,
7172
)
72-
set_layer_name(reshape_layer, target, f"{name}_reshape_2d")
73+
set_layer_name(reshape_layer, target, f"{name}_reshape_2d", source_ir)
7374
input = reshape_layer.get_output(0)
7475
layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power)
7576
set_layer_name(layer, target, name)
@@ -78,7 +79,7 @@ def batch_norm(
7879
if not ctx.net.has_implicit_batch_dimension and len(output_shape) < 4:
7980
reshape_output_layer = ctx.net.add_shuffle(layer.get_output(0))
8081
reshape_output_layer.reshape_dims = tuple(output_shape)
81-
set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d")
82+
set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d", source_ir)
8283
layer = reshape_output_layer
8384
return layer.get_output(0)
8485

@@ -93,6 +94,7 @@ def layer_norm(
9394
weight: torch.Tensor,
9495
bias: torch.Tensor,
9596
eps: List[float],
97+
cudnn_enable: bool,
9698
) -> Union[TRTTensor, Sequence[TRTTensor]]:
9799
if not isinstance(input, trt.tensorrt.ITensor):
98100
raise RuntimeError(
@@ -173,7 +175,7 @@ def layer_norm_no_plugin(
173175
mean_expected_layer = ctx.net.add_reduce(
174176
input, trt.ReduceOperation.AVG, axes, keep_dims=True
175177
)
176-
set_layer_name(mean_expected_layer, target, f"{name}_mean_expected")
178+
set_layer_name(mean_expected_layer, target, f"{name}_mean_expected", source_ir)
177179

178180
# X-E[x]
179181
sub_trt = convert_binary_elementwise(
@@ -203,7 +205,7 @@ def layer_norm_no_plugin(
203205
mean_trt_layer = ctx.net.add_reduce(
204206
pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True
205207
)
206-
set_layer_name(mean_trt_layer, target, f"{name}_mean")
208+
set_layer_name(mean_trt_layer, target, f"{name}_mean", source_ir)
207209
# Variance + eps
208210
eps_tensor = ctx.net.add_constant(
209211
(1,) * len(input.shape),

0 commit comments

Comments
 (0)