Skip to content

Commit dce9526

Browse files
committed
update batch_norm and layer_norm
1 parent 8ebb599 commit dce9526

File tree

3 files changed

+38
-34
lines changed

3 files changed

+38
-34
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,37 @@ def aten_ops_batch_norm(
3131
target,
3232
SourceIR.ATEN,
3333
name,
34-
args[0],
35-
args[1],
36-
args[2],
37-
args[3],
38-
args[4],
39-
args[5],
40-
args[6],
41-
args[7],
34+
input=args[0],
35+
weight=args_bounds_check(args, 1, replacement=1),
36+
bias=args_bounds_check(args, 2, replacement=0),
37+
running_mean=args_bounds_check(args, 3),
38+
running_var=args_bounds_check(args, 4),
39+
training=args_bounds_check(args, 5),
40+
momentum=args_bounds_check(args, 6, replacement=0.1),
41+
eps=args_bounds_check(args, 7, replacement=1e-05),
42+
cudnn_enabled=args_bounds_check(args, 8, replacement=False),
43+
)
44+
45+
46+
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
47+
def aten_ops_layer_norm(
48+
network: TRTNetwork,
49+
target: Target,
50+
args: Tuple[Argument, ...],
51+
kwargs: Dict[str, Argument],
52+
name: str,
53+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
54+
return impl.normalization.layer_norm(
55+
network,
56+
target,
57+
SourceIR.ATEN,
58+
name,
59+
input=args[0],
60+
normalized_shape=args[1],
61+
weight=args_bounds_check(args, 2, replacement=1),
62+
bias=args_bounds_check(args, 3, replacement=0),
63+
eps=args_bounds_check(args, 4, replacement=1e-05),
64+
cudnn_enable=args_bounds_check(args, 5, replacement=True),
4265
)
4366

4467

@@ -258,27 +281,6 @@ def aten_ops_matmul(
258281
)
259282

260283

261-
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
262-
def aten_ops_layernorm(
263-
network: TRTNetwork,
264-
target: Target,
265-
args: Tuple[Argument, ...],
266-
kwargs: Dict[str, Argument],
267-
name: str,
268-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
269-
return impl.normalization.layer_norm(
270-
network,
271-
target,
272-
SourceIR.ATEN,
273-
name,
274-
args[0],
275-
args[1],
276-
args[2],
277-
args[3],
278-
args[4],
279-
)
280-
281-
282284
@dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default) # type: ignore[misc]
283285
def aten_ops_rsqrt(
284286
network: TRTNetwork,

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def batch_norm(
3636
training: torch.Tensor,
3737
momentum: torch.Tensor,
3838
eps: List[float],
39+
cudnn_enabled: bool,
3940
) -> Union[TRTTensor, Sequence[TRTTensor]]:
4041
if not isinstance(input, TRTTensor):
4142
raise RuntimeError(
@@ -69,16 +70,16 @@ def batch_norm(
6970
input.shape[2],
7071
1,
7172
)
72-
set_layer_name(reshape_layer, target, f"{name}_reshape_2d")
73+
set_layer_name(reshape_layer, target, f"{name}_reshape_2d", source_ir)
7374
input = reshape_layer.get_output(0)
7475
layer = network.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power)
75-
set_layer_name(layer, target, name)
76+
set_layer_name(layer, target, name, source_ir)
7677

7778
# For BatchNorm1d, reshape output back to 1d
7879
if not network.has_implicit_batch_dimension and len(output_shape) < 4:
7980
reshape_output_layer = network.add_shuffle(layer.get_output(0))
8081
reshape_output_layer.reshape_dims = tuple(output_shape)
81-
set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d")
82+
set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d", source_ir)
8283
layer = reshape_output_layer
8384
return layer.get_output(0)
8485

@@ -93,6 +94,7 @@ def layer_norm(
9394
weight: torch.Tensor,
9495
bias: torch.Tensor,
9596
eps: List[float],
97+
cudnn_enable: bool,
9698
) -> Union[TRTTensor, Sequence[TRTTensor]]:
9799
if not isinstance(input, trt.tensorrt.ITensor):
98100
raise RuntimeError(
@@ -173,7 +175,7 @@ def layer_norm_no_plugin(
173175
mean_expected_layer = network.add_reduce(
174176
input, trt.ReduceOperation.AVG, axes, keep_dims=True
175177
)
176-
set_layer_name(mean_expected_layer, target, f"{name}_mean_expected")
178+
set_layer_name(mean_expected_layer, target, f"{name}_mean_expected", source_ir)
177179

178180
# X-E[x]
179181
sub_trt = convert_binary_elementwise(
@@ -203,7 +205,7 @@ def layer_norm_no_plugin(
203205
mean_trt_layer = network.add_reduce(
204206
pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True
205207
)
206-
set_layer_name(mean_trt_layer, target, f"{name}_mean")
208+
set_layer_name(mean_trt_layer, target, f"{name}_mean", source_ir)
207209
# Variance + eps
208210
eps_tensor = network.add_constant(
209211
(1,) * len(input.shape),

0 commit comments

Comments
 (0)