Skip to content

Commit 0fa7542

Browse files
[SYCL] group operations update to use size() (#7589)
`get_size()` returns byte size. it is `size()` that is wanted in these group operations.
1 parent 65b7501 commit 0fa7542

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

sycl/include/sycl/detail/spirv.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ EnableIfNativeShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
569569
template <typename T>
570570
EnableIfVectorShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
571571
T result;
572-
for (int s = 0; s < x.get_size(); ++s) {
572+
for (int s = 0; s < x.size(); ++s) {
573573
result[s] = SubgroupShuffle(x[s], local_id);
574574
}
575575
return result;
@@ -578,7 +578,7 @@ EnableIfVectorShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
578578
template <typename T>
579579
EnableIfVectorShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
580580
T result;
581-
for (int s = 0; s < x.get_size(); ++s) {
581+
for (int s = 0; s < x.size(); ++s) {
582582
result[s] = SubgroupShuffleXor(x[s], local_id);
583583
}
584584
return result;
@@ -587,7 +587,7 @@ EnableIfVectorShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
587587
template <typename T>
588588
EnableIfVectorShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
589589
T result;
590-
for (int s = 0; s < x.get_size(); ++s) {
590+
for (int s = 0; s < x.size(); ++s) {
591591
result[s] = SubgroupShuffleDown(x[s], delta);
592592
}
593593
return result;
@@ -596,7 +596,7 @@ EnableIfVectorShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
596596
template <typename T>
597597
EnableIfVectorShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
598598
T result;
599-
for (int s = 0; s < x.get_size(); ++s) {
599+
for (int s = 0; s < x.size(); ++s) {
600600
result[s] = SubgroupShuffleUp(x[s], delta);
601601
}
602602
return result;

sycl/include/sycl/ext/oneapi/group_algorithm.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ detail::enable_if_t<(detail::is_generic_group<Group>::value &&
167167
typename Group::id_type local_id) {
168168
#ifdef __SYCL_DEVICE_ONLY__
169169
T result;
170-
for (int s = 0; s < x.get_size(); ++s) {
170+
for (int s = 0; s < x.size(); ++s) {
171171
result[s] = broadcast(g, x[s], local_id);
172172
}
173173
return result;
@@ -212,7 +212,7 @@ detail::enable_if_t<(detail::is_generic_group<Group>::value &&
212212
linear_local_id) {
213213
#ifdef __SYCL_DEVICE_ONLY__
214214
T result;
215-
for (int s = 0; s < x.get_size(); ++s) {
215+
for (int s = 0; s < x.size(); ++s) {
216216
result[s] = broadcast(g, x[s], linear_local_id);
217217
}
218218
return result;
@@ -250,7 +250,7 @@ detail::enable_if_t<(detail::is_generic_group<Group>::value &&
250250
T> broadcast(Group g, T x) {
251251
#ifdef __SYCL_DEVICE_ONLY__
252252
T result;
253-
for (int s = 0; s < x.get_size(); ++s) {
253+
for (int s = 0; s < x.size(); ++s) {
254254
result[s] = broadcast(g, x[s]);
255255
}
256256
return result;

sycl/include/sycl/group_algorithm.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ reduce_over_group(Group g, T x, BinaryOperation binary_op) {
229229
std::is_same<decltype(binary_op(x[0], x[0])), float>::value),
230230
"Result type of binary_op must match reduction accumulation type.");
231231
T result;
232-
for (int s = 0; s < x.get_size(); ++s) {
232+
for (int s = 0; s < x.size(); ++s) {
233233
result[s] = reduce_over_group(g, x[s], binary_op);
234234
}
235235
return result;
@@ -280,7 +280,7 @@ reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
280280
"Result type of binary_op must match reduction accumulation type.");
281281
#ifdef __SYCL_DEVICE_ONLY__
282282
T result = init;
283-
for (int s = 0; s < x.get_size(); ++s) {
283+
for (int s = 0; s < x.size(); ++s) {
284284
result[s] = binary_op(init[s], reduce_over_group(g, x[s], binary_op));
285285
}
286286
return result;
@@ -656,7 +656,7 @@ exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
656656
std::is_same<decltype(binary_op(x[0], x[0])), float>::value),
657657
"Result type of binary_op must match scan accumulation type.");
658658
T result;
659-
for (int s = 0; s < x.get_size(); ++s) {
659+
for (int s = 0; s < x.size(); ++s) {
660660
result[s] = exclusive_scan_over_group(g, x[s], binary_op);
661661
}
662662
return result;
@@ -680,7 +680,7 @@ exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
680680
std::is_same<decltype(binary_op(init[0], x[0])), float>::value),
681681
"Result type of binary_op must match scan accumulation type.");
682682
T result;
683-
for (int s = 0; s < x.get_size(); ++s) {
683+
for (int s = 0; s < x.size(); ++s) {
684684
result[s] = exclusive_scan_over_group(g, x[s], init[s], binary_op);
685685
}
686686
return result;
@@ -823,7 +823,7 @@ inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
823823
std::is_same<decltype(binary_op(x[0], x[0])), float>::value),
824824
"Result type of binary_op must match scan accumulation type.");
825825
T result;
826-
for (int s = 0; s < x.get_size(); ++s) {
826+
for (int s = 0; s < x.size(); ++s) {
827827
result[s] = inclusive_scan_over_group(g, x[s], binary_op);
828828
}
829829
return result;
@@ -917,7 +917,7 @@ inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {
917917
std::is_same<decltype(binary_op(init[0], x[0])), float>::value),
918918
"Result type of binary_op must match scan accumulation type.");
919919
T result;
920-
for (int s = 0; s < x.get_size(); ++s) {
920+
for (int s = 0; s < x.size(); ++s) {
921921
result[s] = inclusive_scan_over_group(g, x[s], binary_op, init[s]);
922922
}
923923
return result;

0 commit comments

Comments
 (0)