We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ade7ed6 commit fbed329Copy full SHA for fbed329
tests/py/dynamo/conversion/test_layer_norm_aten.py
@@ -39,18 +39,31 @@ def forward(self, x):
39
40
41
class TestNativeLayerNormConverter(DispatchTestCase):
42
- def test_layer_norm(self):
+ @parameterized.expand(
43
+ [
44
+ (
45
+ (5, 3, 2, 4),
46
47
+ 4,
48
+ ],
49
+ ),
50
+ ((5, 3, 2, 4), [2, 4]),
51
+ ((5, 3, 2, 4), [3, 2, 4]),
52
+ ((5, 3, 2, 4), [5, 3, 2, 4]),
53
+ ]
54
+ )
55
+ def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
56
class LayerNorm(torch.nn.Module):
57
def forward(self, x):
58
return torch.ops.aten.native_layer_norm.default(
59
x,
- torch.tensor([3, 224, 224]),
- torch.ones((3, 224, 224)),
- torch.zeros((3, 224, 224)),
- 1e-05,
60
+ normalized_shape,
61
+ torch.randn(normalized_shape),
62
63
+ eps,
64
)[0]
65
- inputs = [torch.randn(1, 3, 224, 224)]
66
+ inputs = [torch.randn(input_shape)]
67
self.run_test(
68
LayerNorm(),
69
inputs,
0 commit comments