Skip to content

Commit fa47324

Browse files
committed
[SYCL] Add support for POD spec constants in RT
1 parent 5c4df90 commit fa47324

File tree

10 files changed

+199
-28
lines changed

10 files changed

+199
-28
lines changed

sycl/include/CL/sycl/ONEAPI/experimental/spec_constant.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#pragma once
1919

20+
#include <CL/sycl/detail/stl_type_traits.hpp>
2021
#include <CL/sycl/detail/sycl_fe_intrins.hpp>
2122
#include <CL/sycl/exception.hpp>
2223

@@ -45,7 +46,9 @@ template <typename T, typename ID = T> class spec_constant {
4546
friend class cl::sycl::program;
4647

4748
public:
48-
T get() const { // explicit access.
49+
template <typename V = T>
50+
typename sycl::detail::enable_if_t<std::is_arithmetic<V>::value, V>
51+
get() const { // explicit access.
4952
#ifdef __SYCL_DEVICE_ONLY__
5053
const char *TName = __builtin_unique_stable_name(ID);
5154
return __sycl_getSpecConstantValue<T>(TName);
@@ -54,6 +57,19 @@ template <typename T, typename ID = T> class spec_constant {
5457
#endif // __SYCL_DEVICE_ONLY__
5558
}
5659

60+
template <typename V = T>
61+
typename sycl::detail::enable_if_t<std::is_class<V>::value &&
62+
std::is_pod<V>::value,
63+
V>
64+
get() const { // explicit access.
65+
#ifdef __SYCL_DEVICE_ONLY__
66+
const char *TName = __builtin_unique_stable_name(ID);
67+
return __sycl_getCompositeSpecConstantValue<T>(TName);
68+
#else
69+
return Val;
70+
#endif // __SYCL_DEVICE_ONLY__
71+
}
72+
5773
operator T() const { // implicit conversion.
5874
return get();
5975
}

sycl/include/CL/sycl/detail/pi.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,12 @@ static const uint8_t PI_DEVICE_BINARY_OFFLOAD_KIND_SYCL = 4;
640640
/// Name must be consistent with
641641
/// PropertySetRegistry::SYCL_SPECIALIZATION_CONSTANTS defined in
642642
/// PropertySetIO.h
643-
#define __SYCL_PI_PROPERTY_SET_SPEC_CONST_MAP "SYCL/specialization constants"
643+
#define __SYCL_PI_PROPERTY_SET_SCALAR_SPEC_CONST_MAP \
644+
"SYCL/specialization constants"
645+
/// PropertySetRegistry::SYCL_COMPOSITE_SPECIALIZATION_CONSTANTS defined in
646+
/// PropertySetIO.h
647+
#define __SYCL_PI_PROPERTY_SET_COMPOSITE_SPEC_CONST_MAP \
648+
"SYCL/composite specialization constants"
644649
/// PropertySetRegistry::SYCL_DEVICELIB_REQ_MASK defined in PropertySetIO.h
645650
#define __SYCL_PI_PROPERTY_SET_DEVICELIB_REQ_MASK "SYCL/devicelib req mask"
646651
/// PropertySetRegistry::SYCL_KERNEL_PARAM_OPT_INFO defined in PropertySetIO.h

sycl/include/CL/sycl/detail/pi.hpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -359,11 +359,23 @@ class DeviceBinaryImage {
359359
return Format;
360360
}
361361

362-
/// Gets the iterator range over specialization constants in this this binary
363-
/// image. For each property pointed to by an iterator within the range, the
364-
/// name of the property is the specializaion constant symbolic ID and the
365-
/// value is 32-bit unsigned integer ID.
366-
const PropertyRange &getSpecConstants() const { return SpecConstIDMap; }
362+
/// Gets the iterator range over scalar specialization constants in this this
363+
/// binary image. For each property pointed to by an iterator within the
364+
/// range, the name of the property is the specializaion constant symbolic ID
365+
/// and the value is 32-bit unsigned integer ID.
366+
const PropertyRange &getScalarSpecConstants() const {
367+
return ScalarSpecConstIDMap;
368+
}
369+
/// Gets the iterator range over composite specialization constants in this
370+
/// this binary image. For each property pointed to by an iterator within the
371+
/// range, the name of the property is the specializaion constant symbolic ID
372+
/// and the value is a list of tuples of 32-bit unsigned integer values, which
373+
/// encode scalar specialization constants, that form a composite one.
374+
/// Each tuple consist of ID of scalar specialization constant, its location
375+
/// within a composite (offset in bytes from the beginning) and its size.
376+
const PropertyRange &getCompositeSpecConstants() const {
377+
return CompositeSpecConstIDMap;
378+
}
367379
const PropertyRange &getDeviceLibReqMask() const { return DeviceLibReqMask; }
368380
const PropertyRange &getKernelParamOptInfo() const {
369381
return KernelParamOptInfo;
@@ -376,7 +388,8 @@ class DeviceBinaryImage {
376388

377389
pi_device_binary Bin;
378390
pi::PiDeviceBinaryType Format = PI_DEVICE_BINARY_TYPE_NONE;
379-
DeviceBinaryImage::PropertyRange SpecConstIDMap;
391+
DeviceBinaryImage::PropertyRange ScalarSpecConstIDMap;
392+
DeviceBinaryImage::PropertyRange CompositeSpecConstIDMap;
380393
DeviceBinaryImage::PropertyRange DeviceLibReqMask;
381394
DeviceBinaryImage::PropertyRange KernelParamOptInfo;
382395
};

sycl/include/CL/sycl/detail/sycl_fe_intrins.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,7 @@
1818
template <typename T>
1919
SYCL_EXTERNAL T __sycl_getSpecConstantValue(const char *ID);
2020

21+
template <typename T>
22+
SYCL_EXTERNAL T __sycl_getCompositeSpecConstantValue(const char *ID);
23+
2124
#endif

sycl/include/CL/sycl/program.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,8 @@ class __SYCL_EXPORT program {
343343
template <typename ID, typename T>
344344
ONEAPI::experimental::spec_constant<T, ID> set_spec_constant(T Cst) {
345345
constexpr const char *Name = detail::SpecConstantInfo<ID>::getName();
346-
static_assert(std::is_integral<T>::value ||
347-
std::is_floating_point<T>::value,
346+
static_assert(std::is_arithmetic<T>::value ||
347+
(std::is_class<T>::value && std::is_pod<T>::value),
348348
"unsupported specialization constant type");
349349
#ifdef __SYCL_DEVICE_ONLY__
350350
(void)Cst;

sycl/source/detail/pi.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,9 @@ void DeviceBinaryImage::init(pi_device_binary Bin) {
591591
// try to determine the format; may remain "NONE"
592592
Format = getBinaryImageFormat(Bin->BinaryStart, getSize());
593593

594-
SpecConstIDMap.init(Bin, __SYCL_PI_PROPERTY_SET_SPEC_CONST_MAP);
594+
ScalarSpecConstIDMap.init(Bin, __SYCL_PI_PROPERTY_SET_SCALAR_SPEC_CONST_MAP);
595+
CompositeSpecConstIDMap.init(Bin,
596+
__SYCL_PI_PROPERTY_SET_COMPOSITE_SPEC_CONST_MAP);
595597
DeviceLibReqMask.init(Bin, __SYCL_PI_PROPERTY_SET_DEVICELIB_REQ_MASK);
596598
KernelParamOptInfo.init(Bin, __SYCL_PI_PROPERTY_SET_KERNEL_PARAM_OPT_INFO);
597599
}

sycl/source/detail/program_impl.cpp

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -523,26 +523,55 @@ void program_impl::flush_spec_constants(const RTDeviceBinaryImage &Img,
523523
RT::PiProgram NativePrg) const {
524524
// iterate via all specialization constants the program's image depends on,
525525
// and set each to current runtime value (if any)
526-
const pi::DeviceBinaryImage::PropertyRange &SCRange = Img.getSpecConstants();
526+
const pi::DeviceBinaryImage::PropertyRange &ScalarSCRange =
527+
Img.getScalarSpecConstants();
528+
const pi::DeviceBinaryImage::PropertyRange &CompositeSCRange =
529+
Img.getCompositeSpecConstants();
527530
ContextImplPtr Ctx = getSyclObjImpl(get_context());
528531
using SCItTy = pi::DeviceBinaryImage::PropertyRange::ConstIterator;
529532

530533
auto LockGuard = Ctx->getKernelProgramCache().acquireCachedPrograms();
534+
NativePrg = NativePrg ? NativePrg : getHandleRef();
531535

532-
for (SCItTy SCIt : SCRange) {
533-
const char *SCName = (*SCIt)->Name;
534-
auto SCEntry = SpecConstRegistry.find(SCName);
536+
for (SCItTy SCIt : ScalarSCRange) {
537+
auto SCEntry = SpecConstRegistry.find((*SCIt)->Name);
535538
if (SCEntry == SpecConstRegistry.end())
536539
// spec constant has not been set in user code - SPIR-V will use default
537540
continue;
538541
const spec_constant_impl &SC = SCEntry->second;
539542
assert(SC.isSet() && "uninitialized spec constant");
540-
pi_device_binary_property SCProp = *SCIt;
541-
pi_uint32 ID = pi::DeviceBinaryProperty(SCProp).asUint32();
542-
NativePrg = NativePrg ? NativePrg : getHandleRef();
543+
pi_uint32 ID = pi::DeviceBinaryProperty(*SCIt).asUint32();
543544
Ctx->getPlugin().call<PiApiKind::piextProgramSetSpecializationConstant>(
544545
NativePrg, ID, SC.getSize(), SC.getValuePtr());
545546
}
547+
548+
for (SCItTy SCIt : CompositeSCRange) {
549+
auto SCEntry = SpecConstRegistry.find((*SCIt)->Name);
550+
if (SCEntry == SpecConstRegistry.end())
551+
// spec constant has not been set in user code - SPIR-V will use default
552+
continue;
553+
const spec_constant_impl &SC = SCEntry->second;
554+
assert(SC.isSet() && "uninitialized spec constant");
555+
pi::ByteArray Descriptors = pi::DeviceBinaryProperty(*SCIt).asByteArray();
556+
// First 8 bytes are consumed by size of the property
557+
assert(Descriptors.size() > 8 && "Unexpected property size");
558+
// Expected layout is vector of 3-component tuples (flattened into a vector
559+
// of scalars), where each tuple consists of: ID of a scalar spec constant,
560+
// which is a member of the composite; offset, which is used to calculate
561+
// location of scalar member within the composite; size of a scalar member
562+
// of the composite.
563+
assert(((Descriptors.size() - 8) / sizeof(std::uint32_t)) % 3 == 0 &&
564+
"unexpected layout of composite spec const descriptors");
565+
auto *It = reinterpret_cast<const std::uint32_t *>(&Descriptors[8]);
566+
auto *End = reinterpret_cast<const std::uint32_t *>(&Descriptors[0] +
567+
Descriptors.size());
568+
while (It != End) {
569+
Ctx->getPlugin().call<PiApiKind::piextProgramSetSpecializationConstant>(
570+
NativePrg, /* ID */ It[0], /* Size */ It[2],
571+
SC.getValuePtr() + /* Offset */ It[1]);
572+
It += 3;
573+
}
574+
}
546575
}
547576

548577
pi_native_handle program_impl::getNative() const {

sycl/source/detail/spec_constant_impl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ namespace sycl {
2121
namespace detail {
2222

2323
void spec_constant_impl::set(size_t Size, const void *Val) {
24-
if ((Size > sizeof(Bytes)) || (Size == 0))
24+
if (0 == Size)
2525
throw sycl::runtime_error("invalid spec constant size", PI_INVALID_VALUE);
26-
this->Size = Size;
27-
std::memcpy(Bytes, Val, Size);
26+
auto *BytePtr = reinterpret_cast<const char *>(Val);
27+
this->Bytes.assign(BytePtr, BytePtr + Size);
2828
}
2929

3030
void stableSerializeSpecConstRegistry(const SpecConstRegistryT &Reg,

sycl/source/detail/spec_constant_impl.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <CL/sycl/stl.hpp>
1414

1515
#include <iostream>
16+
#include <vector>
1617
#include <map>
1718

1819
__SYCL_INLINE_NAMESPACE(cl) {
@@ -22,20 +23,18 @@ namespace detail {
2223
// Represents a specialization constant value in SYCL runtime.
2324
class spec_constant_impl {
2425
public:
25-
spec_constant_impl() : Size(0), Bytes{0} {};
26+
spec_constant_impl() = default;
2627

2728
spec_constant_impl(size_t Size, const void *Val) { set(Size, Val); }
2829

2930
void set(size_t Size, const void *Val);
3031

31-
size_t getSize() const { return Size; }
32-
const unsigned char *getValuePtr() const { return Bytes; }
33-
bool isSet() const { return Size != 0; }
32+
size_t getSize() const { return Bytes.size(); }
33+
const char *getValuePtr() const { return Bytes.data(); }
34+
bool isSet() const { return !Bytes.empty(); }
3435

3536
private:
36-
size_t Size; // the size of the spec constant value
37-
// TODO invent more flexible approach to support values of arbitrary type:
38-
unsigned char Bytes[8]; // memory to hold the value bytes
37+
std::vector<char> Bytes;
3938
};
4039

4140
std::ostream &operator<<(std::ostream &Out, const spec_constant_impl &V);

sycl/test/composite-spec-const.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// UNSUPPORTED: cuda
2+
//
3+
// TODO: this test is disabled because we need two more patches (to
4+
// sycl-post-link and llvm-spirv) to appear in the repo in order to get this
5+
// feature working.
6+
// RUN: true
7+
// RUNx: %clangxx -fsycl %s -o %t.out
8+
// RUNx: env SYCL_DEVICE_TYPE=HOST %t.out
9+
// RUNx: %CPU_RUN_PLACEHOLDER %t.out
10+
// RUNx: %GPU_RUN_PLACEHOLDER %t.out
11+
//
12+
// The test checks that the specialization constant feature works correctly with
13+
// composite types: toolchain processes them correctly and runtime can correctly
14+
// execute the program.
15+
16+
#include <CL/sycl.hpp>
17+
18+
#include <iostream>
19+
#include <vector>
20+
21+
using namespace sycl;
22+
class Test;
23+
24+
struct A {
25+
int a;
26+
float b;
27+
};
28+
29+
struct POD {
30+
A a[2];
31+
int b;
32+
};
33+
34+
using MyPODConst = POD;
35+
36+
int global_val = 10;
37+
38+
// Fetch a value at runtime.
39+
int get_value() { return global_val; }
40+
41+
int main(int argc, char **argv) {
42+
cl::sycl::queue q(default_selector{}, [](exception_list l) {
43+
for (auto ep : l) {
44+
try {
45+
std::rethrow_exception(ep);
46+
} catch (cl::sycl::exception &e0) {
47+
std::cout << e0.what();
48+
} catch (std::exception &e1) {
49+
std::cout << e1.what();
50+
} catch (...) {
51+
std::cout << "*** catch (...)\n";
52+
}
53+
}
54+
});
55+
56+
std::cout << "Running on " << q.get_device().get_info<info::device::name>()
57+
<< "\n";
58+
std::cout << "global_val = " << global_val << "\n";
59+
cl::sycl::program program(q.get_context());
60+
61+
int goldi = (int)get_value();
62+
float goldf = (float)get_value();
63+
64+
POD gold = {{{goldi, goldf}, {goldi, goldf}}, goldi};
65+
66+
cl::sycl::ONEAPI::experimental::spec_constant<POD, MyPODConst> pod =
67+
program.set_spec_constant<MyPODConst>(gold);
68+
69+
program.build_with_kernel_type<Test>();
70+
71+
POD result;
72+
try {
73+
cl::sycl::buffer<POD, 1> bufi(&result, 1);
74+
75+
q.submit([&](cl::sycl::handler &cgh) {
76+
auto acci = bufi.get_access<cl::sycl::access::mode::write>(cgh);
77+
cgh.single_task<Test>(
78+
program.get_kernel<Test>(),
79+
[=]() {
80+
acci[0] = pod.get();
81+
});
82+
});
83+
} catch (cl::sycl::exception &e) {
84+
std::cout << "*** Exception caught: " << e.what() << "\n";
85+
return 1;
86+
}
87+
88+
bool passed = false;
89+
90+
std::cout << result.a[0].a << " " << result.a[0].b << "\n";
91+
std::cout << result.a[1].a << " " << result.a[1].b << "\n";
92+
std::cout << result.b << "\n\n";
93+
94+
std::cout << gold.a[0].a << " " << gold.a[0].b << "\n";
95+
std::cout << gold.a[1].a << " " << gold.a[1].b << "\n";
96+
std::cout << gold.b << "\n\n";
97+
98+
if (0 == std::memcmp(&result, &gold, sizeof(POD))) {
99+
passed = true;
100+
}
101+
102+
std::cout << (passed ? "passed\n" : "FAILED\n");
103+
return passed ? 0 : 1;
104+
}

0 commit comments

Comments
 (0)