Skip to content

Commit 70f9b5d

Browse files
authored
Support Half/BFloat16 in split_copy (#7901)
Partial fix for #7748.
1 parent 5ee5f2f commit 70f9b5d

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

kernels/portable/cpu/op_split_copy.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ void split_copy_Tensor_out(
5858
ScalarType in_type = input.scalar_type();
5959
ScalarType out_type = out[0].scalar_type();
6060

61-
ET_SWITCH_REAL_TYPES_AND(
62-
Bool, in_type, ctx, "split_copy.Tensor_out", CTYPE_IN, [&]() {
63-
ET_SWITCH_REAL_TYPES_AND(
64-
Bool, out_type, ctx, "split_copy.Tensor_out", CTYPE_OUT, [&]() {
61+
ET_SWITCH_REALHBBF16_TYPES(
62+
in_type, ctx, "split_copy.Tensor_out", CTYPE_IN, [&]() {
63+
ET_SWITCH_REALHBBF16_TYPES(
64+
out_type, ctx, "split_copy.Tensor_out", CTYPE_OUT, [&]() {
6565
const CTYPE_IN* input_data = input.const_data_ptr<CTYPE_IN>();
6666
for (size_t i = 0, e = out.size(); i < e; ++i) {
6767
size_t out_step = out[i].size(dim) * trailing_dims;

kernels/test/op_split_copy_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ TEST_F(OpSplitCopyTensorOutTest, LargerSplitSizeDoesNothing) {
298298

299299
TEST_F(OpSplitCopyTensorOutTest, AllDtypesSupported) {
300300
#define TEST_ENTRY(ctype, dtype) test_dtype<ScalarType::dtype>();
301-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
301+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
302302
#undef TEST_ENTRY
303303
// TODO: Also add tests for half, complex, quantized, and other types. Easiest
304304
// way to do that would be to make TensorFactory support zeros() and ones()

0 commit comments

Comments
 (0)