Skip to content

Commit 25d0475

Browse files
authored
[SYCL][ESIMD] BFN function implementation (#8708)
API follows a similar one offered by CM. Example: d = esimd::bfn<~bfn_t::x & ~bfn_t::y & ~bfn_t::z>(s0, s1, s2);
1 parent 4167545 commit 25d0475

File tree

5 files changed

+134
-1
lines changed

5 files changed

+134
-1
lines changed

llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,8 @@ class ESIMDIntrinDescTable {
671671
{"test.src.tmpl.arg", {t(0), t1(1), t8(2), t16(3), t32(4), c8(17)}}},
672672
{"slm_init", {"slm.init", {a(0)}}},
673673
{"bf_cvt", {"bf.cvt", {a(0)}}},
674-
{"tf32_cvt", {"tf32.cvt", {a(0)}}}};
674+
{"tf32_cvt", {"tf32.cvt", {a(0)}}},
675+
{"bfn", {"bfn", {a(0), a(1), a(2), t(0)}}}};
675676
}
676677

677678
const IntrinTable &getTable() { return Table; }

sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,19 @@ __esimd_dpasw_nosrc0(__ESIMD_DNS::vector_type_t<T1, N1> src1,
714714
}
715715
#endif // !__SYCL_DEVICE_ONLY__
716716

717+
template <uint8_t FuncControl, typename T, int N>
718+
__ESIMD_INTRIN __ESIMD_raw_vec_t(T, N)
719+
__esimd_bfn(__ESIMD_raw_vec_t(T, N) src0, __ESIMD_raw_vec_t(T, N) src1,
720+
__ESIMD_raw_vec_t(T, N) src2)
721+
#ifdef __SYCL_DEVICE_ONLY__
722+
;
723+
#else // !__SYCL_DEVICE_ONLY__
724+
{
725+
__ESIMD_UNSUPPORTED_ON_HOST;
726+
return __ESIMD_DNS::vector_type_t<T, N>();
727+
}
728+
#endif // !__SYCL_DEVICE_ONLY__
729+
717730
#undef __ESIMD_raw_vec_t
718731
#undef __ESIMD_cpp_vec_t
719732

sycl/include/sycl/ext/intel/experimental/esimd/math.hpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,6 +1727,106 @@ __ESIMD_API __ESIMD_NS::simd<T, N> dpasw2(
17271727
}
17281728
/// @} sycl_esimd_systolic_array_api
17291729

1730+
/// @addtogroup sycl_esimd_logical
1731+
/// @{
1732+
1733+
/// This enum is used to encode all possible logical operations performed
1734+
/// on the 3 input operands. It is used as a template argument of the bfn()
1735+
/// function.
1736+
/// Example: d = bfn<~bfn_t::x & ~bfn_t::y & ~bfn_t::z>(s0, s1, s2);
1737+
enum class bfn_t : uint8_t { x = 0xAA, y = 0xCC, z = 0xF0 };
1738+
1739+
static constexpr bfn_t operator~(bfn_t x) {
1740+
uint8_t val = static_cast<uint8_t>(x);
1741+
uint8_t res = ~val;
1742+
return static_cast<bfn_t>(res);
1743+
}
1744+
1745+
static constexpr bfn_t operator|(bfn_t x, bfn_t y) {
1746+
uint8_t arg0 = static_cast<uint8_t>(x);
1747+
uint8_t arg1 = static_cast<uint8_t>(y);
1748+
uint8_t res = arg0 | arg1;
1749+
return static_cast<bfn_t>(res);
1750+
}
1751+
1752+
static constexpr bfn_t operator&(bfn_t x, bfn_t y) {
1753+
uint8_t arg0 = static_cast<uint8_t>(x);
1754+
uint8_t arg1 = static_cast<uint8_t>(y);
1755+
uint8_t res = arg0 & arg1;
1756+
return static_cast<bfn_t>(res);
1757+
}
1758+
1759+
static constexpr bfn_t operator^(bfn_t x, bfn_t y) {
1760+
uint8_t arg0 = static_cast<uint8_t>(x);
1761+
uint8_t arg1 = static_cast<uint8_t>(y);
1762+
uint8_t res = arg0 ^ arg1;
1763+
return static_cast<bfn_t>(res);
1764+
}
1765+
1766+
/// Performs binary function computation with three vector operands.
1767+
/// @tparam FuncControl boolean function control expressed with bfn_t
1768+
/// enum values.
1769+
/// @tparam T type of the input vector element.
1770+
/// @tparam N size of the input vector.
1771+
/// @param s0 First boolean function argument.
1772+
/// @param s1 Second boolean function argument.
1773+
/// @param s2 Third boolean function argument.
1774+
template <bfn_t FuncControl, typename T, int N>
1775+
__ESIMD_API std::enable_if_t<std::is_integral_v<T>, __ESIMD_NS::simd<T, N>>
1776+
bfn(__ESIMD_NS::simd<T, N> src0, __ESIMD_NS::simd<T, N> src1,
1777+
__ESIMD_NS::simd<T, N> src2) {
1778+
if constexpr ((sizeof(T) == 8) || ((sizeof(T) == 1) && (N % 4 == 0)) ||
1779+
((sizeof(T) == 2) && (N % 2 == 0))) {
1780+
// Bitcast Nx8-byte vectors to 2xN vectors of 4-byte integers.
1781+
// Bitcast Nx1-byte vectors to N/4 vectors of 4-byte integers.
1782+
// Bitcast Nx2-byte vectors to N/2 vectors of 4-byte integers.
1783+
auto Result = __ESIMD_ENS::bfn<FuncControl>(
1784+
src0.template bit_cast_view<int32_t>().read(),
1785+
src1.template bit_cast_view<int32_t>().read(),
1786+
src2.template bit_cast_view<int32_t>().read());
1787+
return Result.template bit_cast_view<T>();
1788+
} else if constexpr (sizeof(T) == 2 || sizeof(T) == 4) {
1789+
constexpr uint8_t FC = static_cast<uint8_t>(FuncControl);
1790+
return __esimd_bfn<FC, T, N>(src0.data(), src1.data(), src2.data());
1791+
} else if constexpr (N % 2 == 0) {
1792+
// Bitcast Nx1-byte vectors (N is even) to N/2 vectors of 2-byte integers.
1793+
auto Result = __ESIMD_ENS::bfn<FuncControl>(
1794+
src0.template bit_cast_view<int16_t>().read(),
1795+
src1.template bit_cast_view<int16_t>().read(),
1796+
src2.template bit_cast_view<int16_t>().read());
1797+
return Result.template bit_cast_view<T>();
1798+
} else {
1799+
// Odd number of 1-byte elements.
1800+
__ESIMD_NS::simd<T, N + 1> Src0, Src1, Src2;
1801+
Src0.template select<N, 1>() = src0;
1802+
Src1.template select<N, 1>() = src1;
1803+
Src2.template select<N, 1>() = src2;
1804+
auto Result = __ESIMD_ENS::bfn<FuncControl>(Src0, Src1, Src2);
1805+
return Result.template select<N, 1>();
1806+
}
1807+
}
1808+
1809+
/// Performs binary function computation with three scalar operands.
1810+
/// @tparam FuncControl boolean function control expressed with bfn_t enum
1811+
/// values.
1812+
/// @tparam T type of the input vector element.
1813+
/// @param s0 First boolean function argument.
1814+
/// @param s1 Second boolean function argument.
1815+
/// @param s2 Third boolean function argument.
1816+
template <bfn_t FuncControl, typename T>
1817+
ESIMD_NODEBUG ESIMD_INLINE std::enable_if_t<
1818+
__ESIMD_DNS::is_esimd_scalar<T>::value && std::is_integral_v<T>, T>
1819+
bfn(T src0, T src1, T src2) {
1820+
__ESIMD_NS::simd<T, 1> Src0 = src0;
1821+
__ESIMD_NS::simd<T, 1> Src1 = src1;
1822+
__ESIMD_NS::simd<T, 1> Src2 = src2;
1823+
__ESIMD_NS::simd<T, 1> Result =
1824+
esimd::bfn<FuncControl, T, 1>(Src0, Src1, Src2);
1825+
return Result[0];
1826+
}
1827+
1828+
/// @} sycl_esimd_logical
1829+
17301830
} // namespace ext::intel::experimental::esimd
17311831
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
17321832
} // namespace sycl

sycl/test/esimd/intrins_trans.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,4 +302,12 @@ SYCL_EXTERNAL void test_math_intrins() SYCL_ESIMD_FUNCTION {
302302
// CHECK-LABEL: %{{[a-zA-Z0-9.]+}} = call <8 x float> @llvm.genx.ieee.sqrt.v8f32(<8 x float> %{{[a-zA-Z0-9.]+}})
303303
use(y);
304304
}
305+
{
306+
vec<int, 8> x0 = get8i();
307+
vec<int, 8> x1 = get8i();
308+
vec<int, 8> x2 = get8i();
309+
auto res = __esimd_bfn<0xff, int, 8>(x0, x1, x2);
310+
// CHECK-LABEL: %{{[a-zA-Z0-9.]+}} = call <8 x i32> @llvm.genx.bfn.v8i32.v8i32(<8 x i32> %{{[a-zA-Z0-9.]+}}, <8 x i32> %{{[a-zA-Z0-9.]+}}, <8 x i32> %{{[a-zA-Z0-9.]+}}, i8 -1)
311+
use(res);
312+
}
305313
}

sycl/test/esimd/math_impl.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using namespace sycl;
1111
using namespace sycl::ext::intel;
1212
using namespace sycl::ext::intel::esimd;
13+
using namespace sycl::ext::intel::experimental::esimd;
1314

1415
// Math sin,cos,log,exp functions are translated into scalar __spirv_ocl_ calls
1516
SYCL_ESIMD_FUNCTION SYCL_EXTERNAL simd<float, 16> sycl_math(simd<float, 16> x) {
@@ -52,3 +53,13 @@ esimd_math_emu(simd<float, 16> x) {
5253
v = esimd::exp(v);
5354
return v;
5455
}
56+
57+
// Logical BNF function from esimd namespace is translated into __esimd_ calls,
58+
// which later translate into GenX intrinsics.
59+
SYCL_ESIMD_FUNCTION SYCL_EXTERNAL simd<int, 16>
60+
esimd_bfn(simd<int, 16> x, simd<int, 16> y, simd<int, 16> z) {
61+
simd<int, 16> v =
62+
experimental::esimd::bfn<~bfn_t::x & ~bfn_t::y & ~bfn_t::z>(x, y, z);
63+
//CHECK: call spir_func noundef <16 x i32> @_Z11__esimd_bfn
64+
return v;
65+
}

0 commit comments

Comments
 (0)