Skip to content

Commit 426c2f4

Browse files
committed
fix type bug
support group norm, and improve batch and layer norms
1 parent cc967cf commit 426c2f4

File tree

4 files changed

+290
-45
lines changed

4 files changed

+290
-45
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,29 @@ def get_ir(target: Target) -> SourceIR:
4747

4848

4949
@dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default) # type: ignore[misc]
50+
def aten_ops_native_batch_norm(
51+
ctx: ConversionContext,
52+
target: Target,
53+
args: Tuple[Argument, ...],
54+
kwargs: Dict[str, Argument],
55+
name: str,
56+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
57+
return impl.normalization.native_batch_norm(
58+
ctx,
59+
target,
60+
SourceIR.ATEN,
61+
name,
62+
input=args[0],
63+
weight=args[1],
64+
bias=args[2],
65+
running_mean=args[3],
66+
running_var=args[4],
67+
training=args[5],
68+
momentum=args[6],
69+
eps=args[7],
70+
)
71+
72+
5073
@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc]
5174
def aten_ops_batch_norm(
5275
ctx: ConversionContext,
@@ -68,20 +91,19 @@ def aten_ops_batch_norm(
6891
training=args[5],
6992
momentum=args[6],
7093
eps=args[7],
71-
cudnn_enabled=args_bounds_check(args, 8, replacement=True),
94+
cudnn_enabled=args[8],
7295
)
7396

7497

7598
@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default) # type: ignore[misc]
76-
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
77-
def aten_ops_layer_norm(
99+
def aten_ops_native_layer_norm(
78100
ctx: ConversionContext,
79101
target: Target,
80102
args: Tuple[Argument, ...],
81103
kwargs: Dict[str, Argument],
82104
name: str,
83105
) -> Union[TRTTensor, Sequence[TRTTensor]]:
84-
return impl.normalization.layer_norm(
106+
return impl.normalization.native_layer_norm(
85107
ctx,
86108
target,
87109
SourceIR.ATEN,
@@ -91,7 +113,74 @@ def aten_ops_layer_norm(
91113
weight=args[2],
92114
bias=args[3],
93115
eps=args[4],
94-
cudnn_enable=args_bounds_check(args, 5, replacement=True),
116+
)
117+
118+
119+
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
120+
def aten_ops_layer_norm(
121+
ctx: ConversionContext,
122+
target: Target,
123+
args: Tuple[Argument, ...],
124+
kwargs: Dict[str, Argument],
125+
name: str,
126+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
127+
return impl.normalization.layer_norm(
128+
ctx,
129+
target,
130+
SourceIR.ATEN,
131+
name,
132+
input=args[0],
133+
normalized_shape=args[1],
134+
weight=args_bounds_check(args, 2),
135+
bias=args_bounds_check(args, 3),
136+
eps=args_bounds_check(args, 4, 1e-05),
137+
cudnn_enable=args_bounds_check(args, 5, True),
138+
)
139+
140+
141+
@dynamo_tensorrt_converter(torch.ops.aten.native_group_norm.default) # type: ignore[misc]
142+
def aten_ops_native_group_norm(
143+
ctx: ConversionContext,
144+
target: Target,
145+
args: Tuple[Argument, ...],
146+
kwargs: Dict[str, Argument],
147+
name: str,
148+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
149+
return impl.normalization.native_group_norm(
150+
ctx,
151+
target,
152+
SourceIR.ATEN,
153+
name,
154+
input=args[0],
155+
weight=args[1],
156+
bias=args[2],
157+
N=args[3],
158+
C=args[4],
159+
HxW=args[5],
160+
group=args[6],
161+
eps=args[7],
162+
)
163+
164+
165+
@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) # type: ignore[misc]
166+
def aten_ops_group_norm(
167+
ctx: ConversionContext,
168+
target: Target,
169+
args: Tuple[Argument, ...],
170+
kwargs: Dict[str, Argument],
171+
name: str,
172+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
173+
return impl.normalization.group_norm(
174+
ctx,
175+
target,
176+
SourceIR.ATEN,
177+
name,
178+
input=args[0],
179+
num_groups=args[1],
180+
weight=args_bounds_check(args, 2, None),
181+
bias=args_bounds_check(args, 3, None),
182+
eps=args_bounds_check(args, 4, 1e-05),
183+
cudnn_enabled=args_bounds_check(args, 5, True),
95184
)
96185

97186

0 commit comments

Comments
 (0)