23
23
24
24
class MyInt32Const ;
25
25
class MyFloatConst ;
26
+ class MyConst ;
26
27
27
28
using namespace sycl ;
28
29
29
30
class KernelAAAi ;
30
31
class KernelBBBf ;
31
32
32
- int val = 10 ;
33
+ int global_val = 10 ;
33
34
34
35
// Fetch a value at runtime.
35
- int get_value () { return val ; }
36
+ int get_value () { return global_val ; }
36
37
37
38
float foo (
38
39
const cl::sycl::ONEAPI::experimental::spec_constant<float , MyFloatConst>
@@ -49,8 +50,22 @@ struct SCWrapper {
49
50
cl::sycl::ONEAPI::experimental::spec_constant<int , class sc_name2 > SC2;
50
51
};
51
52
53
+ // MyKernel is used to test default constructor
54
+ using AccT = sycl::accessor<int , 1 , sycl::access::mode::write>;
55
+ using ScT = sycl::ONEAPI::experimental::spec_constant<int , MyConst>;
56
+
57
+ struct MyKernel {
58
+ MyKernel (AccT &Acc) : Acc(Acc) {}
59
+
60
+ void setConst (ScT Sc) { this ->Sc = Sc; }
61
+
62
+ void operator ()() const { Acc[0 ] = Sc.get (); }
63
+ AccT Acc;
64
+ ScT Sc;
65
+ };
66
+
52
67
int main (int argc, char **argv) {
53
- val = argc + 16 ;
68
+ global_val = argc + 16 ;
54
69
55
70
cl::sycl::queue q (default_selector{}, [](exception_list l) {
56
71
for (auto ep : l) {
@@ -68,10 +83,11 @@ int main(int argc, char **argv) {
68
83
69
84
std::cout << " Running on " << q.get_device ().get_info <info::device::name>()
70
85
<< " \n " ;
71
- std::cout << " val = " << val << " \n " ;
86
+ std::cout << " global_val = " << global_val << " \n " ;
72
87
cl::sycl::program program1 (q.get_context ());
73
88
cl::sycl::program program2 (q.get_context ());
74
89
cl::sycl::program program3 (q.get_context ());
90
+ cl::sycl::program program4 (q.get_context ());
75
91
76
92
int goldi = (int )get_value ();
77
93
// TODO make this floating point once supported by the compiler
@@ -83,22 +99,30 @@ int main(int argc, char **argv) {
83
99
cl::sycl::ONEAPI::experimental::spec_constant<float , MyFloatConst> f32 =
84
100
program2.set_spec_constant <MyFloatConst>(goldf);
85
101
102
+ cl::sycl::ONEAPI::experimental::spec_constant<int , MyConst> sc =
103
+ program4.set_spec_constant <MyConst>(goldi);
104
+
86
105
program1.build_with_kernel_type <KernelAAAi>();
87
106
// Use an option (does not matter which exactly) to test different internal
88
107
// SYCL RT execution path
89
108
program2.build_with_kernel_type <KernelBBBf>(" -cl-fast-relaxed-math" );
90
109
91
110
SCWrapper W (program3);
92
111
program3.build_with_kernel_type <class KernelWrappedSC >();
112
+
113
+ program4.build_with_kernel_type <MyKernel>();
114
+
93
115
int goldw = 6 ;
94
116
95
117
std::vector<int > veci (1 );
96
118
std::vector<float > vecf (1 );
97
119
std::vector<int > vecw (1 );
120
+ std::vector<int > vec (1 );
98
121
try {
99
122
cl::sycl::buffer<int , 1 > bufi (veci.data (), veci.size ());
100
123
cl::sycl::buffer<float , 1 > buff (vecf.data (), vecf.size ());
101
124
cl::sycl::buffer<int , 1 > bufw (vecw.data (), vecw.size ());
125
+ cl::sycl::buffer<int , 1 > buf (vec.data (), vec.size ());
102
126
103
127
q.submit ([&](cl::sycl::handler &cgh) {
104
128
auto acci = bufi.get_access <cl::sycl::access::mode::write>(cgh);
@@ -123,6 +147,19 @@ int main(int argc, char **argv) {
123
147
program3.get_kernel <KernelWrappedSC>(),
124
148
[=]() { accw[0 ] = W.SC1 .get () + W.SC2 .get (); });
125
149
});
150
+ // Check spec_constant default construction with subsequent initialization
151
+ q.submit ([&](cl::sycl::handler &cgh) {
152
+ auto acc = buf.get_access <cl::sycl::access::mode::write>(cgh);
153
+ // Specialization constants specification says:
154
+ // cl::sycl::experimental::spec_constant is default constructible,
155
+ // although the object is not considered initialized until the result of
156
+ // the call to cl::sycl::program::set_spec_constant is assigned to it.
157
+ MyKernel Kernel (acc); // default construct inside MyKernel instance
158
+ Kernel.setConst (sc); // initialize to sc, returned by set_spec_constant
159
+
160
+ cgh.single_task <MyKernel>(program4.get_kernel <MyKernel>(), Kernel);
161
+ });
162
+
126
163
} catch (cl::sycl::exception &e) {
127
164
std::cout << " *** Exception caught: " << e.what () << " \n " ;
128
165
return 1 ;
@@ -146,6 +183,12 @@ int main(int argc, char **argv) {
146
183
std::cout << " *** ERROR: " << valw << " != " << goldw << " (gold)\n " ;
147
184
passed = false ;
148
185
}
186
+ int val = vec[0 ];
187
+
188
+ if (val != goldi) {
189
+ std::cout << " *** ERROR: " << val << " != " << goldi << " (gold)\n " ;
190
+ passed = false ;
191
+ }
149
192
std::cout << (passed ? " passed\n " : " FAILED\n " );
150
193
return passed ? 0 : 1 ;
151
194
}
0 commit comments