Skip to content

Commit 501550a

Browse files
swolchokYIWENX14
authored andcommitted
don't test Half/BFloat16 for cdist_forward in ATen mode (#7980)
Unbreaks internal tests.
1 parent 09c5ea5 commit 501550a

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

kernels/test/op_cdist_forward_test.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
1010
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/kernels/test/supported_features.h>
1112
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1213
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1314
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
@@ -45,6 +46,12 @@ class OpCdistForwardOutTest : public ::testing::Test {
4546
void test_dtype() {
4647
TensorFactory<DTYPE> tf;
4748

49+
if ((DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) &&
50+
torch::executor::testing::SupportedFeatures::get()->is_aten) {
51+
// ATen doesn't support Half/BFloat for this op.
52+
return;
53+
}
54+
4855
Tensor x1 = tf.make({2, 1, 4, 3}, {0, 1, 2, 3, 5, 4, 3, -3, 7, 1, 6, 2,
4956
-1, 5, 1, 1, -2, 1, 5, 4, 3, 2, -1, 5});
5057
Tensor x2 = tf.make(

0 commit comments

Comments
 (0)