Skip to content

Commit 14a3260

Browse files
pytorchbotswolchok
andauthored
Fix ATen mode op_var_test
Was broken, works now. Differential Revision: [D68927724](https://our.internmc.facebook.com/intern/diff/D68927724/) ghstack-source-id: 263956677 Pull Request resolved: #8080 Co-authored-by: Scott Wolchok <[email protected]>
1 parent c72b62e commit 14a3260

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

kernels/test/op_var_test.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ TEST_F(OpVarOutTest, InvalidDTypeDies) {
328328
}
329329

330330
TEST_F(OpVarOutTest, AllFloatInputFloatOutputPasses) {
331+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
332+
GTEST_SKIP() << "ATen supports fewer dtypes";
333+
}
331334
// Use a two layer switch to hanldle each possible data pair
332335
#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE) \
333336
test_var_out_dtype<ScalarType::INPUT_DTYPE, ScalarType::OUTPUT_DTYPE>();
@@ -340,6 +343,22 @@ TEST_F(OpVarOutTest, AllFloatInputFloatOutputPasses) {
340343
#undef TEST_KERNEL
341344
}
342345

346+
TEST_F(OpVarOutTest, AllFloatInputFloatOutputPasses_Aten) {
347+
if (!torch::executor::testing::SupportedFeatures::get()->is_aten) {
348+
GTEST_SKIP() << "ATen-specific variant of test case";
349+
}
350+
// Use a two layer switch to hanldle each possible data pair
351+
#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE) \
352+
test_var_out_dtype<ScalarType::INPUT_DTYPE, ScalarType::OUTPUT_DTYPE>();
353+
354+
#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
355+
ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
356+
357+
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
358+
#undef TEST_ENTRY
359+
#undef TEST_KERNEL
360+
}
361+
343362
TEST_F(OpVarOutTest, InfinityAndNANTest) {
344363
TensorFactory<ScalarType::Float> tf_float;
345364
// clang-format off

0 commit comments

Comments
 (0)