Skip to content

Commit 728b904

Browse files
rolandschulzbader
authored andcommitted
[SYCL] Add multi_ptr deduction guide
Signed-off-by: Roland Schulz <[email protected]>
1 parent 0c32410 commit 728b904

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

sycl/include/CL/sycl/multi_ptr.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,23 @@ class multi_ptr<const void, Space> {
503503
pointer_t m_Pointer;
504504
};
505505

506+
#ifdef __cpp_deduction_guides
507+
template <int dimensions, access::mode Mode, access::placeholder isPlaceholder,
508+
class T>
509+
multi_ptr(
510+
accessor<T, dimensions, Mode, access::target::global_buffer, isPlaceholder>)
511+
->multi_ptr<T, access::address_space::global_space>;
512+
template <int dimensions, access::mode Mode, access::placeholder isPlaceholder,
513+
class T>
514+
multi_ptr(accessor<T, dimensions, Mode, access::target::constant_buffer,
515+
isPlaceholder>)
516+
->multi_ptr<T, access::address_space::constant_space>;
517+
template <int dimensions, access::mode Mode, access::placeholder isPlaceholder,
518+
class T>
519+
multi_ptr(accessor<T, dimensions, Mode, access::target::local, isPlaceholder>)
520+
->multi_ptr<T, access::address_space::local_space>;
521+
#endif
522+
506523
template <typename ElementType, access::address_space Space>
507524
multi_ptr<ElementType, Space>
508525
make_ptr(typename multi_ptr<ElementType, Space>::pointer_t pointer) {

sycl/test/multi_ptr/ctad.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %clangxx -std=c++17 -fsyntax-only -Xclang -verify %s
2+
// expected-no-diagnostics
3+
//==--------------- ctad.cpp - SYCL multi_ptr CTAD test --------------------==//
4+
//
5+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6+
// See https://llvm.org/LICENSE.txt for license information.
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
#include <CL/sycl.hpp>
11+
12+
namespace sycl = cl::sycl;
13+
14+
int main() {
15+
using sycl::access::address_space;
16+
using sycl::access::mode;
17+
using sycl::access::target;
18+
using globlAcc = sycl::accessor<int, 1, mode::read, target::global_buffer>;
19+
using constAcc = sycl::accessor<int, 1, mode::read, target::constant_buffer>;
20+
using localAcc = sycl::accessor<int, 1, mode::read, target::local>;
21+
using globlCTAD = decltype(sycl::multi_ptr(std::declval<globlAcc>()));
22+
using constCTAD = decltype(sycl::multi_ptr(std::declval<constAcc>()));
23+
using localCTAD = decltype(sycl::multi_ptr(std::declval<localAcc>()));
24+
using globlMPtr = sycl::multi_ptr<int, address_space::global_space>;
25+
using constMPtr = sycl::multi_ptr<int, address_space::constant_space>;
26+
using localMPtr = sycl::multi_ptr<int, address_space::local_space>;
27+
static_assert(std::is_same_v<globlCTAD, globlMPtr>);
28+
static_assert(std::is_same_v<constCTAD, constMPtr>);
29+
static_assert(std::is_same_v<localCTAD, localMPtr>);
30+
}

0 commit comments

Comments
 (0)