Skip to content

Commit 5dec21d

Browse files
[SYCL] Add a new test case for functor with multiple call operators defined. (#10602)
1 parent 9718192 commit 5dec21d

File tree

1 file changed

+46
-3
lines changed

1 file changed

+46
-3
lines changed

sycl/test-e2e/Functor/kernel_functor.cpp

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,25 @@ class Functor2 {
5656
};
5757
} // namespace ns
5858

59-
// Case 2:
59+
// Case 3:
60+
// - functor class is defined in the translation unit scope.
61+
// - the functor has two call operators defined.
62+
63+
class FunctorMulti {
64+
public:
65+
FunctorMulti(int X_,
66+
sycl::accessor<int, 1, sycl_read_write, sycl_device> &Acc_)
67+
: X(X_), Acc(Acc_) {}
68+
69+
void operator()(sycl::id<1> id = 0) const { Acc[id] += X; }
70+
void operator()(sycl::id<2> id) const {}
71+
72+
private:
73+
int X;
74+
sycl::accessor<int, 1, sycl_read_write, sycl_device> Acc;
75+
};
76+
77+
// Case 4:
6078
// - functor class is templated and defined in the translation unit scope
6179
// - the '()' operator:
6280
// * has a parameter of type sycl::id<1> (to be used in 'parallel_for').
@@ -73,7 +91,7 @@ template <typename T> class TmplFunctor {
7391
sycl::accessor<T, 1, sycl_read_write, sycl_device> Acc;
7492
};
7593

76-
// Case 3:
94+
// Case 5:
7795
// - functor class is templated and defined in the translation unit scope
7896
// - the '()' operator:
7997
// * has a parameter of type sycl::id<1> (to be used in 'parallel_for').
@@ -156,14 +174,39 @@ template <typename T> T bar(T X) {
156174
return res;
157175
}
158176

177+
int multi(int X) {
178+
int A[] = {10};
179+
{
180+
sycl::queue Q;
181+
sycl::buffer<int, 1> Buf(A, 1);
182+
183+
Q.submit([&](sycl::handler &cgh) {
184+
auto Acc = Buf.get_access<sycl_read_write, sycl_device>(cgh);
185+
FunctorMulti F(X, Acc);
186+
cgh.parallel_for(sycl::range<1>(X), F);
187+
});
188+
}
189+
return A[0];
190+
}
191+
159192
int main() {
160193
const int Res1 = foo(10);
161194
const int Res2 = bar(10);
162195
const int Gold1 = 40;
163196
const int Gold2 = 80;
164-
165197
assert(Res1 == Gold1);
166198
assert(Res2 == Gold2);
167199

200+
sycl::queue deviceQueue;
201+
// This test case is currently enabled only for GPUs, and fails on CPU and
202+
// Accelerator RT.
203+
// TODO: Remove this conditional check after the RT issues in CPU and
204+
// Accelerator are fixed.
205+
if (deviceQueue.get_device().is_gpu()) {
206+
const int Res3 = multi(10);
207+
const int Gold3 = 20;
208+
assert(Res3 == Gold3);
209+
}
210+
168211
return 0;
169212
}

0 commit comments

Comments
 (0)