@@ -150,35 +150,33 @@ __gpu_shuffle_idx_f64(uint64_t __lane_mask, uint32_t __idx, double __x,
150
150
__builtin_bit_cast (uint64_t , __x ), __width ));
151
151
}
152
152
153
- // Gets the sum of all lanes inside the warp or wavefront.
154
- #define __DO_LANE_SUM (__type , __suffix ) \
155
- _DEFAULT_FN_ATTRS static __inline__ __type __gpu_lane_sum_##__suffix( \
156
- uint64_t __lane_mask, __type __x) { \
157
- for (uint32_t __step = __gpu_num_lanes() / 2; __step > 0; __step /= 2) { \
158
- uint32_t __index = __step + __gpu_lane_id(); \
159
- __x += __gpu_shuffle_idx_##__suffix(__lane_mask, __index, __x, \
160
- __gpu_num_lanes()); \
161
- } \
162
- return __gpu_read_first_lane_##__suffix(__lane_mask, __x); \
163
- }
164
- __DO_LANE_SUM (uint32_t , u32 ); // uint32_t __gpu_lane_sum_u32(m, x)
165
- __DO_LANE_SUM (uint64_t , u64 ); // uint64_t __gpu_lane_sum_u64(m, x)
166
- __DO_LANE_SUM (float , f32 ); // float __gpu_lane_sum_f32(m, x)
167
- __DO_LANE_SUM (double , f64 ); // double __gpu_lane_sum_f64(m, x)
168
- #undef __DO_LANE_SUM
169
-
170
153
// Gets the accumulator scan of the threads in the warp or wavefront.
171
154
#define __DO_LANE_SCAN (__type , __bitmask_type , __suffix ) \
172
155
_DEFAULT_FN_ATTRS static __inline__ uint32_t __gpu_lane_scan_##__suffix( \
173
156
uint64_t __lane_mask, uint32_t __x) { \
174
- for (uint32_t __step = 1; __step < __gpu_num_lanes(); __step *= 2) { \
175
- uint32_t __index = __gpu_lane_id() - __step; \
176
- __bitmask_type bitmask = __gpu_lane_id() >= __step; \
177
- __x += __builtin_bit_cast( \
178
- __type, -bitmask & __builtin_bit_cast(__bitmask_type, \
179
- __gpu_shuffle_idx_##__suffix( \
180
- __lane_mask, __index, __x, \
181
- __gpu_num_lanes()))); \
157
+ uint64_t __first = __lane_mask >> __builtin_ctzll(__lane_mask); \
158
+ bool __divergent = __gpu_read_first_lane_##__suffix( \
159
+ __lane_mask, __first & (__first + 1)); \
160
+ if (__divergent) { \
161
+ __type __accum = 0; \
162
+ for (uint64_t __mask = __lane_mask; __mask; __mask &= __mask - 1) { \
163
+ __type __index = __builtin_ctzll(__mask); \
164
+ __type __tmp = __gpu_shuffle_idx_##__suffix(__lane_mask, __index, __x, \
165
+ __gpu_num_lanes()); \
166
+ __x = __gpu_lane_id() == __index ? __accum + __tmp : __x; \
167
+ __accum += __tmp; \
168
+ } \
169
+ } else { \
170
+ for (uint32_t __step = 1; __step < __gpu_num_lanes(); __step *= 2) { \
171
+ uint32_t __index = __gpu_lane_id() - __step; \
172
+ __bitmask_type bitmask = __gpu_lane_id() >= __step; \
173
+ __x += __builtin_bit_cast( \
174
+ __type, \
175
+ -bitmask & __builtin_bit_cast(__bitmask_type, \
176
+ __gpu_shuffle_idx_##__suffix( \
177
+ __lane_mask, __index, __x, \
178
+ __gpu_num_lanes()))); \
179
+ } \
182
180
} \
183
181
return __x; \
184
182
}
@@ -188,6 +186,32 @@ __DO_LANE_SCAN(float, uint32_t, f32); // float __gpu_lane_scan_f32(m, x)
188
186
__DO_LANE_SCAN (double , uint64_t , f64 ); // double __gpu_lane_scan_f64(m, x)
189
187
#undef __DO_LANE_SCAN
190
188
189
+ // Gets the sum of all lanes inside the warp or wavefront.
190
+ #define __DO_LANE_SUM (__type , __suffix ) \
191
+ _DEFAULT_FN_ATTRS static __inline__ __type __gpu_lane_sum_##__suffix( \
192
+ uint64_t __lane_mask, __type __x) { \
193
+ uint64_t __first = __lane_mask >> __builtin_ctzll(__lane_mask); \
194
+ bool __divergent = __gpu_read_first_lane_##__suffix( \
195
+ __lane_mask, __first & (__first + 1)); \
196
+ if (__divergent) { \
197
+ return __gpu_shuffle_idx_##__suffix( \
198
+ __lane_mask, 63 - __builtin_clzll(__lane_mask), \
199
+ __gpu_lane_scan_##__suffix(__lane_mask, __x), __gpu_num_lanes()); \
200
+ } else { \
201
+ for (uint32_t __step = 1; __step < __gpu_num_lanes(); __step *= 2) { \
202
+ uint32_t __index = __step + __gpu_lane_id(); \
203
+ __x += __gpu_shuffle_idx_##__suffix(__lane_mask, __index, __x, \
204
+ __gpu_num_lanes()); \
205
+ } \
206
+ return __gpu_read_first_lane_##__suffix(__lane_mask, __x); \
207
+ } \
208
+ }
209
+ __DO_LANE_SUM (uint32_t , u32 ); // uint32_t __gpu_lane_sum_u32(m, x)
210
+ __DO_LANE_SUM (uint64_t , u64 ); // uint64_t __gpu_lane_sum_u64(m, x)
211
+ __DO_LANE_SUM (float , f32 ); // float __gpu_lane_sum_f32(m, x)
212
+ __DO_LANE_SUM (double , f64 ); // double __gpu_lane_sum_f64(m, x)
213
+ #undef __DO_LANE_SUM
214
+
191
215
_Pragma ("omp end declare variant" );
192
216
_Pragma ("omp end declare target" );
193
217
0 commit comments