@@ -126,6 +126,69 @@ def forward(self, x):
126
126
output_dtypes = [torch .bool ],
127
127
)
128
128
129
+ @parameterized .expand (
130
+ [
131
+ ((2 , 3 , 4 ), torch .int , - 5 , 0 ),
132
+ ((6 , 7 , 5 , 4 , 5 ), torch .int , - 5 , 5 ),
133
+ ((1 , 5 , 2 , 1 ), torch .int , - 5 , 5 ),
134
+ ]
135
+ )
136
+ def test_any_default_bool_dtype (self , input_shape , dtype , low , high ):
137
+ class Any (nn .Module ):
138
+ def forward (self , x ):
139
+ return torch .ops .aten .any .default (x )
140
+
141
+ inputs = [torch .randint (low , high , input_shape , dtype = dtype ).bool ()]
142
+ self .run_test (
143
+ Any (),
144
+ inputs ,
145
+ output_dtypes = [torch .bool ],
146
+ )
147
+
148
+ @parameterized .expand (
149
+ [
150
+ ((3 , 2 , 4 ), 1 , True , torch .int , 0 , 5 ),
151
+ ((2 , 3 , 4 , 5 ), 3 , True , torch .int , - 10 , 10 ),
152
+ ((2 , 3 , 4 , 5 ), 2 , False , torch .int32 , - 5 , 0 ),
153
+ ((6 , 7 , 5 , 4 , 5 ), 4 , False , torch .int32 , - 5 , 5 ),
154
+ ((1 , 5 , 2 , 1 ), - 4 , False , torch .int32 , - 5 , 5 ),
155
+ ]
156
+ )
157
+ def test_any_dim_bool_dtype (self , input_shape , dim , keep_dims , dtype , low , high ):
158
+ class AnyDim (nn .Module ):
159
+ def forward (self , x ):
160
+ return torch .ops .aten .any .dim (x , dim , keep_dims )
161
+
162
+ inputs = [torch .randint (low , high , input_shape , dtype = dtype ).bool ()]
163
+ self .run_test (
164
+ AnyDim (),
165
+ inputs ,
166
+ output_dtypes = [torch .bool ],
167
+ )
168
+
169
+ @parameterized .expand (
170
+ [
171
+ ((3 , 2 , 4 ), [1 ], True , torch .int , 0 , 5 ),
172
+ ((2 , 1 , 4 , 5 ), [0 , 3 ], True , torch .int , - 10 , 10 ),
173
+ ((2 , 3 , 4 , 5 ), [0 , 1 , 2 , 3 ], False , torch .int32 , - 5 , 0 ),
174
+ ((6 , 7 , 5 , 4 , 5 ), [1 , 3 , 4 ], False , torch .int32 , - 5 , 5 ),
175
+ ((1 , 5 , 2 , 1 ), [- 3 , - 1 ], False , torch .int32 , - 5 , 5 ),
176
+ ]
177
+ )
178
+ def test_any_dims_tuple_bool_dtype (
179
+ self , input_shape , dims , keep_dims , dtype , low , high
180
+ ):
181
+ class AnyDims (nn .Module ):
182
+ def forward (self , x ):
183
+ return torch .ops .aten .any .dims (x , dims , keep_dims )
184
+
185
+ inputs = [torch .randint (low , high , input_shape , dtype = dtype ).bool ()]
186
+ self .run_test (
187
+ AnyDims (),
188
+ inputs ,
189
+ output_dtypes = [torch .bool ],
190
+ )
191
+
129
192
130
193
if __name__ == "__main__" :
131
194
run_tests ()
0 commit comments