@@ -335,25 +335,23 @@ template <typename T> class TryToGetPointerVecT {
335
335
using type = decltype (check(T()));
336
336
};
337
337
338
- template <typename T, typename = typename detail::enable_if_t <
339
- TryToGetPointerT<T>::value, std::true_type>>
340
- typename TryToGetPointerVecT<T>::type TryToGetPointer (T &t) {
338
+ template <
339
+ typename To, typename From,
340
+ typename = typename detail::enable_if_t <TryToGetPointerT<From>::value>>
341
+ To ConvertNonVectorType (From &t) {
341
342
// TODO find the better way to get the pointer to underlying data from vec
342
343
// class
343
- return reinterpret_cast <typename TryToGetPointerVecT<T>::type >(t.get ());
344
+ return reinterpret_cast <To >(t.get ());
344
345
}
345
346
346
- template <typename T>
347
- typename TryToGetPointerVecT<T *>::type TryToGetPointer (T *t) {
348
- // TODO find the better way to get the pointer to underlying data from vec
349
- // class
350
- return reinterpret_cast <typename TryToGetPointerVecT<T *>::type>(t);
347
+ template <typename To, typename From> To ConvertNonVectorType (From *t) {
348
+ return reinterpret_cast <To>(t);
351
349
}
352
350
353
- template <typename T , typename = typename detail:: enable_if_t <
354
- !TryToGetPointerT<T >::value, std::false_type> >
355
- T TryToGetPointer (T &t) {
356
- return t ;
351
+ template <typename To , typename From>
352
+ typename detail:: enable_if_t < !TryToGetPointerT<From >::value, To >
353
+ ConvertNonVectorType (From &t) {
354
+ return static_cast <To>(t) ;
357
355
}
358
356
359
357
// select_apply_cl_scalar_t selects from T8/T16/T32/T64 basing on
@@ -398,13 +396,14 @@ using select_cl_scalar_t = conditional_t<
398
396
conditional_t <std::is_same<T, half>::value,
399
397
sycl::detail::half_impl::BIsRepresentationT, T>>>;
400
398
401
- // select_cl_vector_or_scalar does cl_* type selection for element type of
402
- // a vector type T and does scalar type substitution. If T is not
403
- // vector or scalar unmodified T is returned.
404
- template <typename T, typename Enable = void > struct select_cl_vector_or_scalar ;
399
+ // select_cl_vector_or_scalar_or_ptr does cl_* type selection for element type
400
+ // of a vector type T, pointer type substitution, and scalar type substitution.
401
+ // If T is not vector, scalar, or pointer unmodified T is returned.
402
+ template <typename T, typename Enable = void >
403
+ struct select_cl_vector_or_scalar_or_ptr ;
405
404
406
405
template <typename T>
407
- struct select_cl_vector_or_scalar <
406
+ struct select_cl_vector_or_scalar_or_ptr <
408
407
T, typename detail::enable_if_t <is_vgentype<T>::value>> {
409
408
using type =
410
409
// select_cl_scalar_t returns _Float16, so, we try to instantiate vec
@@ -417,17 +416,31 @@ struct select_cl_vector_or_scalar<
417
416
};
418
417
419
418
template <typename T>
420
- struct select_cl_vector_or_scalar <
421
- T, typename detail::enable_if_t <!is_vgentype<T>::value>> {
419
+ struct select_cl_vector_or_scalar_or_ptr <
420
+ T, typename detail::enable_if_t <!is_vgentype<T>::value &&
421
+ !std::is_pointer<T>::value>> {
422
422
using type = select_cl_scalar_t <T>;
423
423
};
424
424
425
- // select_cl_mptr_or_vector_or_scalar does cl_* type selection for type
426
- // pointed by multi_ptr or for element type of a vector type T and does
427
- // scalar type substitution. If T is not mutlti_ptr or vector or scalar
428
- // unmodified T is returned.
425
+ template <typename T>
426
+ struct select_cl_vector_or_scalar_or_ptr <
427
+ T, typename detail::enable_if_t <!is_vgentype<T>::value &&
428
+ std::is_pointer<T>::value>> {
429
+ using elem_ptr_type = typename select_cl_vector_or_scalar_or_ptr<
430
+ std::remove_pointer_t <T>>::type *;
431
+ #ifdef __SYCL_DEVICE_ONLY__
432
+ using type = typename DecoratedType<elem_ptr_type, deduce_AS<T>::value>::type;
433
+ #else
434
+ using type = elem_ptr_type;
435
+ #endif
436
+ };
437
+
438
+ // select_cl_mptr_or_vector_or_scalar_or_ptr does cl_* type selection for type
439
+ // pointed by multi_ptr, for raw pointers, for element type of a vector type T,
440
+ // and does scalar type substitution. If T is not mutlti_ptr or vector or
441
+ // scalar or pointer unmodified T is returned.
429
442
template <typename T, typename Enable = void >
430
- struct select_cl_mptr_or_vector_or_scalar ;
443
+ struct select_cl_mptr_or_vector_or_scalar_or_ptr ;
431
444
432
445
// this struct helps to use std::uint8_t instead of std::byte,
433
446
// which is not supported on device
@@ -444,25 +457,25 @@ template <> struct TypeHelper<std::byte> {
444
457
template <typename T> using type_helper = typename TypeHelper<T>::RetType;
445
458
446
459
template <typename T>
447
- struct select_cl_mptr_or_vector_or_scalar <
460
+ struct select_cl_mptr_or_vector_or_scalar_or_ptr <
448
461
T, typename detail::enable_if_t <is_genptr<T>::value &&
449
462
!std::is_pointer<T>::value>> {
450
- using type = multi_ptr<typename select_cl_vector_or_scalar <
463
+ using type = multi_ptr<typename select_cl_vector_or_scalar_or_ptr <
451
464
type_helper<typename T::element_type>>::type,
452
465
T::address_space>;
453
466
};
454
467
455
468
template <typename T>
456
- struct select_cl_mptr_or_vector_or_scalar <
469
+ struct select_cl_mptr_or_vector_or_scalar_or_ptr <
457
470
T, typename detail::enable_if_t <!is_genptr<T>::value ||
458
471
std::is_pointer<T>::value>> {
459
- using type = typename select_cl_vector_or_scalar <T>::type;
472
+ using type = typename select_cl_vector_or_scalar_or_ptr <T>::type;
460
473
};
461
474
462
475
// All types converting shortcut.
463
476
template <typename T>
464
477
using SelectMatchingOpenCLType_t =
465
- typename select_cl_mptr_or_vector_or_scalar <T>::type;
478
+ typename select_cl_mptr_or_vector_or_scalar_or_ptr <T>::type;
466
479
467
480
// Converts T to OpenCL friendly
468
481
//
@@ -492,7 +505,7 @@ typename detail::enable_if_t<!(is_vgentype<FROM>::value &&
492
505
sizeof (TO) == sizeof (FROM),
493
506
TO>
494
507
convertDataToType (FROM t) {
495
- return TryToGetPointer (t);
508
+ return ConvertNonVectorType<TO> (t);
496
509
}
497
510
498
511
// Used for all, any and select relational built-in functions
0 commit comments