Skip to content

Commit 894f40a

Browse files
authored
[SYCL][ESIMD] Fix incorrect handling of non native floating types (#15181)
1 parent 5b02c4c commit 894f40a

File tree

2 files changed

+123
-11
lines changed

2 files changed

+123
-11
lines changed

sycl/include/sycl/ext/intel/esimd/math.hpp

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,12 @@ template <typename TRes, typename TArg, int SZ>
9898
ESIMD_NODEBUG ESIMD_INLINE simd<TRes, SZ>
9999
__esimd_abs_common_internal(simd<TArg, SZ> src0) {
100100
simd<TArg, SZ> Result;
101-
if constexpr (detail::is_generic_floating_point_v<TArg>)
102-
Result = simd<TArg, SZ>(__spirv_ocl_fabs<TArg, SZ>(src0.data()));
103-
else
101+
if constexpr (detail::is_generic_floating_point_v<TArg>) {
102+
using CppT = __ESIMD_DNS::element_type_traits<TArg>::EnclosingCppT;
103+
Result =
104+
__ESIMD_DNS::convert_vector<TArg, CppT, SZ>(__spirv_ocl_fabs<CppT, SZ>(
105+
__ESIMD_DNS::convert_vector<CppT, TArg, SZ>(src0.data())));
106+
} else
104107
Result = simd<TArg, SZ>(__spirv_ocl_s_abs<TArg, SZ>(src0.data()));
105108
return convert<TRes>(Result);
106109
}
@@ -184,8 +187,12 @@ template <typename T, int SZ, class Sat = saturation_off_tag>
184187
__ESIMD_API simd<T, SZ>(max)(simd<T, SZ> src0, simd<T, SZ> src1, Sat sat = {}) {
185188
constexpr bool is_sat = std::is_same_v<Sat, saturation_on_tag>;
186189

187-
if constexpr (std::is_floating_point<T>::value) {
188-
auto Result = __spirv_ocl_fmax<T, SZ>(src0.data(), src1.data());
190+
if constexpr (detail::is_generic_floating_point_v<T>) {
191+
using CppT = __ESIMD_DNS::element_type_traits<T>::EnclosingCppT;
192+
auto Result =
193+
__ESIMD_DNS::convert_vector<T, CppT, SZ>(__spirv_ocl_fmax<CppT, SZ>(
194+
__ESIMD_DNS::convert_vector<CppT, T, SZ>(src0.data()),
195+
__ESIMD_DNS::convert_vector<CppT, T, SZ>(src1.data())));
189196
if constexpr (is_sat)
190197
Result = __esimd_sat<T, T, SZ>(Result);
191198
return simd<T, SZ>(Result);
@@ -269,8 +276,12 @@ template <typename T, int SZ, class Sat = saturation_off_tag>
269276
__ESIMD_API simd<T, SZ>(min)(simd<T, SZ> src0, simd<T, SZ> src1, Sat sat = {}) {
270277
constexpr bool is_sat = std::is_same_v<Sat, saturation_on_tag>;
271278

272-
if constexpr (std::is_floating_point<T>::value) {
273-
auto Result = __spirv_ocl_fmin<T, SZ>(src0.data(), src1.data());
279+
if constexpr (detail::is_generic_floating_point_v<T>) {
280+
using CppT = __ESIMD_DNS::element_type_traits<T>::EnclosingCppT;
281+
auto Result =
282+
__ESIMD_DNS::convert_vector<T, CppT, SZ>(__spirv_ocl_fmin<CppT, SZ>(
283+
__ESIMD_DNS::convert_vector<CppT, T, SZ>(src0.data()),
284+
__ESIMD_DNS::convert_vector<CppT, T, SZ>(src1.data())));
274285
if constexpr (is_sat)
275286
Result = __esimd_sat<T, T, SZ>(Result);
276287
return simd<T, SZ>(Result);
@@ -1465,8 +1476,12 @@ template <typename T0, typename T1, int SZ> struct esimd_apply_prod {
14651476
template <typename T0, typename T1, int SZ> struct esimd_apply_reduced_max {
14661477
template <typename... T>
14671478
simd<T0, SZ> operator()(simd<T1, SZ> v1, simd<T1, SZ> v2) {
1468-
if constexpr (std::is_floating_point<T1>::value) {
1469-
return __spirv_ocl_fmax<T1, SZ>(v1.data(), v2.data());
1479+
if constexpr (detail::is_generic_floating_point_v<T1>) {
1480+
using CppT = __ESIMD_DNS::element_type_traits<T1>::EnclosingCppT;
1481+
return __ESIMD_DNS::convert_vector<T1, CppT, SZ>(
1482+
__spirv_ocl_fmax<CppT, SZ>(
1483+
__ESIMD_DNS::convert_vector<CppT, T1, SZ>(v1.data()),
1484+
__ESIMD_DNS::convert_vector<CppT, T1, SZ>(v2.data())));
14701485
} else if constexpr (std::is_unsigned<T1>::value) {
14711486
return __esimd_umax<T1, SZ>(v1.data(), v2.data());
14721487
} else {
@@ -1478,8 +1493,13 @@ template <typename T0, typename T1, int SZ> struct esimd_apply_reduced_max {
14781493
template <typename T0, typename T1, int SZ> struct esimd_apply_reduced_min {
14791494
template <typename... T>
14801495
simd<T0, SZ> operator()(simd<T1, SZ> v1, simd<T1, SZ> v2) {
1481-
if constexpr (std::is_floating_point<T1>::value) {
1482-
return __spirv_ocl_fmin<T1, SZ>(v1.data(), v2.data());
1496+
1497+
if constexpr (detail::is_generic_floating_point_v<T1>) {
1498+
using CppT = __ESIMD_DNS::element_type_traits<T1>::EnclosingCppT;
1499+
return __ESIMD_DNS::convert_vector<T1, CppT, SZ>(
1500+
__spirv_ocl_fmin<CppT, SZ>(
1501+
__ESIMD_DNS::convert_vector<CppT, T1, SZ>(v1.data()),
1502+
__ESIMD_DNS::convert_vector<CppT, T1, SZ>(v2.data())));
14831503
} else if constexpr (std::is_unsigned<T1>::value) {
14841504
return __esimd_umin<T1, SZ>(v1.data(), v2.data());
14851505
} else {

sycl/test-e2e/ESIMD/spirv_fp_test.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
//==- spirv_fp_test.cpp - Test for abs function -==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: arch-intel_gpu_pvc
9+
// RUN: %{build} -o %t.out
10+
// RUN: %{run} %t.out
11+
12+
#include <sycl/detail/core.hpp>
13+
#include <sycl/ext/intel/esimd.hpp>
14+
15+
#include <sycl/usm.hpp>
16+
#include <sycl/usm/usm_allocator.hpp>
17+
18+
using namespace sycl;
19+
using namespace sycl::ext::intel::esimd;
20+
using bf16 = sycl::ext::oneapi::bfloat16;
21+
using tfloat32 = sycl::ext::intel::experimental::esimd::tfloat32;
22+
23+
template <typename DataT>
24+
using shared_allocator = sycl::usm_allocator<DataT, sycl::usm::alloc::shared>;
25+
template <typename DataT>
26+
using shared_vector = std::vector<DataT, shared_allocator<DataT>>;
27+
28+
template <typename T, int N>
29+
bool test(sycl::queue &Queue, T testValue1, T testValue2) {
30+
shared_allocator<T> Allocator(Queue);
31+
32+
shared_vector<T> OutputAbs(N, 0, Allocator);
33+
shared_vector<T> OutputMin(N, 0, Allocator);
34+
shared_vector<T> OutputMax(N, 0, Allocator);
35+
36+
auto *OutputAbsPtr = OutputAbs.data();
37+
auto *OutputMinPtr = OutputMin.data();
38+
auto *OutputMaxPtr = OutputMax.data();
39+
40+
Queue.submit([&](sycl::handler &cgh) {
41+
auto Kernel = ([=]() SYCL_ESIMD_KERNEL {
42+
simd<T, N> Input1 = testValue1;
43+
simd<T, N> Input2 = testValue2;
44+
simd<T, N> ResultAbs = __ESIMD_NS::abs(Input1);
45+
simd<T, N> ResultMin = __ESIMD_NS::min(Input1, Input2);
46+
simd<T, N> ResultMax = __ESIMD_NS::max(Input1, Input2);
47+
ResultAbs.copy_to(OutputAbsPtr);
48+
ResultMin.copy_to(OutputMinPtr);
49+
ResultMax.copy_to(OutputMaxPtr);
50+
});
51+
cgh.single_task(Kernel);
52+
});
53+
Queue.wait();
54+
55+
for (int I = 0; I < N; I++) {
56+
if (std::abs(testValue1) != OutputAbs[I]) {
57+
std::cout << "Incorrect value for abs at index " << I << " "
58+
<< std::abs(testValue1) << " != " << OutputAbs[I] << std::endl;
59+
return false;
60+
}
61+
if (std::min(testValue1, testValue2) != OutputMin[I]) {
62+
std::cout << "Incorrect value for min at index " << I << " "
63+
<< std::min(testValue1, testValue2) << " != " << OutputMin[I]
64+
<< std::endl;
65+
return false;
66+
}
67+
68+
if (std::max(testValue1, testValue2) != OutputMax[I]) {
69+
std::cout << "Incorrect value for max at index " << I << " "
70+
<< std::max(testValue1, testValue2) << " != " << OutputMax[I]
71+
<< std::endl;
72+
return false;
73+
}
74+
}
75+
76+
return true;
77+
}
78+
79+
int main() {
80+
81+
bool Pass = true;
82+
sycl::queue Q;
83+
Pass &= test<bf16, 8>(Q, -1, -2);
84+
Pass &= test<tfloat32, 8>(Q, -1, -2);
85+
86+
if (Pass)
87+
std::cout << "Pass" << std::endl;
88+
else
89+
std::cout << "Fail" << std::endl;
90+
91+
return !Pass;
92+
}

0 commit comments

Comments
 (0)