@@ -328,6 +328,9 @@ TEST_F(OpVarOutTest, InvalidDTypeDies) {
328
328
}
329
329
330
330
TEST_F (OpVarOutTest, AllFloatInputFloatOutputPasses) {
331
+ if (torch::executor::testing::SupportedFeatures::get ()->is_aten ) {
332
+ GTEST_SKIP () << " ATen supports fewer dtypes" ;
333
+ }
331
334
// Use a two layer switch to hanldle each possible data pair
332
335
#define TEST_KERNEL (INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE ) \
333
336
test_var_out_dtype<ScalarType::INPUT_DTYPE, ScalarType::OUTPUT_DTYPE>();
@@ -340,6 +343,22 @@ TEST_F(OpVarOutTest, AllFloatInputFloatOutputPasses) {
340
343
#undef TEST_KERNEL
341
344
}
342
345
346
+ TEST_F (OpVarOutTest, AllFloatInputFloatOutputPasses_Aten) {
347
+ if (!torch::executor::testing::SupportedFeatures::get ()->is_aten ) {
348
+ GTEST_SKIP () << " ATen-specific variant of test case" ;
349
+ }
350
+ // Use a two layer switch to hanldle each possible data pair
351
+ #define TEST_KERNEL (INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE ) \
352
+ test_var_out_dtype<ScalarType::INPUT_DTYPE, ScalarType::OUTPUT_DTYPE>();
353
+
354
+ #define TEST_ENTRY (INPUT_CTYPE, INPUT_DTYPE ) \
355
+ ET_FORALL_FLOAT_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
356
+
357
+ ET_FORALL_FLOAT_TYPES (TEST_ENTRY);
358
+ #undef TEST_ENTRY
359
+ #undef TEST_KERNEL
360
+ }
361
+
343
362
TEST_F (OpVarOutTest, InfinityAndNANTest) {
344
363
TensorFactory<ScalarType::Float> tf_float;
345
364
// clang-format off
0 commit comments