Skip to content

Commit fd820e6

Browse files
committed
support group norm, and improve batch and layer norms
1 parent 9dc5e5d commit fd820e6

File tree

4 files changed

+208
-30
lines changed

4 files changed

+208
-30
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,52 @@ def aten_ops_layer_norm(
6767
)
6868

6969

70+
@dynamo_tensorrt_converter(torch.ops.aten.native_group_norm.default) # type: ignore[misc]
71+
def aten_ops_native_group_norm(
72+
network: TRTNetwork,
73+
target: Target,
74+
args: Tuple[Argument, ...],
75+
kwargs: Dict[str, Argument],
76+
name: str,
77+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
78+
return impl.normalization.native_group_norm(
79+
network,
80+
target,
81+
SourceIR.ATEN,
82+
name,
83+
input=args[0],
84+
weight=args[1],
85+
bias=args[2],
86+
N=args[3],
87+
C=args[4],
88+
HxW=args[5],
89+
group=args[6],
90+
eps=args[7],
91+
)
92+
93+
94+
@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default) # type: ignore[misc]
95+
def aten_ops_group_norm(
96+
network: TRTNetwork,
97+
target: Target,
98+
args: Tuple[Argument, ...],
99+
kwargs: Dict[str, Argument],
100+
name: str,
101+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
102+
return impl.normalization.group_norm(
103+
network,
104+
target,
105+
SourceIR.ATEN,
106+
name,
107+
input=args[0],
108+
num_groups=args[1],
109+
weight=args_bounds_check(args, 2, None),
110+
bias=args_bounds_check(args, 3, None),
111+
eps=args_bounds_check(args, 4, 1e-05),
112+
cudnn_enabled=args_bounds_check(args, 5, True),
113+
)
114+
115+
70116
def embedding_param_validator(embedding_node: Node) -> bool:
71117
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
72118
sparse = args_bounds_check(embedding_node.args, 4)

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 107 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
import torch
77
from torch.fx.node import Target
88
from torch_tensorrt.dynamo._SourceIR import SourceIR
9-
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
10-
convert_binary_elementwise,
11-
)
12-
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
9+
from torch_tensorrt.dynamo.conversion import impl
10+
from torch_tensorrt.dynamo.conversion.converter_utils import get_axes_for_reduce_op
1311
from torch_tensorrt.fx.converters.converter_utils import (
1412
get_positive_dim,
1513
get_trt_plugin,
14+
get_trt_tensor,
1615
has_dynamic_shape,
1716
set_layer_name,
1817
to_numpy,
@@ -188,79 +187,77 @@ def layer_norm_no_plugin(
188187

189188
shape = weight.shape
190189
broadcasted_shape = (1,) * (len(input.shape) - len(shape)) + shape
191-
gamma = to_numpy(weight.reshape(*shape))
192-
beta = to_numpy(bias.reshape(*shape))
190+
gamma = to_numpy(weight).reshape(shape)
191+
beta = to_numpy(bias).reshape(shape)
193192

194-
axes = 0
195-
for d in range(len(shape)):
196-
axes |= 1 << (len(input.shape) - d - 1)
193+
dims = list(range(len(input.shape) - len(shape), len(input.shape)))
194+
axes = get_axes_for_reduce_op(dims)
197195

198196
# E[x]
199197
mean_expected_layer = network.add_reduce(
200198
input, trt.ReduceOperation.AVG, axes, keep_dims=True
201199
)
202200
set_layer_name(mean_expected_layer, target, f"{name}_mean_expected", source_ir)
203201

204-
# X-E[x]
205-
sub_trt = convert_binary_elementwise(
202+
# X - E[x]
203+
sub_trt = impl.elementwise.sub(
206204
network,
207205
target,
208206
source_ir,
209207
f"{name}_sub",
210-
trt.ElementWiseOperation.SUB,
211208
input,
212209
mean_expected_layer.get_output(0),
213210
)
214-
# Variance = mean(pow(x_sub_mean,2))
211+
212+
# variance = mean(pow(x_sub_mean, 2))
215213
pow_tensor = network.add_constant(
216214
(1,) * len(input.shape),
217215
trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)),
218216
)
219217
pow_tensor.name = f"{name}_power"
220-
pow_var = convert_binary_elementwise(
218+
pow_var = impl.elementwise.pow(
221219
network,
222220
target,
223221
source_ir,
224222
f"{name}_pow_var",
225-
trt.ElementWiseOperation.POW,
226223
sub_trt,
227224
pow_tensor.get_output(0),
228225
)
229226
mean_trt_layer = network.add_reduce(
230227
pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True
231228
)
232229
set_layer_name(mean_trt_layer, target, f"{name}_mean", source_ir)
233-
# Variance + eps
230+
231+
# var + eps
234232
eps_tensor = network.add_constant(
235233
(1,) * len(input.shape),
236234
trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)),
237235
)
238236
eps_tensor.name = f"{name}_eps"
239-
add_trt = convert_binary_elementwise(
237+
238+
# sqrt((var + eps))
239+
add_trt = impl.elementwise.add(
240240
network,
241241
target,
242242
source_ir,
243243
f"{name}_add",
244-
trt.ElementWiseOperation.SUM,
245244
mean_trt_layer.get_output(0),
246245
eps_tensor.get_output(0),
247246
)
248-
# SQRT((Var + eps))
249-
sqrt_trt = convert_unary(
247+
sqrt_trt = impl.unary.sqrt(
250248
network,
251249
target,
252250
source_ir,
253251
f"{name}_sqrt",
254-
trt.UnaryOperation.SQRT,
255252
add_trt,
256253
)
257-
# (x - E[x]) / sqrt((var + eps))
258-
div_trt = convert_binary_elementwise(
254+
255+
# (X - E[X]) / sqrt((var + eps))
256+
div_trt = impl.elementwise.div(
259257
network,
260258
target,
261259
source_ir,
262260
f"{name}_div_trt",
263-
trt.ElementWiseOperation.DIV,
264261
sub_trt,
265262
sqrt_trt,
266263
)
@@ -270,32 +267,113 @@ def layer_norm_no_plugin(
270267
gamma.shape, trt.Weights(np.ascontiguousarray(gamma))
271268
)
272269
gamma_tensor.name = f"{name}_gamma"
270+
273271
assert beta is not None
274272
beta_tensor = network.add_constant(
275273
gamma.shape, trt.Weights(np.ascontiguousarray(beta))
276274
)
277275
beta_tensor.name = f"{name}_beta"
276+
278277
# y * gamma + beta
279-
scale_layer = convert_binary_elementwise(
278+
scaled_y = impl.elementwise.mul(
280279
network,
281280
target,
282281
source_ir,
283282
f"{name}_scale",
284-
trt.ElementWiseOperation.PROD,
285283
div_trt,
286284
gamma_tensor.get_output(0),
287285
)
288-
return convert_binary_elementwise(
286+
return impl.elementwise.add(
289287
network,
290288
target,
291289
source_ir,
292290
name,
293-
trt.ElementWiseOperation.SUM,
294-
scale_layer,
291+
scaled_y,
295292
beta_tensor.get_output(0),
296293
)
297294

298295

296+
def native_group_norm(
297+
network: TRTNetwork,
298+
target: Target,
299+
source_ir: Optional[SourceIR],
300+
name: str,
301+
input: TRTTensor,
302+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
303+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
304+
N: int,
305+
C: int,
306+
HxW: int,
307+
group: int,
308+
eps: float,
309+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
310+
return group_norm(
311+
network,
312+
target,
313+
source_ir,
314+
name,
315+
input,
316+
group,
317+
weight,
318+
bias,
319+
eps,
320+
cudnn_enabled=True,
321+
)
322+
323+
324+
def group_norm(
325+
network: TRTNetwork,
326+
target: Target,
327+
source_ir: Optional[SourceIR],
328+
name: str,
329+
input: TRTTensor,
330+
num_groups: int,
331+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
332+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
333+
eps: float,
334+
cudnn_enabled: bool,
335+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
336+
if not isinstance(input, trt.tensorrt.ITensor):
337+
raise RuntimeError(
338+
f"LayerNorm received input {input} that is not part "
339+
"of the TensorRT region!"
340+
)
341+
342+
if weight is None:
343+
weight = to_numpy(1.0)
344+
345+
if bias is None:
346+
bias = to_numpy(0.0)
347+
348+
scale = get_trt_tensor(network, weight, "scale")
349+
bias = get_trt_tensor(network, bias, "bias")
350+
351+
eps_field = trt.PluginField(
352+
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32
353+
)
354+
num_groups_filed = trt.PluginField(
355+
"num_groups", np.array(num_groups), trt.PluginFieldType.INT32
356+
)
357+
358+
field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed])
359+
360+
try:
361+
# Here's the schema of the plugin:
362+
# https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml
363+
plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1")
364+
except AssertionError:
365+
_LOGGER.error(
366+
"Unable to find group norm plugin, fall back to TensorRT implementation."
367+
)
368+
369+
layer = network.add_plugin_v2([input, scale, bias], plugin)
370+
set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir)
371+
372+
# PyTorch requires three return values: (out, mean, rstd)
373+
dummy_tensor = torch.tensor(0)
374+
return layer.get_output(0), dummy_tensor, dummy_tensor
375+
376+
299377
def softmax(
300378
network: TRTNetwork,
301379
target: Target,

py/torch_tensorrt/dynamo/conversion/op_evaluators.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import operator
33
from typing import Dict, Sequence, Tuple, Union
44

5+
import torch
56
from torch.fx.node import Argument, Node, Target
67
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
78

@@ -18,7 +19,8 @@ def getitem_validator(getitem_node: Node) -> bool:
1819

1920

2021
# TODO: Subsequent evaluators should be registered here with their own validators
21-
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
22+
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) # type: ignore[misc]
23+
@dynamo_tensorrt_converter(torch.ops.aten.detach.default) # type: ignore[misc]
2224
def generic_evaluator(
2325
network: TRTNetwork,
2426
target: Target,
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
from torch.testing._internal.common_utils import run_tests
3+
from torch_tensorrt import Input
4+
5+
from .harness import DispatchTestCase
6+
7+
8+
class TestGroupNormConverter(DispatchTestCase):
9+
def test_groupnorm(self):
10+
class TestModule(torch.nn.Module):
11+
def __init__(self):
12+
super().__init__()
13+
self.gn = torch.nn.GroupNorm(2, 6)
14+
15+
def forward(self, x):
16+
return self.gn(x)
17+
18+
inputs = [torch.randn(1, 6, 224, 224)]
19+
self.run_test(
20+
TestModule(),
21+
inputs,
22+
expected_ops={torch.ops.aten.native_group_norm.default},
23+
disable_passes=True,
24+
)
25+
26+
def test_groupnorm_with_dynamic_shape(self):
27+
class TestModule(torch.nn.Module):
28+
def __init__(self):
29+
super().__init__()
30+
self.gn = torch.nn.GroupNorm(2, 6)
31+
32+
def forward(self, x):
33+
return self.gn(x)
34+
35+
input_specs = [
36+
Input(
37+
shape=(-1, 6, 5),
38+
dtype=torch.float32,
39+
shape_ranges=[((2, 6, 5), (6, 6, 5), (10, 6, 5))],
40+
),
41+
]
42+
43+
self.run_test_with_dynamic_shape(
44+
TestModule(),
45+
input_specs,
46+
expected_ops={torch.ops.aten.native_group_norm.default},
47+
disable_passes=True,
48+
)
49+
50+
51+
if __name__ == "__main__":
52+
run_tests()

0 commit comments

Comments
 (0)