Skip to content

[SYCL] Remove direct initialization constructor from bfloat16 class #4989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions sycl/include/sycl/ext/intel/experimental/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <CL/__spirv/spirv_ops.hpp>
#include <CL/sycl/half_type.hpp>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
Expand Down Expand Up @@ -43,8 +44,11 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
#endif
}

// Direct initialization
bfloat16(const storage_t &a) : value(a) {}
static bfloat16 from_bits(const storage_t &a) {
bfloat16 res;
res.value = a;
return res;
}

// Implicit conversion from float to bfloat16
bfloat16(const float &a) { value = from_float(a); }
Expand All @@ -56,9 +60,10 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {

// Implicit conversion from bfloat16 to float
operator float() const { return to_float(value); }
operator sycl::half() const { return to_float(value); }

// Get raw bits representation of bfloat16
operator storage_t() const { return value; }
storage_t raw() const { return value; }

// Logical operators (!,||,&&) are covered if we can cast to bool
explicit operator bool() { return to_float(value) != 0.0f; }
Expand Down
14 changes: 13 additions & 1 deletion sycl/test/extensions/bfloat16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using sycl::ext::intel::experimental::bfloat16;

SYCL_EXTERNAL uint16_t some_bf16_intrinsic(uint16_t x, uint16_t y);
SYCL_EXTERNAL void foo(long x, sycl::half y);

__attribute__((noinline)) float op(float a, float b) {
// CHECK: define {{.*}} spir_func float @_Z2opff(float [[a:%.*]], float [[b:%.*]])
Expand All @@ -27,11 +28,22 @@ __attribute__((noinline)) float op(float a, float b) {
// CHECK-NOT: uitofp
// CHECK-NOT: fptoui

bfloat16 D = some_bf16_intrinsic(A, C);
bfloat16 D = bfloat16::from_bits(some_bf16_intrinsic(A.raw(), C.raw()));
// CHECK: [[D:%.*]] = tail call spir_func zeroext i16 @_Z19some_bf16_intrinsictt(i16 zeroext [[A]], i16 zeroext [[C]])
// CHECK-NOT: uitofp
// CHECK-NOT: fptoui

long L = bfloat16(3.14f);
// CHECK: [[L_bfloat16:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float 0x40091EB860000000)
// CHECK: [[L_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[L_bfloat16]])
// CHECK: [[L:%.*]] = fptosi float [[L_float]] to i{{32|64}}

sycl::half H = bfloat16(2.71f);
// CHECK: [[H_bfloat16:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float 0x4005AE1480000000)
// CHECK: [[H_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[H_bfloat16]])
// CHECK: [[H:%.*]] = fptrunc float [[H_float]] to half
foo(L, H);

return D;
// CHECK: [[RetVal:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[D]])
// CHECK: ret float [[RetVal]]
Expand Down