Skip to content

Commit 4993646

Browse files
authored
[SYCL][USM] Initial commit of USM fill operation (#2305)
Implements the fill operation for USM as defined in SYCL 2020 provisional. Signed-off-by: James Brodman <[email protected]>
1 parent 3084982 commit 4993646

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

sycl/include/CL/sycl/handler.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ template <typename DataT, int Dimensions, cl::sycl::access::mode AccessMode,
3838
cl::sycl::access::placeholder IsPlaceholder>
3939
class __fill;
4040

41+
template <typename T> class __usmfill;
42+
4143
template <typename T_Src, typename T_Dst, int Dims,
4244
cl::sycl::access::mode AccessMode,
4345
cl::sycl::access::target AccessTarget,
@@ -1713,6 +1715,22 @@ class __SYCL_EXPORT handler {
17131715
}
17141716
}
17151717

1718+
/// Fills the specified memory with the specified pattern.
1719+
///
1720+
/// \param Ptr is the pointer to the memory to fill
1721+
/// \param Pattern is the pattern to fill into the memory. T should be
1722+
/// trivially copyable.
1723+
/// \param Count is the number of times to fill Pattern into Ptr.
1724+
template <typename T> void fill(void *Ptr, const T &Pattern, size_t Count) {
1725+
throwIfActionIsCreated();
1726+
static_assert(std::is_trivially_copyable<T>::value,
1727+
"Pattern must be trivially copyable");
1728+
parallel_for<class __usmfill<T>>(range<1>(Count), [=](id<1> Index) {
1729+
T *CastedPtr = static_cast<T *>(Ptr);
1730+
CastedPtr[Index] = Pattern;
1731+
});
1732+
}
1733+
17161734
/// Prevents any commands submitted afterward to this queue from executing
17171735
/// until all commands previously submitted to this queue have entered the
17181736
/// complete state.

sycl/include/CL/sycl/queue.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,16 @@ class __SYCL_EXPORT queue {
319319
/// property, an invalid_object_error SYCL exception.
320320
template <typename PropertyT> PropertyT get_property() const;
321321

322+
/// Fills the specified memory with the specified pattern.
323+
///
324+
/// \param Ptr is the pointer to the memory to fill
325+
/// \param Pattern is the pattern to fill into the memory. T should be
326+
/// trivially copyable.
327+
/// \param Count is the number of times to fill Pattern into Ptr.
328+
template <typename T> event fill(void *Ptr, const T &Pattern, size_t Count) {
329+
return submit([&](handler &CGH) { CGH.fill<T>(Ptr, Pattern, Count); });
330+
}
331+
322332
/// Fills the memory pointed by a USM pointer with the value specified.
323333
/// No operations is done if \param Count is zero. An exception is thrown
324334
/// if \param Dest is nullptr. The behavior is undefined if \param Ptr

sycl/test/usm/fill.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//==---- fill.cpp - USM fill test ------------------------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// XFAIL: cuda
9+
// piextUSM*Alloc functions for CUDA are not behaving as described in
10+
//
11+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t1.out
12+
// RUN: env SYCL_DEVICE_TYPE=HOST %t1.out
13+
// RUN: %CPU_RUN_PLACEHOLDER %t1.out
14+
// RUN: %GPU_RUN_PLACEHOLDER %t1.out
15+
16+
#include <CL/sycl.hpp>
17+
18+
using namespace cl::sycl;
19+
20+
constexpr int count = 100;
21+
constexpr int pattern = 42;
22+
23+
int main() {
24+
queue q;
25+
if (q.get_device().get_info<info::device::usm_shared_allocations>()) {
26+
int *mem = malloc_shared<int>(count, q);
27+
28+
for (int i = 0; i < count; i++)
29+
mem[i] = 0;
30+
31+
q.fill(mem, pattern, count);
32+
q.wait();
33+
34+
for (int i = 0; i < count; i++) {
35+
assert(mem[i] == pattern);
36+
}
37+
}
38+
std::cout << "Passed\n";
39+
return 0;
40+
}

0 commit comments

Comments
 (0)