Skip to content

Commit 69bf18b

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Enable operator<< for Half (#1733)
Summary: Pull Request resolved: #1733 Reviewed By: mikekgfb Differential Revision: D53139440 fbshipit-source-id: e371ea15516677dfc726757047f4c5fb2b793d21
1 parent 862f755 commit 69bf18b

File tree

4 files changed

+33
-1
lines changed

4 files changed

+33
-1
lines changed

runtime/core/exec_aten/testing_util/tensor_util.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ std::ostream& operator<<(std::ostream& os, const Tensor& t) {
251251
break;
252252

253253
switch (t.scalar_type()) {
254-
ET_FORALL_REAL_TYPES_AND(Bool, PRINT_CASE)
254+
ET_FORALL_REAL_TYPES_AND2(Half, Bool, PRINT_CASE)
255255
default:
256256
ET_CHECK_MSG(
257257
false,

runtime/core/portable_type/half.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/portable_type/half.h>
10+
#include <ostream>
11+
#include <type_traits>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
static_assert(
17+
std::is_standard_layout_v<torch::executor::Half>,
18+
"Half must be standard layout.");
19+
20+
std::ostream& operator<<(
21+
std::ostream& out,
22+
const torch::executor::Half& value) {
23+
out << (float)value;
24+
return out;
25+
}
26+
27+
} // namespace executor
28+
} // namespace torch

runtime/core/portable_type/half.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <cstdint>
1313
#include <cstring>
1414
#include <limits>
15+
#include <ostream>
1516

1617
#if defined(__GNUC__) || defined(__clang__)
1718
#if defined(__aarch64__)
@@ -673,6 +674,8 @@ inline Half operator/(int64_t a, Half b) {
673674
/// NOTE: we do not define comparisons directly and instead rely on the implicit
674675
/// conversion Half to float.
675676

677+
std::ostream& operator<<(std::ostream& out, const Half& value);
678+
676679
} // namespace executor
677680
} // namespace torch
678681

runtime/core/portable_type/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def define_common_targets():
4040
# Set up a specific exported library for scalar_type to avoid circle dependency in ScalarTypeUtil.h
4141
runtime.cxx_library(
4242
name = "scalar_type",
43+
srcs = ["half.cpp"],
4344
exported_headers = [
4445
"bfloat16.h",
4546
"complex.h",

0 commit comments

Comments
 (0)