We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 12204ab commit 4e47a38Copy full SHA for 4e47a38
tests/py/dynamo/conversion/test_layer_norm_aten.py
@@ -64,18 +64,31 @@ def forward(self, x):
64
65
66
class TestNativeLayerNormConverter(DispatchTestCase):
67
- def test_layer_norm(self):
+ @parameterized.expand(
68
+ [
69
+ (
70
+ (5, 3, 2, 4),
71
72
+ 4,
73
+ ],
74
+ ),
75
+ ((5, 3, 2, 4), [2, 4]),
76
+ ((5, 3, 2, 4), [3, 2, 4]),
77
+ ((5, 3, 2, 4), [5, 3, 2, 4]),
78
+ ]
79
+ )
80
+ def test_layer_norm(self, input_shape, normalized_shape, eps=1e-05):
81
class LayerNorm(torch.nn.Module):
82
def forward(self, x):
83
return torch.ops.aten.native_layer_norm.default(
84
x,
- torch.tensor([3, 224, 224]),
- torch.ones((3, 224, 224)),
- torch.zeros((3, 224, 224)),
- 1e-05,
85
+ normalized_shape,
86
+ torch.randn(normalized_shape),
87
88
+ eps,
89
)[0]
90
- inputs = [torch.randn(1, 3, 224, 224)]
91
+ inputs = [torch.randn(input_shape)]
92
self.run_test(
93
LayerNorm(),
94
inputs,
0 commit comments