Skip to content

Commit efb6ae9

Browse files
PABannierslaren
authored andcommitted
feat: add GGML_UNARY_OP_ARGMAX Metal kernel (ggml/1019)
* implemented argmax kernel * tpig -> tgpig * change to strides * contiguous assertions * kernel working and tested * argmax simd parallel implementation * added 2 new tests for argmax in test-backend-ops * cosmit * added 3 tests cases for perf eval * add test_argmax in make_test_cases_perf * Update test-backend-ops.cpp Co-authored-by: Diego Devesa <[email protected]> --------- Co-authored-by: Diego Devesa <[email protected]>
1 parent 667d70d commit efb6ae9

File tree

3 files changed

+91
-5
lines changed

3 files changed

+91
-5
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
392392
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
393393
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
394394
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
395+
GGML_METAL_KERNEL_TYPE_ARGMAX,
395396

396397
GGML_METAL_KERNEL_TYPE_COUNT
397398
};
@@ -956,6 +957,7 @@ @implementation GGMLMetalClass
956957
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
957958
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
958959
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
960+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
959961
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
960962
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
961963
}
@@ -1086,6 +1088,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
10861088
return has_simdgroup_reduction;
10871089
case GGML_OP_RMS_NORM:
10881090
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
1091+
case GGML_OP_ARGMAX:
10891092
case GGML_OP_NORM:
10901093
case GGML_OP_ROPE:
10911094
return true;
@@ -3845,6 +3848,31 @@ static void ggml_metal_encode_node(
38453848

38463849
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
38473850
} break;
3851+
case GGML_OP_ARGMAX:
3852+
{
3853+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
3854+
GGML_ASSERT(ggml_is_contiguous_1(src0));
3855+
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
3856+
3857+
const int64_t nrows = ggml_nrows(src0);
3858+
3859+
int nth = 32; // SIMD width
3860+
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
3861+
nth *= 2;
3862+
}
3863+
3864+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;
3865+
3866+
[encoder setComputePipelineState:pipeline];
3867+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3868+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3869+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
3870+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
3871+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
3872+
[encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];
3873+
3874+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3875+
} break;
38483876
default:
38493877
{
38503878
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,63 @@ kernel void kernel_ssm_scan_f32(
13661366
}
13671367
}
13681368

1369+
kernel void kernel_argmax(
1370+
device const void * x,
1371+
device int32_t * dst,
1372+
constant int64_t & ncols,
1373+
constant uint64_t & nb01,
1374+
threadgroup float * shared_maxval [[threadgroup(0)]],
1375+
threadgroup int32_t * shared_argmax [[threadgroup(1)]],
1376+
uint tgpig[[threadgroup_position_in_grid]],
1377+
uint tpitg[[thread_position_in_threadgroup]],
1378+
uint sgitg[[simdgroup_index_in_threadgroup]],
1379+
uint tiisg[[thread_index_in_simdgroup]],
1380+
uint ntg[[threads_per_threadgroup]]) {
1381+
device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);
1382+
1383+
float lmax = -INFINITY;
1384+
int32_t larg = -1;
1385+
1386+
for (int i00 = tpitg; i00 < ncols; i00 += ntg) {
1387+
if (x_row[i00] > lmax) {
1388+
lmax = x_row[i00];
1389+
larg = i00;
1390+
}
1391+
}
1392+
1393+
// find the argmax value in the block
1394+
float max_val = simd_max(lmax);
1395+
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
1396+
1397+
if (ntg > N_SIMDWIDTH) {
1398+
if (sgitg == 0) {
1399+
shared_maxval[tiisg] = -INFINITY;
1400+
shared_argmax[tiisg] = -1;
1401+
}
1402+
1403+
threadgroup_barrier(mem_flags::mem_threadgroup);
1404+
1405+
if (tiisg == 0) {
1406+
shared_maxval[sgitg] = max_val;
1407+
shared_argmax[sgitg] = arg_val;
1408+
}
1409+
1410+
threadgroup_barrier(mem_flags::mem_threadgroup);
1411+
1412+
max_val = shared_maxval[tiisg];
1413+
arg_val = shared_argmax[tiisg];
1414+
1415+
float max_val_reduced = simd_max(max_val);
1416+
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
1417+
1418+
dst[tgpig] = arg_val_reduced;
1419+
1420+
return;
1421+
}
1422+
1423+
dst[tgpig] = arg_val;
1424+
}
1425+
13691426
kernel void kernel_norm(
13701427
constant ggml_metal_kargs_norm & args,
13711428
device const char * src0,

tests/test-backend-ops.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3460,13 +3460,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
34603460
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
34613461
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
34623462

3463-
test_cases.emplace_back(new test_argmax());
3464-
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
3465-
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
3463+
test_cases.emplace_back(new test_count_equal());
3464+
3465+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
3466+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
34663467
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
3468+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));
34673469
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
3468-
3469-
test_cases.emplace_back(new test_count_equal());
3470+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438, 3, 1, 1}));
34703471

34713472
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
34723473
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));

0 commit comments

Comments
 (0)