@@ -386,19 +386,7 @@ class __SYCL_EXPORT handler {
386
386
static_cast <int >(AccessTarget), ArgIndex);
387
387
}
388
388
389
- template <typename T> struct ShouldEnableSetArgHelper {
390
- static constexpr bool value = std::is_trivially_copyable<T>::value
391
- #ifdef CL_SYCL_LANGUAGE_VERSION
392
- #if CL_SYCL_LANGUAGE_VERSION <= 121
393
- && std::is_standard_layout<T>::value
394
- #endif
395
- #endif
396
- ;
397
- };
398
-
399
- template <typename T>
400
- typename std::enable_if<ShouldEnableSetArgHelper<T>::value, void >::type
401
- setArgHelper (int ArgIndex, T &&Arg) {
389
+ template <typename T> void setArgHelper (int ArgIndex, T &&Arg) {
402
390
void *StoredArg = (void *)storePlainArg (Arg);
403
391
404
392
if (!std::is_same<cl_mem, T>::value && std::is_pointer<T>::value) {
@@ -808,13 +796,39 @@ class __SYCL_EXPORT handler {
808
796
}
809
797
}
810
798
799
+ template <typename T>
800
+ using remove_cv_ref_t =
801
+ typename std::remove_cv<detail::remove_reference_t <T>>::type;
802
+
803
+ template <typename T> struct ShouldEnableSetArg {
804
+ static constexpr bool value =
805
+ std::is_trivially_copyable<T>::value
806
+ #if CL_SYCL_LANGUAGE_VERSION && CL_SYCL_LANGUAGE_VERSION <= 121
807
+ && std::is_standard_layout<T>::value
808
+ #endif
809
+ || std::is_same<sampler, remove_cv_ref_t <T>>::value // Sampler
810
+ || (!std::is_same<cl_mem, remove_cv_ref_t <T>>::value &&
811
+ std::is_pointer<remove_cv_ref_t <T>>::value) // USM
812
+ || std::is_same<cl_mem, remove_cv_ref_t <T>>::value; // Interop
813
+ };
814
+
811
815
// / Sets argument for OpenCL interoperability kernels.
812
816
// /
813
817
// / Registers Arg passed as argument # ArgIndex.
814
818
// /
815
819
// / \param ArgIndex is a positional number of argument to be set.
816
820
// / \param Arg is an argument value to be set.
817
- template <typename T> void set_arg (int ArgIndex, T &&Arg) {
821
+ template <typename T>
822
+ typename std::enable_if<ShouldEnableSetArg<T>::value, void >::type
823
+ set_arg (int ArgIndex, T &&Arg) {
824
+ setArgHelper (ArgIndex, std::move (Arg));
825
+ }
826
+
827
+ template <typename DataT, int Dims, access::mode AccessMode,
828
+ access::target AccessTarget, access::placeholder IsPlaceholder>
829
+ void
830
+ set_arg (int ArgIndex,
831
+ accessor<DataT, Dims, AccessMode, AccessTarget, IsPlaceholder> Arg) {
818
832
setArgHelper (ArgIndex, std::move (Arg));
819
833
}
820
834
0 commit comments