Skip to content

[SYCL][ESIMD] BFN function implementation #8708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,8 @@ class ESIMDIntrinDescTable {
{"test.src.tmpl.arg", {t(0), t1(1), t8(2), t16(3), t32(4), c8(17)}}},
{"slm_init", {"slm.init", {a(0)}}},
{"bf_cvt", {"bf.cvt", {a(0)}}},
{"tf32_cvt", {"tf32.cvt", {a(0)}}}};
{"tf32_cvt", {"tf32.cvt", {a(0)}}},
{"bfn", {"bfn", {a(0), a(1), a(2), t(0)}}}};
}

const IntrinTable &getTable() { return Table; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,19 @@ __esimd_dpasw_nosrc0(__ESIMD_DNS::vector_type_t<T1, N1> src1,
}
#endif // !__SYCL_DEVICE_ONLY__

template <uint8_t FuncControl, typename T, int N>
__ESIMD_INTRIN __ESIMD_raw_vec_t(T, N)
__esimd_bfn(__ESIMD_raw_vec_t(T, N) src0, __ESIMD_raw_vec_t(T, N) src1,
__ESIMD_raw_vec_t(T, N) src2)
#ifdef __SYCL_DEVICE_ONLY__
;
#else // !__SYCL_DEVICE_ONLY__
{
__ESIMD_UNSUPPORTED_ON_HOST;
return __ESIMD_DNS::vector_type_t<T, N>();
}
#endif // !__SYCL_DEVICE_ONLY__

#undef __ESIMD_raw_vec_t
#undef __ESIMD_cpp_vec_t

Expand Down
100 changes: 100 additions & 0 deletions sycl/include/sycl/ext/intel/experimental/esimd/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1727,6 +1727,106 @@ __ESIMD_API __ESIMD_NS::simd<T, N> dpasw2(
}
/// @} sycl_esimd_systolic_array_api

/// @addtogroup sycl_esimd_logical
/// @{

/// This enum is used to encode all possible logical operations performed
/// on the 3 input operands. It is used as a template argument of the bfn()
/// function.
/// Example: d = bfn<~bfn_t::x & ~bfn_t::y & ~bfn_t::z>(s0, s1, s2);
enum class bfn_t : uint8_t { x = 0xAA, y = 0xCC, z = 0xF0 };

static constexpr bfn_t operator~(bfn_t x) {
uint8_t val = static_cast<uint8_t>(x);
uint8_t res = ~val;
return static_cast<bfn_t>(res);
}

static constexpr bfn_t operator|(bfn_t x, bfn_t y) {
uint8_t arg0 = static_cast<uint8_t>(x);
uint8_t arg1 = static_cast<uint8_t>(y);
uint8_t res = arg0 | arg1;
return static_cast<bfn_t>(res);
}

static constexpr bfn_t operator&(bfn_t x, bfn_t y) {
uint8_t arg0 = static_cast<uint8_t>(x);
uint8_t arg1 = static_cast<uint8_t>(y);
uint8_t res = arg0 & arg1;
return static_cast<bfn_t>(res);
}

static constexpr bfn_t operator^(bfn_t x, bfn_t y) {
uint8_t arg0 = static_cast<uint8_t>(x);
uint8_t arg1 = static_cast<uint8_t>(y);
uint8_t res = arg0 ^ arg1;
return static_cast<bfn_t>(res);
}

/// Performs binary function computation with three vector operands.
/// @tparam FuncControl boolean function control expressed with bfn_t
/// enum values.
/// @tparam T type of the input vector element.
/// @tparam N size of the input vector.
/// @param s0 First boolean function argument.
/// @param s1 Second boolean function argument.
/// @param s2 Third boolean function argument.
template <bfn_t FuncControl, typename T, int N>
__ESIMD_API std::enable_if_t<std::is_integral_v<T>, __ESIMD_NS::simd<T, N>>
bfn(__ESIMD_NS::simd<T, N> src0, __ESIMD_NS::simd<T, N> src1,
__ESIMD_NS::simd<T, N> src2) {
if constexpr ((sizeof(T) == 8) || ((sizeof(T) == 1) && (N % 4 == 0)) ||
((sizeof(T) == 2) && (N % 2 == 0))) {
// Bitcast Nx8-byte vectors to 2xN vectors of 4-byte integers.
// Bitcast Nx1-byte vectors to N/4 vectors of 4-byte integers.
// Bitcast Nx2-byte vectors to N/2 vectors of 4-byte integers.
auto Result = __ESIMD_ENS::bfn<FuncControl>(
src0.template bit_cast_view<int32_t>().read(),
src1.template bit_cast_view<int32_t>().read(),
src2.template bit_cast_view<int32_t>().read());
return Result.template bit_cast_view<T>();
} else if constexpr (sizeof(T) == 2 || sizeof(T) == 4) {
constexpr uint8_t FC = static_cast<uint8_t>(FuncControl);
return __esimd_bfn<FC, T, N>(src0.data(), src1.data(), src2.data());
} else if constexpr (N % 2 == 0) {
// Bitcast Nx1-byte vectors (N is even) to N/2 vectors of 2-byte integers.
auto Result = __ESIMD_ENS::bfn<FuncControl>(
src0.template bit_cast_view<int16_t>().read(),
src1.template bit_cast_view<int16_t>().read(),
src2.template bit_cast_view<int16_t>().read());
return Result.template bit_cast_view<T>();
} else {
// Odd number of 1-byte elements.
__ESIMD_NS::simd<T, N + 1> Src0, Src1, Src2;
Src0.template select<N, 1>() = src0;
Src1.template select<N, 1>() = src1;
Src2.template select<N, 1>() = src2;
auto Result = __ESIMD_ENS::bfn<FuncControl>(Src0, Src1, Src2);
return Result.template select<N, 1>();
}
}

/// Performs binary function computation with three scalar operands.
/// @tparam FuncControl boolean function control expressed with bfn_t enum
/// values.
/// @tparam T type of the input vector element.
/// @param s0 First boolean function argument.
/// @param s1 Second boolean function argument.
/// @param s2 Third boolean function argument.
template <bfn_t FuncControl, typename T>
ESIMD_NODEBUG ESIMD_INLINE std::enable_if_t<
__ESIMD_DNS::is_esimd_scalar<T>::value && std::is_integral_v<T>, T>
bfn(T src0, T src1, T src2) {
__ESIMD_NS::simd<T, 1> Src0 = src0;
__ESIMD_NS::simd<T, 1> Src1 = src1;
__ESIMD_NS::simd<T, 1> Src2 = src2;
__ESIMD_NS::simd<T, 1> Result =
esimd::bfn<FuncControl, T, 1>(Src0, Src1, Src2);
return Result[0];
}

/// @} sycl_esimd_logical

} // namespace ext::intel::experimental::esimd
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl
8 changes: 8 additions & 0 deletions sycl/test/esimd/intrins_trans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,4 +302,12 @@ SYCL_EXTERNAL void test_math_intrins() SYCL_ESIMD_FUNCTION {
// CHECK-LABEL: %{{[a-zA-Z0-9.]+}} = call <8 x float> @llvm.genx.ieee.sqrt.v8f32(<8 x float> %{{[a-zA-Z0-9.]+}})
use(y);
}
{
vec<int, 8> x0 = get8i();
vec<int, 8> x1 = get8i();
vec<int, 8> x2 = get8i();
auto res = __esimd_bfn<0xff, int, 8>(x0, x1, x2);
// CHECK-LABEL: %{{[a-zA-Z0-9.]+}} = call <8 x i32> @llvm.genx.bfn.v8i32.v8i32(<8 x i32> %{{[a-zA-Z0-9.]+}}, <8 x i32> %{{[a-zA-Z0-9.]+}}, <8 x i32> %{{[a-zA-Z0-9.]+}}, i8 -1)
use(res);
}
}
11 changes: 11 additions & 0 deletions sycl/test/esimd/math_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using namespace sycl;
using namespace sycl::ext::intel;
using namespace sycl::ext::intel::esimd;
using namespace sycl::ext::intel::experimental::esimd;

// Math sin,cos,log,exp functions are translated into scalar __spirv_ocl_ calls
SYCL_ESIMD_FUNCTION SYCL_EXTERNAL simd<float, 16> sycl_math(simd<float, 16> x) {
Expand Down Expand Up @@ -52,3 +53,13 @@ esimd_math_emu(simd<float, 16> x) {
v = esimd::exp(v);
return v;
}

// Logical BNF function from esimd namespace is translated into __esimd_ calls,
// which later translate into GenX intrinsics.
SYCL_ESIMD_FUNCTION SYCL_EXTERNAL simd<int, 16>
esimd_bfn(simd<int, 16> x, simd<int, 16> y, simd<int, 16> z) {
simd<int, 16> v =
experimental::esimd::bfn<~bfn_t::x & ~bfn_t::y & ~bfn_t::z>(x, y, z);
//CHECK: call spir_func noundef <16 x i32> @_Z11__esimd_bfn
return v;
}