Skip to content

Commit d78a846

Browse files
committed
chore: revert layer_norm test
1 parent b0e92d8 commit d78a846

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

tests/py/dynamo/conversion/test_layer_norm_aten.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,31 @@ def forward(self, x):
2424
inputs,
2525
)
2626

27+
def test_layernorm_with_dynamic_shape(self):
28+
class LayerNorm(torch.nn.Module):
29+
def forward(self, x):
30+
return torch.ops.aten.layer_norm.default(
31+
x,
32+
torch.tensor([3, 224, 224]),
33+
torch.ones((3, 224, 224)),
34+
torch.zeros((3, 224, 224)),
35+
1e-05,
36+
True,
37+
)
38+
39+
input_specs = [
40+
Input(
41+
shape=(-1, 3, 224, 224),
42+
dtype=torch.float32,
43+
shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))],
44+
),
45+
]
46+
47+
self.run_test_with_dynamic_shape(
48+
LayerNorm(),
49+
input_specs,
50+
)
51+
2752

2853
class TestNativeLayerNormConverter(DispatchTestCase):
2954
def test_layer_norm(self):
@@ -43,6 +68,30 @@ def forward(self, x):
4368
inputs,
4469
)
4570

71+
def test_layernorm_with_dynamic_shape(self):
72+
class LayerNorm(torch.nn.Module):
73+
def forward(self, x):
74+
return torch.ops.aten.native_layer_norm.default(
75+
x,
76+
torch.tensor([3, 224, 224]),
77+
torch.ones((3, 224, 224)),
78+
torch.zeros((3, 224, 224)),
79+
1e-05,
80+
)[0]
81+
82+
input_specs = [
83+
Input(
84+
shape=(-1, 3, 224, 224),
85+
dtype=torch.float32,
86+
shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))],
87+
),
88+
]
89+
90+
self.run_test_with_dynamic_shape(
91+
LayerNorm(),
92+
input_specs,
93+
)
94+
4695

4796
if __name__ == "__main__":
4897
run_tests()

0 commit comments

Comments
 (0)