Skip to content

Commit 1c24432

Browse files
authored
feat: support group_norm, batch_norm, and layer_norm (#2330)
1 parent cb2aee0 commit 1c24432

File tree

8 files changed

+731
-293
lines changed

8 files changed

+731
-293
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 113 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,23 @@ def get_ir(target: Target) -> SourceIR:
4646
return SourceIR.UNKNOWN
4747

4848

49+
def one_user_validator(node: Node) -> bool:
50+
# Validate only one user, which is a getitem node that accesses the first element in the list
51+
return (
52+
len(node.users) == 1
53+
and list(node.users)[0].target == operator.getitem
54+
and list(node.users)[0].args[1] == 0
55+
)
56+
57+
58+
@dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default, capability_validator=one_user_validator) # type: ignore[misc]
59+
@dynamo_tensorrt_converter(torch.ops.aten.batch_norm.default) # type: ignore[misc]
4960
@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc]
61+
@enforce_tensor_types(
62+
{
63+
0: (TRTTensor,),
64+
}
65+
) # type: ignore[misc]
5066
def aten_ops_batch_norm(
5167
ctx: ConversionContext,
5268
target: Target,
@@ -59,14 +75,103 @@ def aten_ops_batch_norm(
5975
target,
6076
SourceIR.ATEN,
6177
name,
62-
args[0],
63-
args[1],
64-
args[2],
65-
args[3],
66-
args[4],
67-
args[5],
68-
args[6],
69-
args[7],
78+
input=args[0],
79+
weight=args[1],
80+
bias=args[2],
81+
running_mean=args[3],
82+
running_var=args[4],
83+
training=args[5],
84+
momentum=args[6],
85+
eps=args[7],
86+
cudnn_enabled=args_bounds_check(args, 8, True),
87+
return_mean_rstd=(target == torch.ops.aten.native_batch_norm.default),
88+
)
89+
90+
91+
@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator) # type: ignore[misc]
92+
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
93+
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm) # type: ignore[misc]
94+
@enforce_tensor_types(
95+
{
96+
0: (TRTTensor,),
97+
}
98+
) # type: ignore[misc]
99+
def aten_ops_layer_norm(
100+
ctx: ConversionContext,
101+
target: Target,
102+
args: Tuple[Argument, ...],
103+
kwargs: Dict[str, Argument],
104+
name: str,
105+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
106+
return impl.normalization.layer_norm(
107+
ctx,
108+
target,
109+
SourceIR.ATEN,
110+
name,
111+
input=args[0],
112+
normalized_shape=args[1],
113+
weight=args_bounds_check(args, 2),
114+
bias=args_bounds_check(args, 3),
115+
eps=args_bounds_check(args, 4, 1e-05),
116+
cudnn_enable=args_bounds_check(args, 5, True),
117+
return_mean_rstd=(target == torch.ops.aten.native_layer_norm.default),
118+
)
119+
120+
121+
@dynamo_tensorrt_converter(torch.ops.aten.native_group_norm.default, capability_validator=one_user_validator) # type: ignore[misc]
122+
@enforce_tensor_types(
123+
{
124+
0: (TRTTensor,),
125+
}
126+
) # type: ignore[misc]
127+
def aten_ops_native_group_norm(
128+
ctx: ConversionContext,
129+
target: Target,
130+
args: Tuple[Argument, ...],
131+
kwargs: Dict[str, Argument],
132+
name: str,
133+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
134+
return impl.normalization.native_group_norm(
135+
ctx,
136+
target,
137+
SourceIR.ATEN,
138+
name,
139+
input=args[0],
140+
weight=args[1],
141+
bias=args[2],
142+
N=args[3],
143+
C=args[4],
144+
HxW=args[5],
145+
group=args[6],
146+
eps=args[7],
147+
)
148+
149+
150+
@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) # type: ignore[misc]
151+
@dynamo_tensorrt_converter(torch.ops.aten.group_norm) # type: ignore[misc]
152+
@enforce_tensor_types(
153+
{
154+
0: (TRTTensor,),
155+
}
156+
) # type: ignore[misc]
157+
def aten_ops_group_norm(
158+
ctx: ConversionContext,
159+
target: Target,
160+
args: Tuple[Argument, ...],
161+
kwargs: Dict[str, Argument],
162+
name: str,
163+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
164+
return impl.normalization.group_norm(
165+
ctx,
166+
target,
167+
SourceIR.ATEN,
168+
name,
169+
input=args[0],
170+
num_groups=args[1],
171+
weight=args_bounds_check(args, 2, None),
172+
bias=args_bounds_check(args, 3, None),
173+
eps=args_bounds_check(args, 4, 1e-05),
174+
cudnn_enabled=args_bounds_check(args, 5, True),
70175
)
71176

72177

@@ -328,27 +433,6 @@ def aten_ops_matmul(
328433
)
329434

330435

331-
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
332-
def aten_ops_layernorm(
333-
ctx: ConversionContext,
334-
target: Target,
335-
args: Tuple[Argument, ...],
336-
kwargs: Dict[str, Argument],
337-
name: str,
338-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
339-
return impl.normalization.layer_norm(
340-
ctx,
341-
target,
342-
SourceIR.ATEN,
343-
name,
344-
args[0],
345-
args[1],
346-
args[2],
347-
args[3],
348-
args[4],
349-
)
350-
351-
352436
@dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default) # type: ignore[misc]
353437
def aten_ops_rsqrt(
354438
ctx: ConversionContext,
@@ -763,15 +847,6 @@ def aten_ops_prod(
763847
)
764848

765849

766-
def one_user_validator(node: Node) -> bool:
767-
# Validate only one user, which is a getitem node that accesses the first element in the list
768-
return (
769-
len(node.users) == 1
770-
and list(node.users)[0].target == operator.getitem
771-
and list(node.users)[0].args[1] == 0
772-
)
773-
774-
775850
@dynamo_tensorrt_converter(torch.ops.aten.max.default) # type: ignore[misc]
776851
@dynamo_tensorrt_converter(torch.ops.aten.max.dim, capability_validator=one_user_validator) # type: ignore[misc]
777852
def aten_ops_max(

0 commit comments

Comments
 (0)