@@ -24,6 +24,31 @@ def forward(self, x):
24
24
inputs ,
25
25
)
26
26
27
+ def test_layernorm_with_dynamic_shape (self ):
28
+ class LayerNorm (torch .nn .Module ):
29
+ def forward (self , x ):
30
+ return torch .ops .aten .layer_norm .default (
31
+ x ,
32
+ torch .tensor ([3 , 224 , 224 ]),
33
+ torch .ones ((3 , 224 , 224 )),
34
+ torch .zeros ((3 , 224 , 224 )),
35
+ 1e-05 ,
36
+ True ,
37
+ )
38
+
39
+ input_specs = [
40
+ Input (
41
+ shape = (- 1 , 3 , 224 , 224 ),
42
+ dtype = torch .float32 ,
43
+ shape_ranges = [((1 , 3 , 224 , 224 ), (1 , 3 , 224 , 224 ), (2 , 3 , 224 , 224 ))],
44
+ ),
45
+ ]
46
+
47
+ self .run_test_with_dynamic_shape (
48
+ LayerNorm (),
49
+ input_specs ,
50
+ )
51
+
27
52
28
53
class TestNativeLayerNormConverter (DispatchTestCase ):
29
54
def test_layer_norm (self ):
@@ -43,6 +68,30 @@ def forward(self, x):
43
68
inputs ,
44
69
)
45
70
71
+ def test_layernorm_with_dynamic_shape (self ):
72
+ class LayerNorm (torch .nn .Module ):
73
+ def forward (self , x ):
74
+ return torch .ops .aten .native_layer_norm .default (
75
+ x ,
76
+ torch .tensor ([3 , 224 , 224 ]),
77
+ torch .ones ((3 , 224 , 224 )),
78
+ torch .zeros ((3 , 224 , 224 )),
79
+ 1e-05 ,
80
+ )[0 ]
81
+
82
+ input_specs = [
83
+ Input (
84
+ shape = (- 1 , 3 , 224 , 224 ),
85
+ dtype = torch .float32 ,
86
+ shape_ranges = [((1 , 3 , 224 , 224 ), (1 , 3 , 224 , 224 ), (2 , 3 , 224 , 224 ))],
87
+ ),
88
+ ]
89
+
90
+ self .run_test_with_dynamic_shape (
91
+ LayerNorm (),
92
+ input_specs ,
93
+ )
94
+
46
95
47
96
if __name__ == "__main__" :
48
97
run_tests ()
0 commit comments