Skip to content

Commit fbed329

Browse files
committed
add more test cases
1 parent ade7ed6 commit fbed329

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

tests/py/dynamo/conversion/test_layer_norm_aten.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,31 @@ def forward(self, x):
3939

4040

4141
class TestNativeLayerNormConverter(DispatchTestCase):
42-
def test_layer_norm(self):
42+
@parameterized.expand(
43+
[
44+
(
45+
(5, 3, 2, 4),
46+
[
47+
4,
48+
],
49+
),
50+
((5, 3, 2, 4), [2, 4]),
51+
((5, 3, 2, 4), [3, 2, 4]),
52+
((5, 3, 2, 4), [5, 3, 2, 4]),
53+
]
54+
)
55+
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
4356
class LayerNorm(torch.nn.Module):
4457
def forward(self, x):
4558
return torch.ops.aten.native_layer_norm.default(
4659
x,
47-
torch.tensor([3, 224, 224]),
48-
torch.ones((3, 224, 224)),
49-
torch.zeros((3, 224, 224)),
50-
1e-05,
60+
normalized_shape,
61+
torch.randn(normalized_shape),
62+
torch.randn(normalized_shape),
63+
eps,
5164
)[0]
5265

53-
inputs = [torch.randn(1, 3, 224, 224)]
66+
inputs = [torch.randn(input_shape)]
5467
self.run_test(
5568
LayerNorm(),
5669
inputs,

0 commit comments

Comments
 (0)