-
Notifications
You must be signed in to change notification settings - Fork 12.2k
cuda : optimize argmax #10441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cuda : optimize argmax #10441
Changes from all commits
35386e8
0a737d2
1e9447a
a734da7
316f3d3
48f94d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,69 @@ | ||
#include "common.cuh" | ||
#include <algorithm> | ||
#include <cstdint> | ||
|
||
#include "argmax.cuh" | ||
#include "common.cuh" | ||
#include "sum.cuh" | ||
|
||
#include <cstdint> | ||
static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) { | ||
const int64_t row = blockIdx.x; | ||
|
||
static __global__ void argmax_f32( | ||
const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) { | ||
float maxval = -FLT_MAX; | ||
int argmax = -1; | ||
const float * rowx = x + row * ncols; | ||
|
||
int argmax_thread = 0; | ||
const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE; | ||
for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) { | ||
const float val = rowx[col]; | ||
if (val > maxval) { | ||
maxval = val; | ||
argmax = col; | ||
} | ||
} | ||
|
||
#pragma unroll | ||
for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) { | ||
const int64_t row = row0 + row1; | ||
|
||
if (row >= nrows) { | ||
break; | ||
for (int offset = 16; offset > 0; offset >>= 1) { | ||
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); | ||
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); | ||
if (val > maxval) { | ||
maxval = val; | ||
argmax = col; | ||
} | ||
} | ||
Comment on lines
+27
to
30
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In retrospect it probably makes more sense to do it like this; conditional statements are problematic for code optimization since they prevent the compiler from reordering instructions but there isn't much to do in one loop iteration anyways. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I couldn't measure a meaningful difference in performance and this should be easier to understand and maintain. Maybe in some hardware it would make a difference? I also would expect the compiler to be able to optimize simple conditionals like this, but that may be expecting too much. |
||
|
||
float maxval = -FLT_MAX; | ||
int argmax = -1; | ||
|
||
for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) { | ||
const float val = x[row*ncols + col]; | ||
const int bigger = val > maxval; | ||
const int not_bigger = bigger ^ 0x00000001; | ||
|
||
maxval = maxval*not_bigger + val*bigger; | ||
argmax = argmax*not_bigger + col*bigger; | ||
const int n_warps = blockDim.x / WARP_SIZE; | ||
const int lane_id = threadIdx.x % WARP_SIZE; | ||
const int warp_id = threadIdx.x / WARP_SIZE; | ||
if (n_warps > 1) { | ||
constexpr int max_warps = 1024 / WARP_SIZE; | ||
__shared__ float shared_maxval[max_warps]; | ||
__shared__ int shared_argmax[max_warps]; | ||
if (lane_id == 0) { | ||
shared_maxval[warp_id] = maxval; | ||
shared_argmax[warp_id] = argmax; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
if (warp_id == 0) { | ||
if (lane_id < n_warps) { | ||
maxval = shared_maxval[lane_id]; | ||
argmax = shared_argmax[lane_id]; | ||
} | ||
#pragma unroll | ||
for (int mask = 16; mask > 0; mask >>= 1) { | ||
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE); | ||
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE); | ||
const int bigger = val > maxval; | ||
const int not_bigger = bigger ^ 0x00000001; | ||
|
||
maxval = maxval*not_bigger + val*bigger; | ||
argmax = argmax*not_bigger + col*bigger; | ||
for (int offset = 16; offset > 0; offset >>= 1) { | ||
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); | ||
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); | ||
if (val > maxval) { | ||
maxval = val; | ||
argmax = col; | ||
} | ||
} | ||
} | ||
|
||
const int store = row1 == threadIdx.x; | ||
argmax_thread += store*argmax; | ||
} | ||
|
||
const int row = row0 + threadIdx.x; | ||
|
||
if (row >= nrows) { | ||
return; | ||
if (warp_id == 0 && lane_id == 0) { | ||
dst[row] = argmax; | ||
} | ||
|
||
Comment on lines
+64
to
66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My experience is that conditional returns/continues are faster than conditional writes but it probably doesn't matter much. |
||
dst[row] = argmax_thread; | ||
} | ||
|
||
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
|
@@ -70,10 +82,10 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
|
||
cudaStream_t stream = ctx.stream(); | ||
|
||
const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE; | ||
|
||
const dim3 blocks_dim(WARP_SIZE, 1, 1); | ||
const int64_t num_blocks = nrows; | ||
const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE); | ||
const dim3 blocks_dim(num_threads, 1, 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is going to be efficient for 32 <= ne00 <= 1024 and ne00 >> 1024 but inefficient for 1024 < ne00 <= 4096. And in general, if you have a variable block size you should make it a template parameter. |
||
const dim3 blocks_num(num_blocks, 1, 1); | ||
|
||
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows); | ||
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1155,6 +1155,26 @@ struct test_argmax : public test_case { | |
return out; | ||
} | ||
|
||
void initialize_tensors(ggml_context * ctx) override { | ||
std::random_device rd; | ||
std::default_random_engine rng(rd()); | ||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { | ||
if (t->type == GGML_TYPE_F32) { | ||
// initialize with unique values to avoid ties | ||
for (int64_t r = 0; r < ggml_nrows(t); r++) { | ||
std::vector<float> data(t->ne[0]); | ||
for (int i = 0; i < t->ne[0]; i++) { | ||
data[i] = i; | ||
} | ||
std::shuffle(data.begin(), data.end(), rng); | ||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); | ||
} | ||
} else { | ||
init_tensor_uniform(t); | ||
} | ||
} | ||
} | ||
|
||
double max_nmse_err() override { | ||
return 0.0; | ||
} | ||
|
@@ -3441,6 +3461,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() { | |
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); | ||
|
||
test_cases.emplace_back(new test_argmax()); | ||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1})); | ||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1})); | ||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); | ||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1})); | ||
Comment on lines
+3464
to
+3467
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You may want to also check the case with ne01 and ne00 flipped where whether or not the writes are coalesced makes a comparatively larger difference. But that would be the case with a very large batch size and few classes and especially with language models that have large vocabulary sizes I think it's not an important use case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean test for correctness or performance? These cases are the ones used in eval mode only. I also tested the performance with [512,32000], and it drops to 480GB/s (compared to 730GB/s with [32000,512]). There are surely more optimization opportunities, but I don't think it is worth spending more time on this at moment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I only meant performance. I wrote the code on master in the context of the ggml MNIST example with an input shape of |
||
|
||
test_cases.emplace_back(new test_count_equal()); | ||
|
||
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1 | ||
|
@@ -3831,6 +3856,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() { | |
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f)); | ||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f)); | ||
|
||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1})); | ||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); | ||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1})); | ||
|
||
for (int bs : {1, 512}) { | ||
for (ggml_type type_a : all_types) { | ||
for (ggml_type type_b : {GGML_TYPE_F32}) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the code again, I think either 64 bit should be used for the
ne00
dimension or there should be an assert that 32 bit is enough.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output is int32, so it would definitely not work with
ne00
larger thanINT_MAX
. In that case it might make more sense to add the assert toggml_argmax
instead. Other arg* functions will have the same issue.