Skip to content

Commit 4f585d8

Browse files
committed
rebase and update three norms
1 parent 2da25ad commit 4f585d8

File tree

5 files changed

+451
-347
lines changed

5 files changed

+451
-347
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 33 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -46,31 +46,22 @@ def get_ir(target: Target) -> SourceIR:
4646
return SourceIR.UNKNOWN
4747

4848

49-
@dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default) # type: ignore[misc]
50-
def aten_ops_native_batch_norm(
51-
ctx: ConversionContext,
52-
target: Target,
53-
args: Tuple[Argument, ...],
54-
kwargs: Dict[str, Argument],
55-
name: str,
56-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
57-
return impl.normalization.native_batch_norm(
58-
ctx,
59-
target,
60-
SourceIR.ATEN,
61-
name,
62-
input=args[0],
63-
weight=args[1],
64-
bias=args[2],
65-
running_mean=args[3],
66-
running_var=args[4],
67-
training=args[5],
68-
momentum=args[6],
69-
eps=args[7],
49+
def one_user_validator(node: Node) -> bool:
50+
# Validate only one user, which is a getitem node that accesses the first element in the list
51+
return (
52+
len(node.users) == 1
53+
and list(node.users)[0].target == operator.getitem
54+
and list(node.users)[0].args[1] == 0
7055
)
7156

7257

73-
@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc]
58+
@dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default, capability_validator=one_user_validator) # type: ignore[misc]
59+
@dynamo_tensorrt_converter(torch.ops.aten.batch_norm.default) # type: ignore[misc]
60+
@enforce_tensor_types(
61+
{
62+
0: (TRTTensor,),
63+
}
64+
) # type: ignore[misc]
7465
def aten_ops_batch_norm(
7566
ctx: ConversionContext,
7667
target: Target,
@@ -91,32 +82,18 @@ def aten_ops_batch_norm(
9182
training=args[5],
9283
momentum=args[6],
9384
eps=args[7],
94-
cudnn_enabled=args[8],
95-
)
96-
97-
98-
@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default) # type: ignore[misc]
99-
def aten_ops_native_layer_norm(
100-
ctx: ConversionContext,
101-
target: Target,
102-
args: Tuple[Argument, ...],
103-
kwargs: Dict[str, Argument],
104-
name: str,
105-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
106-
return impl.normalization.native_layer_norm(
107-
ctx,
108-
target,
109-
SourceIR.ATEN,
110-
name,
111-
input=args[0],
112-
normalized_shape=args[1],
113-
weight=args[2],
114-
bias=args[3],
115-
eps=args[4],
85+
cudnn_enabled=args_bounds_check(args, 8, True),
86+
return_mean_rstd=(target == torch.ops.aten.native_batch_norm.default),
11687
)
11788

11889

90+
@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator) # type: ignore[misc]
11991
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
92+
@enforce_tensor_types(
93+
{
94+
0: (TRTTensor,),
95+
}
96+
) # type: ignore[misc]
12097
def aten_ops_layer_norm(
12198
ctx: ConversionContext,
12299
target: Target,
@@ -135,10 +112,16 @@ def aten_ops_layer_norm(
135112
bias=args_bounds_check(args, 3),
136113
eps=args_bounds_check(args, 4, 1e-05),
137114
cudnn_enable=args_bounds_check(args, 5, True),
115+
return_mean_rstd=(target == torch.ops.aten.native_layer_norm.default),
138116
)
139117

140118

141-
@dynamo_tensorrt_converter(torch.ops.aten.native_group_norm.default) # type: ignore[misc]
119+
@dynamo_tensorrt_converter(torch.ops.aten.native_group_norm.default, capability_validator=one_user_validator) # type: ignore[misc]
120+
@enforce_tensor_types(
121+
{
122+
0: (TRTTensor,),
123+
}
124+
) # type: ignore[misc]
142125
def aten_ops_native_group_norm(
143126
ctx: ConversionContext,
144127
target: Target,
@@ -163,6 +146,11 @@ def aten_ops_native_group_norm(
163146

164147

165148
@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) # type: ignore[misc]
149+
@enforce_tensor_types(
150+
{
151+
0: (TRTTensor,),
152+
}
153+
) # type: ignore[misc]
166154
def aten_ops_group_norm(
167155
ctx: ConversionContext,
168156
target: Target,
@@ -838,15 +826,6 @@ def aten_ops_prod(
838826
)
839827

840828

841-
def one_user_validator(node: Node) -> bool:
842-
# Validate only one user, which is a getitem node that accesses the first element in the list
843-
return (
844-
len(node.users) == 1
845-
and list(node.users)[0].target == operator.getitem
846-
and list(node.users)[0].args[1] == 0
847-
)
848-
849-
850829
@dynamo_tensorrt_converter(torch.ops.aten.max.default) # type: ignore[misc]
851830
@dynamo_tensorrt_converter(torch.ops.aten.max.dim, capability_validator=one_user_validator) # type: ignore[misc]
852831
def aten_ops_max(

0 commit comments

Comments
 (0)