Skip to content

Commit e286166

Browse files
[SYCL] Make builtins accept half pointers (#6596)
Some SYCL math builtins with pointer arguments, such as modf and sincos, do not currently accept pointers to halfs due to the conversion to OpenCL types not propagating through pointers. This commit fixes this by making a special case for pointers, applying the type conversion to the underlying types. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent da7dcf8 commit e286166

File tree

3 files changed

+69
-33
lines changed

3 files changed

+69
-33
lines changed

sycl/include/sycl/access/access.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,11 @@ template <class T> struct remove_AS {
205205

206206
#ifdef __SYCL_DEVICE_ONLY__
207207
template <class T> struct deduce_AS {
208-
static_assert(!std::is_same<typename detail::remove_AS<T>::type, T>::value,
209-
"Only types with address space attributes are supported");
208+
// Undecorated pointers are considered generic.
209+
// TODO: This assumes that the implementation uses generic as default. If
210+
// address space inference is used this may need to change.
211+
static const access::address_space value =
212+
access::address_space::generic_space;
210213
};
211214

212215
template <class T> struct remove_AS<__OPENCL_GLOBAL_AS__ T> {

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -335,25 +335,23 @@ template <typename T> class TryToGetPointerVecT {
335335
using type = decltype(check(T()));
336336
};
337337

338-
template <typename T, typename = typename detail::enable_if_t<
339-
TryToGetPointerT<T>::value, std::true_type>>
340-
typename TryToGetPointerVecT<T>::type TryToGetPointer(T &t) {
338+
template <
339+
typename To, typename From,
340+
typename = typename detail::enable_if_t<TryToGetPointerT<From>::value>>
341+
To ConvertNonVectorType(From &t) {
341342
// TODO find the better way to get the pointer to underlying data from vec
342343
// class
343-
return reinterpret_cast<typename TryToGetPointerVecT<T>::type>(t.get());
344+
return reinterpret_cast<To>(t.get());
344345
}
345346

346-
template <typename T>
347-
typename TryToGetPointerVecT<T *>::type TryToGetPointer(T *t) {
348-
// TODO find the better way to get the pointer to underlying data from vec
349-
// class
350-
return reinterpret_cast<typename TryToGetPointerVecT<T *>::type>(t);
347+
template <typename To, typename From> To ConvertNonVectorType(From *t) {
348+
return reinterpret_cast<To>(t);
351349
}
352350

353-
template <typename T, typename = typename detail::enable_if_t<
354-
!TryToGetPointerT<T>::value, std::false_type>>
355-
T TryToGetPointer(T &t) {
356-
return t;
351+
template <typename To, typename From>
352+
typename detail::enable_if_t<!TryToGetPointerT<From>::value, To>
353+
ConvertNonVectorType(From &t) {
354+
return static_cast<To>(t);
357355
}
358356

359357
// select_apply_cl_scalar_t selects from T8/T16/T32/T64 basing on
@@ -398,13 +396,14 @@ using select_cl_scalar_t = conditional_t<
398396
conditional_t<std::is_same<T, half>::value,
399397
sycl::detail::half_impl::BIsRepresentationT, T>>>;
400398

401-
// select_cl_vector_or_scalar does cl_* type selection for element type of
402-
// a vector type T and does scalar type substitution. If T is not
403-
// vector or scalar unmodified T is returned.
404-
template <typename T, typename Enable = void> struct select_cl_vector_or_scalar;
399+
// select_cl_vector_or_scalar_or_ptr does cl_* type selection for element type
400+
// of a vector type T, pointer type substitution, and scalar type substitution.
401+
// If T is not vector, scalar, or pointer unmodified T is returned.
402+
template <typename T, typename Enable = void>
403+
struct select_cl_vector_or_scalar_or_ptr;
405404

406405
template <typename T>
407-
struct select_cl_vector_or_scalar<
406+
struct select_cl_vector_or_scalar_or_ptr<
408407
T, typename detail::enable_if_t<is_vgentype<T>::value>> {
409408
using type =
410409
// select_cl_scalar_t returns _Float16, so, we try to instantiate vec
@@ -417,17 +416,31 @@ struct select_cl_vector_or_scalar<
417416
};
418417

419418
template <typename T>
420-
struct select_cl_vector_or_scalar<
421-
T, typename detail::enable_if_t<!is_vgentype<T>::value>> {
419+
struct select_cl_vector_or_scalar_or_ptr<
420+
T, typename detail::enable_if_t<!is_vgentype<T>::value &&
421+
!std::is_pointer<T>::value>> {
422422
using type = select_cl_scalar_t<T>;
423423
};
424424

425-
// select_cl_mptr_or_vector_or_scalar does cl_* type selection for type
426-
// pointed by multi_ptr or for element type of a vector type T and does
427-
// scalar type substitution. If T is not mutlti_ptr or vector or scalar
428-
// unmodified T is returned.
425+
template <typename T>
426+
struct select_cl_vector_or_scalar_or_ptr<
427+
T, typename detail::enable_if_t<!is_vgentype<T>::value &&
428+
std::is_pointer<T>::value>> {
429+
using elem_ptr_type = typename select_cl_vector_or_scalar_or_ptr<
430+
std::remove_pointer_t<T>>::type *;
431+
#ifdef __SYCL_DEVICE_ONLY__
432+
using type = typename DecoratedType<elem_ptr_type, deduce_AS<T>::value>::type;
433+
#else
434+
using type = elem_ptr_type;
435+
#endif
436+
};
437+
438+
// select_cl_mptr_or_vector_or_scalar_or_ptr does cl_* type selection for type
439+
// pointed by multi_ptr, for raw pointers, for element type of a vector type T,
440+
// and does scalar type substitution. If T is not mutlti_ptr or vector or
441+
// scalar or pointer unmodified T is returned.
429442
template <typename T, typename Enable = void>
430-
struct select_cl_mptr_or_vector_or_scalar;
443+
struct select_cl_mptr_or_vector_or_scalar_or_ptr;
431444

432445
// this struct helps to use std::uint8_t instead of std::byte,
433446
// which is not supported on device
@@ -444,25 +457,25 @@ template <> struct TypeHelper<std::byte> {
444457
template <typename T> using type_helper = typename TypeHelper<T>::RetType;
445458

446459
template <typename T>
447-
struct select_cl_mptr_or_vector_or_scalar<
460+
struct select_cl_mptr_or_vector_or_scalar_or_ptr<
448461
T, typename detail::enable_if_t<is_genptr<T>::value &&
449462
!std::is_pointer<T>::value>> {
450-
using type = multi_ptr<typename select_cl_vector_or_scalar<
463+
using type = multi_ptr<typename select_cl_vector_or_scalar_or_ptr<
451464
type_helper<typename T::element_type>>::type,
452465
T::address_space>;
453466
};
454467

455468
template <typename T>
456-
struct select_cl_mptr_or_vector_or_scalar<
469+
struct select_cl_mptr_or_vector_or_scalar_or_ptr<
457470
T, typename detail::enable_if_t<!is_genptr<T>::value ||
458471
std::is_pointer<T>::value>> {
459-
using type = typename select_cl_vector_or_scalar<T>::type;
472+
using type = typename select_cl_vector_or_scalar_or_ptr<T>::type;
460473
};
461474

462475
// All types converting shortcut.
463476
template <typename T>
464477
using SelectMatchingOpenCLType_t =
465-
typename select_cl_mptr_or_vector_or_scalar<T>::type;
478+
typename select_cl_mptr_or_vector_or_scalar_or_ptr<T>::type;
466479

467480
// Converts T to OpenCL friendly
468481
//
@@ -492,7 +505,7 @@ typename detail::enable_if_t<!(is_vgentype<FROM>::value &&
492505
sizeof(TO) == sizeof(FROM),
493506
TO>
494507
convertDataToType(FROM t) {
495-
return TryToGetPointer(t);
508+
return ConvertNonVectorType<TO>(t);
496509
}
497510

498511
// Used for all, any and select relational built-in functions
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//==--------------- half_ptr_builtins.cpp ----------------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// Verifies that builtins with pointer arguments accept pointers to sycl::half.
10+
// RUN: %clangxx -fsycl -fsyntax-only -Xclang -verify %s -Xclang -verify-ignore-unexpected=note,warning -Wno-sycl-strict
11+
// expected-no-diagnostics
12+
13+
#include <sycl/sycl.hpp>
14+
15+
int main() {
16+
sycl::half x;
17+
sycl::modf(sycl::half{1.0}, &x);
18+
sycl::sincos(sycl::half{1.0}, &x);
19+
return 0;
20+
}

0 commit comments

Comments
 (0)