Skip to content

Commit 4e47a38

Browse files
committed
add more test cases
1 parent 12204ab commit 4e47a38

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

tests/py/dynamo/conversion/test_layer_norm_aten.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,31 @@ def forward(self, x):
6464

6565

6666
class TestNativeLayerNormConverter(DispatchTestCase):
67-
def test_layer_norm(self):
67+
@parameterized.expand(
68+
[
69+
(
70+
(5, 3, 2, 4),
71+
[
72+
4,
73+
],
74+
),
75+
((5, 3, 2, 4), [2, 4]),
76+
((5, 3, 2, 4), [3, 2, 4]),
77+
((5, 3, 2, 4), [5, 3, 2, 4]),
78+
]
79+
)
80+
def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
6881
class LayerNorm(torch.nn.Module):
6982
def forward(self, x):
7083
return torch.ops.aten.native_layer_norm.default(
7184
x,
72-
torch.tensor([3, 224, 224]),
73-
torch.ones((3, 224, 224)),
74-
torch.zeros((3, 224, 224)),
75-
1e-05,
85+
normalized_shape,
86+
torch.randn(normalized_shape),
87+
torch.randn(normalized_shape),
88+
eps,
7689
)[0]
7790

78-
inputs = [torch.randn(1, 3, 224, 224)]
91+
inputs = [torch.randn(input_shape)]
7992
self.run_test(
8093
LayerNorm(),
8194
inputs,

0 commit comments

Comments
 (0)