Skip to content

Commit 78eb088

Browse files
authored
[SYCL][ESIMD] Add support for tf32 (#6828)
1 parent 067d3b3 commit 78eb088

File tree

7 files changed

+185
-8
lines changed

7 files changed

+185
-8
lines changed

llvm/lib/SYCLLowerIR/ESIMD/ESIMDVerifier.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ static const char *LegalSYCLFunctions[] = {
5151
"^sycl::_V1::ext::oneapi::sub_group::.+",
5252
"^sycl::_V1::ext::oneapi::experimental::spec_constant<.+>::.+",
5353
"^sycl::_V1::ext::oneapi::experimental::this_sub_group",
54-
"^sycl::_V1::ext::oneapi::experimental::bfloat16::.+"};
54+
"^sycl::_V1::ext::oneapi::experimental::bfloat16::.+",
55+
"^sycl::_V1::ext::oneapi::experimental::tfloat32::.+"};
5556

5657
static const char *LegalSYCLFunctionsInStatelessMode[] = {
5758
"^sycl::_V1::multi_ptr<.+>::get",

llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,8 @@ class ESIMDIntrinDescTable {
654654
{"test_src_tmpl_arg",
655655
{"test.src.tmpl.arg", {t(0), t1(1), t8(2), t16(3), t32(4), c8(17)}}},
656656
{"slm_init", {"slm.init", {a(0)}}},
657-
{"bf_cvt", {"bf.cvt", {a(0)}}}};
657+
{"bf_cvt", {"bf.cvt", {a(0)}}},
658+
{"tf32_cvt", {"tf32.cvt", {a(0)}}}};
658659
}
659660

660661
const IntrinTable &getTable() { return Table; }

sycl/include/sycl/ext/intel/esimd.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
#include <sycl/ext/intel/esimd/common.hpp>
8484
#include <sycl/ext/intel/esimd/detail/bfloat16_type_traits.hpp>
8585
#include <sycl/ext/intel/esimd/detail/half_type_traits.hpp>
86+
#include <sycl/ext/intel/esimd/detail/tfloat32_type_traits.hpp>
8687
#include <sycl/ext/intel/esimd/simd.hpp>
8788
#include <sycl/ext/intel/esimd/simd_view.hpp>
8889
#include <sycl/ext/intel/esimd/xmx/dpas.hpp>

sycl/include/sycl/ext/intel/esimd/detail/intrin.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,4 +345,9 @@ __ESIMD_INTRIN __ESIMD_DNS::vector_type_t<To, N>
345345
__esimd_bf_cvt(__ESIMD_DNS::vector_type_t<From, N> Val);
346346
#endif // __SYCL_DEVICE_ONLY__
347347

348-
/// @endcond ESIMD_DETAIL
348+
#ifdef __SYCL_DEVICE_ONLY__
349+
template <class To, class From, int N>
350+
__ESIMD_INTRIN __ESIMD_DNS::vector_type_t<To, N>
351+
__esimd_tf32_cvt(__ESIMD_DNS::vector_type_t<From, N> Val);
352+
#endif // __SYCL_DEVICE_ONLY__
353+
/// @endcond ESIMD_DETAIL
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//==-------------- tfloat32_type_traits.hpp - DPC++ Explicit SIMD API
2+
//----------==//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
// Implementation of SIMD element type traits for the tfloat32 type.
10+
//===----------------------------------------------------------------------===//
11+
12+
#pragma once
13+
14+
#include <sycl/ext/intel/esimd/detail/elem_type_traits.hpp>
15+
#include <sycl/ext/intel/esimd/detail/intrin.hpp>
16+
#include <sycl/ext/intel/experimental/esimd/tfloat32.hpp>
17+
18+
/// @cond ESIMD_DETAIL
19+
20+
namespace sycl {
21+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
22+
namespace ext::intel::esimd::detail {
23+
24+
// Standalone definitions to use w/o instantiating element_type_traits.
25+
using tfloat32 = sycl::ext::intel::experimental::esimd::tfloat32;
26+
27+
template <> struct element_type_traits<tfloat32> {
28+
using RawT = unsigned int;
29+
using EnclosingCppT = float;
30+
31+
static inline constexpr bool use_native_cpp_ops = false;
32+
static inline constexpr bool is_floating_point = true;
33+
};
34+
35+
// ------------------- Type conversion traits
36+
37+
template <int N> struct vector_conversion_traits<tfloat32, N> {
38+
using StdT = __cpp_t<tfloat32>;
39+
using RawT = __raw_t<tfloat32>;
40+
41+
static ESIMD_INLINE vector_type_t<RawT, N>
42+
convert_to_raw(vector_type_t<StdT, N> Val) {
43+
#ifdef __SYCL_DEVICE_ONLY__
44+
vector_type_t<RawT, N> Result = __esimd_tf32_cvt<RawT, StdT, N>(Val);
45+
return Result;
46+
#else
47+
vector_type_t<RawT, N> Output = 0;
48+
49+
for (int i = 0; i < N; i++) {
50+
Output[i] = sycl::bit_cast<RawT>(static_cast<tfloat32>(Val[i]));
51+
}
52+
return Output;
53+
#endif
54+
}
55+
56+
static ESIMD_INLINE vector_type_t<StdT, N>
57+
convert_to_cpp(vector_type_t<RawT, N> Val) {
58+
vector_type_t<StdT, N> Result = sycl::bit_cast<vector_type_t<StdT, N>>(Val);
59+
return Result;
60+
}
61+
};
62+
63+
template <> struct scalar_conversion_traits<tfloat32> {
64+
using RawT = __raw_t<tfloat32>;
65+
66+
static ESIMD_INLINE RawT bitcast_to_raw(tfloat32 Val) {
67+
return sycl::bit_cast<RawT>(Val);
68+
}
69+
70+
static ESIMD_INLINE tfloat32 bitcast_to_wrapper(RawT Val) {
71+
return sycl::bit_cast<tfloat32>(Val);
72+
}
73+
};
74+
75+
// Misc
76+
inline std::ostream &operator<<(std::ostream &O, tfloat32 const &rhs) {
77+
O << static_cast<float>(rhs);
78+
return O;
79+
}
80+
81+
template <> struct is_esimd_arithmetic_type<tfloat32, void> : std::true_type {};
82+
83+
} // namespace ext::intel::esimd::detail
84+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
85+
} // namespace sycl
86+
87+
/// @endcond ESIMD_DETAIL

sycl/include/sycl/ext/intel/esimd/memory.hpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,15 @@ ESIMD_INLINE
357357
__esimd_scatter_scaled<PromoT, N, decltype(si), TypeSizeLog2, scale>(
358358
mask.data(), si, glob_offset, offsets.data(), promo_vals.data());
359359
} else {
360-
__esimd_scatter_scaled<T, N, decltype(si), TypeSizeLog2, scale>(
361-
mask.data(), si, glob_offset, offsets.data(), vals.data());
360+
using Treal = __raw_t<T>;
361+
if constexpr (!std::is_same_v<Treal, T>) {
362+
simd<Treal, N> Values = vals.template bit_cast_view<Treal>();
363+
__esimd_scatter_scaled<Treal, N, decltype(si), TypeSizeLog2, scale>(
364+
mask.data(), si, glob_offset, offsets.data(), Values.data());
365+
} else {
366+
__esimd_scatter_scaled<T, N, decltype(si), TypeSizeLog2, scale>(
367+
mask.data(), si, glob_offset, offsets.data(), vals.data());
368+
}
362369
}
363370
}
364371

@@ -396,9 +403,15 @@ gather_impl(AccessorTy acc, simd<uint32_t, N> offsets, uint32_t glob_offset,
396403
return Res;
397404
}
398405
} else {
399-
return __esimd_gather_masked_scaled2<T, N, decltype(si), TypeSizeLog2,
400-
scale>(si, glob_offset, offsets.data(),
401-
mask.data());
406+
using Treal = __raw_t<T>;
407+
simd<Treal, N> Res = __esimd_gather_masked_scaled2<Treal, N, decltype(si),
408+
TypeSizeLog2, scale>(
409+
si, glob_offset, offsets.data(), mask.data());
410+
if constexpr (!std::is_same_v<Treal, T>) {
411+
return Res.template bit_cast_view<T>();
412+
} else {
413+
return Res;
414+
}
402415
}
403416
}
404417

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//==--------- tfloat32.hpp ------- SYCL tensorfloat32 conversion ------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// Implementation of SIMD tfloat32 type.
9+
//===----------------------------------------------------------------------===//
10+
11+
#pragma once
12+
13+
#include <CL/__spirv/spirv_ops.hpp>
14+
#include <sycl/bit_cast.hpp>
15+
16+
namespace sycl {
17+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
18+
namespace ext {
19+
namespace intel {
20+
namespace experimental {
21+
namespace esimd {
22+
23+
class tfloat32 {
24+
using storage_t = uint32_t;
25+
storage_t value;
26+
27+
public:
28+
tfloat32() = default;
29+
tfloat32(const tfloat32 &) = default;
30+
~tfloat32() = default;
31+
32+
// Explicit conversion functions
33+
static storage_t from_float(const float &a) {
34+
storage_t tmp_uint = sycl::bit_cast<storage_t>(a);
35+
tmp_uint &= 0xFFFFE000u;
36+
return tmp_uint;
37+
}
38+
static float to_float(const storage_t &a) {
39+
return sycl::bit_cast<float>(a & 0xFFFFE000u);
40+
}
41+
42+
// Implicit conversion from float to tfloat32
43+
tfloat32(const float &a) { value = from_float(a); }
44+
45+
tfloat32 &operator=(const float &rhs) {
46+
value = from_float(rhs);
47+
return *this;
48+
}
49+
50+
// Implicit conversion from tfloat32 to float
51+
operator float() const { return to_float(value); }
52+
53+
// Get raw bits representation of tfloat32
54+
storage_t raw() const { return value; }
55+
56+
// Logical operators (!,||,&&) are covered if we can cast to bool
57+
explicit operator bool() { return to_float(value) != 0.0f; }
58+
59+
// Unary minus operator overloading
60+
friend tfloat32 operator-(tfloat32 &lhs) { return tfloat32(-to_float(lhs)); }
61+
};
62+
63+
} // namespace esimd
64+
} // namespace experimental
65+
} // namespace intel
66+
} // namespace ext
67+
68+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
69+
} // namespace sycl

0 commit comments

Comments
 (0)