Skip to content

Commit 0d6a098

Browse files
swolchokfacebook-github-bot
authored andcommitted
Support bf16 for binary logical ops (#5706)
Summary: Pull Request resolved: #5706 ghstack-source-id: 245701975 exported-using-ghexport Reviewed By: manuelcandales Differential Revision: D63486223 fbshipit-source-id: de95de504b0434a3a034aa58ba685e6b783b3999
1 parent c48d867 commit 0d6a098

File tree

8 files changed

+144
-33
lines changed

8 files changed

+144
-33
lines changed

kernels/portable/cpu/pattern/binary_ufunc_realb_realb_to_realb_logical.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ Tensor& binary_ufunc_realb_realb_to_realb_logical(
3434
ScalarType b_type = b.scalar_type();
3535
ScalarType out_type = out.scalar_type();
3636

37-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, __func__, CTYPE_A, [&]() {
38-
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, __func__, CTYPE_B, [&]() {
39-
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() {
37+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, __func__, CTYPE_A, [&]() {
38+
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, __func__, CTYPE_B, [&]() {
39+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, __func__, CTYPE_OUT, [&]() {
4040
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
4141
[fn](const CTYPE_A val_a, const CTYPE_B val_b) {
4242
bool a_casted = static_cast<bool>(val_a);

kernels/test/BinaryLogicalOpTest.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/BinaryLogicalOpTest.h>
10+
11+
namespace torch::executor::testing {
12+
13+
void BinaryLogicalOpTest::test_all_dtypes() {
14+
#define TEST_ENTRY(ctype, dtype) \
15+
test_op_out<ScalarType::dtype, ScalarType::Double, ScalarType::Double>();
16+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
17+
#undef TEST_ENTRY
18+
#define TEST_ENTRY(ctype, dtype) \
19+
test_op_out<ScalarType::Double, ScalarType::dtype, ScalarType::Double>();
20+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
21+
#undef TEST_ENTRY
22+
#define TEST_ENTRY(ctype, dtype) \
23+
test_op_out<ScalarType::Double, ScalarType::Double, ScalarType::dtype>();
24+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
25+
#undef TEST_ENTRY
26+
}
27+
} // namespace torch::executor::testing

kernels/test/BinaryLogicalOpTest.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/kernels/test/TestUtil.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15+
16+
namespace torch::executor::testing {
17+
class BinaryLogicalOpTest : public OperatorTest {
18+
protected:
19+
// Implement this to call the torch::executor::aten::op_outf function for the
20+
// op.
21+
virtual exec_aten::Tensor& op_out(
22+
const exec_aten::Tensor& lhs,
23+
const exec_aten::Tensor& rhs,
24+
exec_aten::Tensor& out) = 0;
25+
26+
// Scalar reference implementation of the function in question for testing.
27+
virtual double op_reference(double x, double y) const = 0;
28+
29+
template <
30+
exec_aten::ScalarType IN_DTYPE,
31+
exec_aten::ScalarType IN_DTYPE2,
32+
exec_aten::ScalarType OUT_DTYPE>
33+
void test_op_out() {
34+
TensorFactory<IN_DTYPE> tf_in;
35+
TensorFactory<IN_DTYPE2> tf_in2;
36+
TensorFactory<OUT_DTYPE> tf_out;
37+
38+
exec_aten::Tensor out = tf_out.zeros({1, 4});
39+
40+
using CTYPE1 = typename decltype(tf_in)::ctype;
41+
std::vector<CTYPE1> test_vector1 = {0, CTYPE1(-1), CTYPE1(0), CTYPE1(31)};
42+
43+
using CTYPE2 = typename decltype(tf_in2)::ctype;
44+
std::vector<CTYPE2> test_vector2 = {
45+
CTYPE2(0),
46+
CTYPE2(0),
47+
CTYPE2(15),
48+
CTYPE2(12),
49+
};
50+
51+
std::vector<typename decltype(tf_out)::ctype> expected_vector;
52+
for (int ii = 0; ii < test_vector1.size(); ++ii) {
53+
expected_vector.push_back(
54+
op_reference(test_vector1[ii], test_vector2[ii]));
55+
}
56+
57+
op_out(
58+
tf_in.make({1, 4}, test_vector1),
59+
tf_in2.make({1, 4}, test_vector2),
60+
out);
61+
62+
EXPECT_TENSOR_CLOSE(out, tf_out.make({1, 4}, expected_vector));
63+
}
64+
65+
void test_all_dtypes();
66+
};
67+
68+
#define IMPLEMENT_BINARY_LOGICAL_OP_TEST(TestName) \
69+
TEST_F(TestName, SimpleTestAllTypes) { \
70+
test_all_dtypes(); \
71+
}
72+
} // namespace torch::executor::testing

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ add_custom_target(
6868
)
6969

7070
set(all_test_sources
71+
"BinaryLogicalOpTest.cpp"
7172
"op__to_dim_order_copy_test.cpp"
7273
"op_abs_test.cpp"
7374
"op_acos_test.cpp"

kernels/test/op_logical_and_test.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,26 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/kernels/test/BinaryLogicalOpTest.h>
910
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10-
#include <executorch/kernels/test/TestUtil.h>
11-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12-
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13-
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
1411

1512
#include <gtest/gtest.h>
1613

17-
using namespace ::testing;
18-
using exec_aten::ScalarType;
1914
using exec_aten::Tensor;
20-
using torch::executor::testing::TensorFactory;
2115

22-
class OpLogicalAndTest : public OperatorTest {
16+
class OpLogicalAndTest : public torch::executor::testing::BinaryLogicalOpTest {
2317
protected:
24-
Tensor&
25-
op_logical_and_out(const Tensor& self, const Tensor& other, Tensor& out) {
18+
Tensor& op_out(const Tensor& self, const Tensor& other, Tensor& out)
19+
override {
2620
return torch::executor::aten::logical_and_outf(context_, self, other, out);
2721
}
22+
23+
double op_reference(double x, double y) const override {
24+
uint64_t lhs, rhs;
25+
std::memcpy(&lhs, &x, sizeof(lhs));
26+
std::memcpy(&rhs, &y, sizeof(rhs));
27+
return lhs && rhs;
28+
}
2829
};
30+
31+
IMPLEMENT_BINARY_LOGICAL_OP_TEST(OpLogicalAndTest)

kernels/test/op_logical_or_test.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,26 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/kernels/test/BinaryLogicalOpTest.h>
910
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10-
#include <executorch/kernels/test/TestUtil.h>
11-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12-
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13-
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
1411

1512
#include <gtest/gtest.h>
1613

17-
using namespace ::testing;
18-
using exec_aten::ScalarType;
1914
using exec_aten::Tensor;
20-
using torch::executor::testing::TensorFactory;
2115

22-
class OpLogicalOrTest : public OperatorTest {
16+
class OpLogicalOrTest : public torch::executor::testing::BinaryLogicalOpTest {
2317
protected:
24-
Tensor&
25-
op_logical_or_out(const Tensor& self, const Tensor& other, Tensor& out) {
18+
Tensor& op_out(const Tensor& self, const Tensor& other, Tensor& out)
19+
override {
2620
return torch::executor::aten::logical_or_outf(context_, self, other, out);
2721
}
22+
23+
double op_reference(double x, double y) const override {
24+
uint64_t lhs, rhs;
25+
std::memcpy(&lhs, &x, sizeof(lhs));
26+
std::memcpy(&rhs, &y, sizeof(rhs));
27+
return lhs || rhs;
28+
}
2829
};
30+
31+
IMPLEMENT_BINARY_LOGICAL_OP_TEST(OpLogicalOrTest)

kernels/test/op_logical_xor_test.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,26 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/kernels/test/BinaryLogicalOpTest.h>
910
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10-
#include <executorch/kernels/test/TestUtil.h>
11-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12-
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13-
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
1411

1512
#include <gtest/gtest.h>
1613

17-
using namespace ::testing;
18-
using exec_aten::ScalarType;
1914
using exec_aten::Tensor;
20-
using torch::executor::testing::TensorFactory;
2115

22-
class OpLogicalXorTest : public OperatorTest {
16+
class OpLogicalXorTest : public torch::executor::testing::BinaryLogicalOpTest {
2317
protected:
24-
Tensor&
25-
op_logical_xor_out(const Tensor& self, const Tensor& other, Tensor& out) {
18+
Tensor& op_out(const Tensor& self, const Tensor& other, Tensor& out)
19+
override {
2620
return torch::executor::aten::logical_xor_outf(context_, self, other, out);
2721
}
22+
23+
double op_reference(double x, double y) const override {
24+
uint64_t lhs, rhs;
25+
std::memcpy(&lhs, &x, sizeof(lhs));
26+
std::memcpy(&rhs, &y, sizeof(rhs));
27+
return bool(lhs) != bool(rhs);
28+
}
2829
};
30+
31+
IMPLEMENT_BINARY_LOGICAL_OP_TEST(OpLogicalXorTest)

kernels/test/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@ def define_common_targets():
4444
runtime.cxx_library(
4545
name = "test_util" + aten_suffix,
4646
srcs = [
47+
"BinaryLogicalOpTest.cpp",
4748
"UnaryUfuncRealHBBF16ToFloatHBF16Test.cpp",
4849
],
4950
exported_headers = [
51+
"BinaryLogicalOpTest.h",
5052
"TestUtil.h",
5153
"UnaryUfuncRealHBBF16ToFloatHBF16Test.h",
5254
],

0 commit comments

Comments
 (0)