Skip to content

Commit 1ca262f

Browse files
Add dynamic shape support for layer_norm/native_group_norm/group_norm (#2908)
1 parent d9b2840 commit 1ca262f

File tree

4 files changed

+160
-139
lines changed

4 files changed

+160
-139
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ def aten_ops_layer_norm(
165165

166166

167167
@dynamo_tensorrt_converter(
168-
torch.ops.aten.native_group_norm.default, capability_validator=one_user_validator
168+
torch.ops.aten.native_group_norm.default,
169+
capability_validator=one_user_validator,
170+
supports_dynamic_shapes=True,
169171
)
170172
@enforce_tensor_types(
171173
{
@@ -195,8 +197,16 @@ def aten_ops_native_group_norm(
195197
)
196198

197199

198-
@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default)
199-
@dynamo_tensorrt_converter(torch.ops.aten.group_norm)
200+
@dynamo_tensorrt_converter(
201+
torch.ops.aten.group_norm.default,
202+
capability_validator=one_user_validator,
203+
supports_dynamic_shapes=True,
204+
)
205+
@dynamo_tensorrt_converter(
206+
torch.ops.aten.group_norm,
207+
capability_validator=one_user_validator,
208+
supports_dynamic_shapes=True,
209+
)
200210
@enforce_tensor_types(
201211
{
202212
0: (TRTTensor,),
@@ -581,7 +591,7 @@ def aten_ops_neg(
581591

582592

583593
try:
584-
import modelopt.torch.quantization as mtq
594+
import modelopt.torch.quantization as mtq # noqa: F401
585595

586596
assert torch.ops.trt.quantize_fp8.default
587597
except Exception as e:

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 55 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def layer_norm(
111111
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:
112112
dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape)))
113113
axes = get_axes_for_reduce_op(dims)
114-
115114
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
116115
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
117116
if tuple(input.shape) != tuple(weight.shape):
@@ -153,157 +152,80 @@ def native_group_norm(
153152
assert (
154153
len(input.shape) >= 3
155154
), f"The input dimension should not be less than 3, got {len(input.shape)}!"
156-
B, C = input.shape[0], input.shape[1]
157155

156+
B = input.shape[0]
157+
# if C is provided, it must be as same as the channel from the input shape,
158+
# else if C is zero, we should get the channel from the input shape
159+
if C == 0:
160+
C = input.shape[1]
161+
assert (
162+
C == input.shape[1]
163+
), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
158164
# Groups are a subdivision of the channel dimension.
159165
assert (
160166
C % group == 0
161167
), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
168+
input = get_trt_tensor(ctx, input, f"{name}_input")
162169

163-
if weight is None:
164-
weight = to_numpy(1.0)
170+
shape = list(input.shape)
165171

166-
if bias is None:
167-
bias = to_numpy(0.0)
172+
for i, s in enumerate(shape):
173+
if i == 0 and s > 0:
174+
shape[i] = B * group
175+
elif i == 1:
176+
shape[i] = C // group
177+
elif i > 1 and s == -1:
178+
shape[i] = 0
168179

169180
# Normalize every group.
170181
reshaped_input = impl.shuffle.reshape(
171182
ctx,
172183
target,
173184
source_ir,
174-
name,
185+
f"{name}_reshape_input",
175186
input,
176-
(B * group, -1),
177-
)
178-
179-
dim = 1
180-
181-
# E[X]
182-
mean_trt = impl.reduce.mean(
183-
ctx,
184-
target,
185-
source_ir,
186-
f"{name}_mean",
187-
reshaped_input,
188-
dim,
189-
True,
190-
)
191-
192-
# X - E[X]
193-
sub_trt = impl.elementwise.sub(
194-
ctx,
195-
target,
196-
source_ir,
197-
f"{name}_sub",
198-
reshaped_input,
199-
mean_trt,
200-
)
201-
202-
# variance = mean(pow(sub_trt, 2))
203-
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
204-
pow_var = impl.elementwise.pow(
205-
ctx,
206-
target,
207-
source_ir,
208-
f"{name}_pow",
209-
sub_trt,
210-
pow_trt,
211-
)
212-
213-
var_trt = impl.reduce.mean(
214-
ctx,
215-
target,
216-
source_ir,
217-
f"{name}_mean_var",
218-
pow_var,
219-
dim,
220-
True,
221-
)
222-
223-
# sqrt((var + eps))
224-
eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
225-
add_trt = impl.elementwise.add(
226-
ctx,
227-
target,
228-
source_ir,
229-
f"{name}_add",
230-
var_trt,
231-
eps_trt,
232-
)
233-
sqrt_trt = impl.unary.sqrt(
234-
ctx,
235-
target,
236-
source_ir,
237-
f"{name}_sqrt",
238-
add_trt,
239-
)
240-
241-
# y = (X - E[X]) / sqrt((var + eps))
242-
div_trt = impl.elementwise.div(
243-
ctx,
244-
target,
245-
source_ir,
246-
f"{name}_div",
247-
sub_trt,
248-
sqrt_trt,
249-
)
250-
251-
# y * gamma + beta
252-
gamma_trt = get_trt_tensor(ctx, weight, f"{name}_gamma")
253-
beta_trt = get_trt_tensor(ctx, bias, f"{name}_beta")
254-
255-
output = impl.shuffle.reshape(
256-
ctx,
257-
target,
258-
source_ir,
259-
f"{name}_reshape_div",
260-
div_trt,
261-
input.shape,
262-
)
263-
264-
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
265-
266-
reshaped_gamma = impl.shuffle.reshape(
267-
ctx,
268-
target,
269-
source_ir,
270-
f"{name}_reshape_gamma",
271-
gamma_trt,
272-
weight_bias_shape,
273-
)
274-
275-
output = impl.elementwise.mul(
276-
ctx,
277-
target,
278-
source_ir,
279-
f"{name}_mul_gamma",
280-
output,
281-
reshaped_gamma,
282-
)
283-
284-
reshaped_bias = impl.shuffle.reshape(
285-
ctx,
286-
target,
287-
source_ir,
288-
f"{name}_reshape_beta",
289-
beta_trt,
290-
weight_bias_shape,
187+
shape,
291188
)
292189

293-
output = impl.elementwise.add(
294-
ctx,
295-
target,
296-
source_ir,
297-
f"{name}_add_beta",
298-
output,
299-
reshaped_bias,
190+
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
191+
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
192+
if tuple(reshaped_input.shape) != tuple(weight.shape):
193+
weight = impl.slice.expand(
194+
ctx,
195+
target,
196+
source_ir,
197+
f"{name}_expand_weight",
198+
weight,
199+
reshaped_input.shape,
200+
)
201+
if tuple(reshaped_input.shape) != tuple(bias.shape):
202+
bias = impl.slice.expand(
203+
ctx, target, source_ir, f"{name}_expand_bias", bias, reshaped_input.shape
204+
)
205+
dims = list(range(1, len(input.shape)))
206+
axes = get_axes_for_reduce_op(dims)
207+
group_norm = ctx.net.add_normalization(reshaped_input, weight, bias, axes)
208+
group_norm.epsilon = eps
209+
group_norm.compute_precision = input.dtype
210+
set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
211+
output = group_norm.get_output(0)
212+
213+
shape = list(output.shape)
214+
for i, s in enumerate(shape):
215+
if i == 0 and s > 0:
216+
shape[i] = B
217+
elif i == 1:
218+
shape[i] = C
219+
elif i > 1 and s == -1:
220+
shape[i] = 0
221+
222+
reshaped_output = impl.shuffle.reshape(
223+
ctx, target, source_ir, f"{name}_reshape_output", output, shape
300224
)
301-
302225
if return_mean_rstd:
303226
# return fake mean and rstd for now
304-
return output, None, None
305-
306-
return output
227+
return reshaped_output, None, None
228+
return reshaped_output
307229

308230

309231
def group_norm(

tests/py/dynamo/conversion/test_group_norm_aten.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from parameterized import parameterized
23
from torch.testing._internal.common_utils import run_tests
34
from torch_tensorrt import Input
45

@@ -43,6 +44,31 @@ def forward(self, x):
4344
inputs,
4445
)
4546

47+
def test_groupnorm_with_dynamic_shape(self):
48+
class GroupNorm(torch.nn.Module):
49+
def forward(self, x):
50+
return torch.ops.aten.group_norm.default(
51+
x,
52+
2,
53+
torch.ones((6,)),
54+
torch.zeros((6,)),
55+
1e-05,
56+
True,
57+
)
58+
59+
input_specs = [
60+
Input(
61+
dtype=torch.float32,
62+
min_shape=(3, 6, 24, 24),
63+
opt_shape=(5, 6, 24, 24),
64+
max_shape=(8, 6, 48, 24),
65+
),
66+
]
67+
self.run_test_with_dynamic_shape(
68+
GroupNorm(),
69+
input_specs,
70+
)
71+
4672

4773
class TestNativeGroupNormConverter(DispatchTestCase):
4874
def test_groupnorm1d(self):
@@ -86,6 +112,43 @@ def forward(self, x):
86112
inputs,
87113
)
88114

115+
@parameterized.expand(
116+
[
117+
(5, 4, 4, 2, (2, 4, 2), (3, 4, 2), (5, 4, 4)),
118+
(5, 4, 2 * 2, 2, (2, 4, 2, 2), (3, 4, 2, 2), (5, 4, 2, 2)),
119+
(5, 9, 6 * 3, 3, (3, 9, 3, 3), (4, 9, 3, 3), (5, 9, 6, 3)),
120+
(8, 9, 6 * 6, 3, (3, 9, 2, 3, 2), (5, 9, 3, 3, 2), (8, 9, 6, 3, 2)),
121+
]
122+
)
123+
def test_groupnorm_with_dynamic_shape(
124+
self, N, C, HxW, groups, min_shape, opt_shape, max_shape
125+
):
126+
class GroupNorm(torch.nn.Module):
127+
def forward(self, x):
128+
return torch.ops.aten.native_group_norm.default(
129+
x,
130+
torch.ones((C,)),
131+
torch.zeros((C,)),
132+
N,
133+
C,
134+
HxW,
135+
groups,
136+
1e-5,
137+
)[0]
138+
139+
input_specs = [
140+
Input(
141+
dtype=torch.float32,
142+
min_shape=min_shape,
143+
opt_shape=opt_shape,
144+
max_shape=max_shape,
145+
),
146+
]
147+
self.run_test_with_dynamic_shape(
148+
GroupNorm(),
149+
input_specs,
150+
)
151+
89152

90153
if __name__ == "__main__":
91154
run_tests()

tests/py/dynamo/conversion/test_layer_norm_aten.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,35 @@ def forward(self, x):
8282

8383
input_specs = [
8484
Input(
85-
shape=(-1, 3, 224, 224),
8685
dtype=torch.float32,
87-
shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))],
86+
min_shape=(1, 3, 224, 224),
87+
opt_shape=(5, 3, 224, 224),
88+
max_shape=(10, 3, 224, 224),
89+
),
90+
]
91+
92+
self.run_test_with_dynamic_shape(
93+
LayerNorm(),
94+
input_specs,
95+
)
96+
97+
def test_layernorm_with_dynamic_shape_1(self):
98+
class LayerNorm(torch.nn.Module):
99+
def forward(self, x):
100+
return torch.ops.aten.native_layer_norm.default(
101+
x,
102+
torch.tensor([3]),
103+
torch.ones((3)),
104+
torch.zeros((3)),
105+
1e-05,
106+
)[0]
107+
108+
input_specs = [
109+
Input(
110+
dtype=torch.float32,
111+
min_shape=(1, 2, 3),
112+
opt_shape=(3, 3, 3),
113+
max_shape=(4, 5, 3),
88114
),
89115
]
90116

0 commit comments

Comments
 (0)