Skip to content

Commit a5e4759

Browse files
cuda : optimize argmax (#10441)
* cuda : optimize argmax * remove unused parameter ggml-ci * fixup : use full warps ggml-ci * Apply suggestions from code review Co-authored-by: Johannes Gäßler <[email protected]> * fix ub * ggml : check ne00 <= INT32_MAX in argmax and argsort --------- Co-authored-by: Johannes Gäßler <[email protected]>
1 parent 1bb30bf commit a5e4759

File tree

5 files changed

+104
-61
lines changed

5 files changed

+104
-61
lines changed

ggml/src/ggml-cuda/argmax.cu

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,69 @@
1-
#include "common.cuh"
1+
#include <algorithm>
2+
#include <cstdint>
3+
24
#include "argmax.cuh"
5+
#include "common.cuh"
36
#include "sum.cuh"
47

5-
#include <cstdint>
8+
static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
9+
const int64_t row = blockIdx.x;
610

7-
static __global__ void argmax_f32(
8-
const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
11+
float maxval = -FLT_MAX;
12+
int argmax = -1;
13+
const float * rowx = x + row * ncols;
914

10-
int argmax_thread = 0;
11-
const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE;
15+
for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
16+
const float val = rowx[col];
17+
if (val > maxval) {
18+
maxval = val;
19+
argmax = col;
20+
}
21+
}
1222

1323
#pragma unroll
14-
for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) {
15-
const int64_t row = row0 + row1;
16-
17-
if (row >= nrows) {
18-
break;
24+
for (int offset = 16; offset > 0; offset >>= 1) {
25+
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
26+
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
27+
if (val > maxval) {
28+
maxval = val;
29+
argmax = col;
1930
}
31+
}
2032

21-
float maxval = -FLT_MAX;
22-
int argmax = -1;
23-
24-
for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) {
25-
const float val = x[row*ncols + col];
26-
const int bigger = val > maxval;
27-
const int not_bigger = bigger ^ 0x00000001;
28-
29-
maxval = maxval*not_bigger + val*bigger;
30-
argmax = argmax*not_bigger + col*bigger;
33+
const int n_warps = blockDim.x / WARP_SIZE;
34+
const int lane_id = threadIdx.x % WARP_SIZE;
35+
const int warp_id = threadIdx.x / WARP_SIZE;
36+
if (n_warps > 1) {
37+
constexpr int max_warps = 1024 / WARP_SIZE;
38+
__shared__ float shared_maxval[max_warps];
39+
__shared__ int shared_argmax[max_warps];
40+
if (lane_id == 0) {
41+
shared_maxval[warp_id] = maxval;
42+
shared_argmax[warp_id] = argmax;
3143
}
3244

45+
__syncthreads();
46+
47+
if (warp_id == 0) {
48+
if (lane_id < n_warps) {
49+
maxval = shared_maxval[lane_id];
50+
argmax = shared_argmax[lane_id];
51+
}
3352
#pragma unroll
34-
for (int mask = 16; mask > 0; mask >>= 1) {
35-
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
36-
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
37-
const int bigger = val > maxval;
38-
const int not_bigger = bigger ^ 0x00000001;
39-
40-
maxval = maxval*not_bigger + val*bigger;
41-
argmax = argmax*not_bigger + col*bigger;
53+
for (int offset = 16; offset > 0; offset >>= 1) {
54+
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
55+
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
56+
if (val > maxval) {
57+
maxval = val;
58+
argmax = col;
59+
}
60+
}
4261
}
43-
44-
const int store = row1 == threadIdx.x;
45-
argmax_thread += store*argmax;
4662
}
4763

48-
const int row = row0 + threadIdx.x;
49-
50-
if (row >= nrows) {
51-
return;
64+
if (warp_id == 0 && lane_id == 0) {
65+
dst[row] = argmax;
5266
}
53-
54-
dst[row] = argmax_thread;
5567
}
5668

5769
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -70,10 +82,10 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
7082

7183
cudaStream_t stream = ctx.stream();
7284

73-
const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE;
74-
75-
const dim3 blocks_dim(WARP_SIZE, 1, 1);
85+
const int64_t num_blocks = nrows;
86+
const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
87+
const dim3 blocks_dim(num_threads, 1, 1);
7688
const dim3 blocks_num(num_blocks, 1, 1);
7789

78-
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);
90+
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00);
7991
}

ggml/src/ggml-cuda/common.cuh

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -180,26 +180,26 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
180180
return __reduce_add_sync(0xffffffff, x);
181181
#else
182182
#pragma unroll
183-
for (int mask = 16; mask > 0; mask >>= 1) {
184-
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
183+
for (int offset = 16; offset > 0; offset >>= 1) {
184+
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
185185
}
186186
return x;
187187
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
188188
}
189189

190190
static __device__ __forceinline__ float warp_reduce_sum(float x) {
191191
#pragma unroll
192-
for (int mask = 16; mask > 0; mask >>= 1) {
193-
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
192+
for (int offset = 16; offset > 0; offset >>= 1) {
193+
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
194194
}
195195
return x;
196196
}
197197

198198
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
199199
#pragma unroll
200-
for (int mask = 16; mask > 0; mask >>= 1) {
201-
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
202-
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
200+
for (int offset = 16; offset > 0; offset >>= 1) {
201+
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
202+
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
203203
}
204204
return a;
205205
}
@@ -209,16 +209,16 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
209209

210210
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
211211
#pragma unroll
212-
for (int mask = 16; mask > 0; mask >>= 1) {
213-
const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
212+
for (int offset = 16; offset > 0; offset >>= 1) {
213+
const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
214214
reinterpret_cast<half&>(a.x) += __low2half(a_other);
215215
reinterpret_cast<half&>(a.y) += __high2half(a_other);
216216
}
217217
return a;
218218
#else
219219
#pragma unroll
220-
for (int mask = 16; mask > 0; mask >>= 1) {
221-
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
220+
for (int offset = 16; offset > 0; offset >>= 1) {
221+
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
222222
}
223223
return a;
224224
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
@@ -231,8 +231,8 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
231231

232232
static __device__ __forceinline__ float warp_reduce_max(float x) {
233233
#pragma unroll
234-
for (int mask = 16; mask > 0; mask >>= 1) {
235-
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
234+
for (int offset = 16; offset > 0; offset >>= 1) {
235+
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
236236
}
237237
return x;
238238
}
@@ -275,8 +275,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
275275
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
276276
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
277277
#pragma unroll
278-
for (int mask = 16; mask > 0; mask >>= 1) {
279-
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
278+
for (int offset = 16; offset > 0; offset >>= 1) {
279+
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
280280
}
281281
return x;
282282
#else

ggml/src/ggml-cuda/quantize.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ static __global__ void quantize_mmq_q8_1(
6969

7070
// Exchange max. abs. value between vals_per_scale/4 threads.
7171
#pragma unroll
72-
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
73-
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
72+
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
73+
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
7474
}
7575

7676
float sum;
@@ -79,8 +79,8 @@ static __global__ void quantize_mmq_q8_1(
7979

8080
// Exchange calculate sum across vals_per_sum/4 threads.
8181
#pragma unroll
82-
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
83-
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
82+
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
83+
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
8484
}
8585
}
8686

ggml/src/ggml.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2255,6 +2255,7 @@ struct ggml_tensor * ggml_argmax(
22552255
struct ggml_context * ctx,
22562256
struct ggml_tensor * a) {
22572257
GGML_ASSERT(ggml_is_matrix(a));
2258+
GGML_ASSERT(a->ne[0] <= INT32_MAX);
22582259

22592260
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
22602261

@@ -4138,6 +4139,7 @@ struct ggml_tensor * ggml_argsort(
41384139
struct ggml_context * ctx,
41394140
struct ggml_tensor * a,
41404141
enum ggml_sort_order order) {
4142+
GGML_ASSERT(a->ne[0] <= INT32_MAX);
41414143
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
41424144

41434145
ggml_set_op_params_i32(result, 0, (int32_t) order);

tests/test-backend-ops.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,26 @@ struct test_argmax : public test_case {
11541154
return out;
11551155
}
11561156

1157+
void initialize_tensors(ggml_context * ctx) override {
1158+
std::random_device rd;
1159+
std::default_random_engine rng(rd());
1160+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1161+
if (t->type == GGML_TYPE_F32) {
1162+
// initialize with unique values to avoid ties
1163+
for (int64_t r = 0; r < ggml_nrows(t); r++) {
1164+
std::vector<float> data(t->ne[0]);
1165+
for (int i = 0; i < t->ne[0]; i++) {
1166+
data[i] = i;
1167+
}
1168+
std::shuffle(data.begin(), data.end(), rng);
1169+
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
1170+
}
1171+
} else {
1172+
init_tensor_uniform(t);
1173+
}
1174+
}
1175+
}
1176+
11571177
double max_nmse_err() override {
11581178
return 0.0;
11591179
}
@@ -3440,6 +3460,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
34403460
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
34413461

34423462
test_cases.emplace_back(new test_argmax());
3463+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
3464+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
3465+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
3466+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
3467+
34433468
test_cases.emplace_back(new test_count_equal());
34443469

34453470
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
@@ -3830,6 +3855,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
38303855
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f));
38313856
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f));
38323857

3858+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
3859+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
3860+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
3861+
38333862
for (int bs : {1, 512}) {
38343863
for (ggml_type type_a : all_types) {
38353864
for (ggml_type type_b : {GGML_TYPE_F32}) {

0 commit comments

Comments
 (0)