Skip to content

Commit 3d6edb0

Browse files
authored
[ExecuTorch] support BF16 in op_copy
Differential Revision: D61981357 Pull Request resolved: #4979
1 parent c02546c commit 3d6edb0

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

kernels/portable/cpu/op_copy.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ Tensor& copy_out(
4545
ScalarType in_type = in.scalar_type();
4646
ScalarType src_type = src.scalar_type();
4747

48-
ET_SWITCH_REALHB_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() {
49-
ET_SWITCH_REALHB_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() {
48+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() {
49+
ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() {
5050
apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
5151
[](const CTYPE val_in, const CTYPE_SRC val_src) {
5252
return convert<CTYPE, CTYPE_SRC>(val_src);
@@ -75,8 +75,8 @@ copy_(RuntimeContext& ctx, Tensor& in, const Tensor& src, bool non_blocking) {
7575
ScalarType in_type = in.scalar_type();
7676
ScalarType src_type = src.scalar_type();
7777

78-
ET_SWITCH_REALHB_TYPES(in_type, ctx, "copy_", CTYPE, [&]() {
79-
ET_SWITCH_REALHB_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() {
78+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy_", CTYPE, [&]() {
79+
ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() {
8080
apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
8181
[](const CTYPE val_in, const CTYPE_SRC val_src) {
8282
return convert<CTYPE, CTYPE_SRC>(val_src);

kernels/test/op_copy_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,13 @@ class OpCopyInplaceTest : public OperatorTest {
125125
// regular test for copy.out
126126
TEST_F(OpCopyTest, AllRealDtypesSupported) {
127127
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
128-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
128+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
129129
#undef TEST_ENTRY
130130
}
131131

132132
TEST_F(OpCopyTest, EmptyInputSupported) {
133133
#define TEST_ENTRY(ctype, dtype) test_empty_input<ctype, ScalarType::dtype>();
134-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
134+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
135135
#undef TEST_ENTRY
136136
}
137137

0 commit comments

Comments
 (0)