Skip to content

Commit 1c73d4e

Browse files
committed
GGML map ops proof of concept.
1 parent 723dac5 commit 1c73d4e

File tree

2 files changed

+226
-3
lines changed

2 files changed

+226
-3
lines changed

ggml.c

Lines changed: 212 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2712,9 +2712,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
27122712

27132713
"FLASH_ATTN",
27142714
"FLASH_FF",
2715+
2716+
"MAP_UNARY",
2717+
"MAP_BINARY",
27152718
};
27162719

2717-
static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
2720+
static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
27182721

27192722
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
27202723
"none",
@@ -2757,9 +2760,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
27572760

27582761
"flash_attn(x)",
27592762
"flash_ff(x)",
2763+
2764+
"f(x)",
2765+
"f(x,y)",
27602766
};
27612767

2762-
static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
2768+
static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
27632769

27642770
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
27652771
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3671,6 +3677,92 @@ struct ggml_tensor * ggml_dup_inplace(
36713677
return ggml_dup_impl(ctx, a, true);
36723678
}
36733679

3680+
3681+
// ggml_map_binary
3682+
3683+
struct ggml_tensor * ggml_map_binary_impl(
3684+
struct ggml_context * ctx,
3685+
struct ggml_tensor * a,
3686+
struct ggml_tensor * b,
3687+
void (*const fun)(int, float *, float *, float *),
3688+
bool inplace) {
3689+
GGML_ASSERT(ggml_are_same_shape(a, b));
3690+
3691+
bool is_node = false;
3692+
3693+
if (!inplace && (a->grad || b->grad)) {
3694+
is_node = true;
3695+
}
3696+
3697+
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
3698+
*((void **)addr_tensor->data) = fun;
3699+
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3700+
3701+
result->op = GGML_OP_MAP_BINARY;
3702+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
3703+
result->src0 = a;
3704+
result->src1 = b;
3705+
result->opt[0] = addr_tensor;
3706+
3707+
return result;
3708+
}
3709+
3710+
struct ggml_tensor * ggml_map_binary(
3711+
struct ggml_context * ctx,
3712+
struct ggml_tensor * a,
3713+
struct ggml_tensor * b,
3714+
void (*const fun)(int, float *, float *, float *)) {
3715+
return ggml_map_binary_impl(ctx, a, b, fun, false);
3716+
}
3717+
3718+
struct ggml_tensor * ggml_map_binary_inplace(
3719+
struct ggml_context * ctx,
3720+
struct ggml_tensor * a,
3721+
struct ggml_tensor * b,
3722+
void (*const fun)(int, float *, float *, float *)) {
3723+
return ggml_map_binary_impl(ctx, a, b, fun, true);
3724+
}
3725+
3726+
// ggml_map_unary
3727+
3728+
struct ggml_tensor * ggml_map_unary_impl(
3729+
struct ggml_context * ctx,
3730+
struct ggml_tensor * a,
3731+
void (*const fun)(int, float *, float *),
3732+
bool inplace) {
3733+
bool is_node = false;
3734+
3735+
if (!inplace && a->grad) {
3736+
is_node = true;
3737+
}
3738+
3739+
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
3740+
*((void **)addr_tensor->data) = fun;
3741+
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3742+
3743+
result->op = GGML_OP_MAP_UNARY;
3744+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
3745+
result->src0 = a;
3746+
result->opt[0] = addr_tensor;
3747+
3748+
return result;
3749+
}
3750+
3751+
struct ggml_tensor * ggml_map_unary(
3752+
struct ggml_context * ctx,
3753+
struct ggml_tensor * a,
3754+
void (*const fun)(int, float *, float *)) {
3755+
return ggml_map_unary_impl(ctx, a, fun, false);
3756+
}
3757+
3758+
struct ggml_tensor * ggml_map_unary_inplace(
3759+
struct ggml_context * ctx,
3760+
struct ggml_tensor * a,
3761+
void (*const fun)(int, float *, float *)) {
3762+
return ggml_map_unary_impl(ctx, a, fun, true);
3763+
}
3764+
3765+
36743766
// ggml_add
36753767

36763768
struct ggml_tensor * ggml_add_impl(
@@ -5329,6 +5421,111 @@ static void ggml_compute_forward_dup(
53295421
}
53305422
}
53315423

5424+
// ggml_compute_forward_map_unary
5425+
5426+
static void ggml_compute_forward_map_unary_f32(
5427+
const struct ggml_compute_params * params,
5428+
const struct ggml_tensor * src0,
5429+
struct ggml_tensor * dst,
5430+
void (*const fun)(int, float *, float *)) {
5431+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
5432+
5433+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5434+
return;
5435+
}
5436+
5437+
const int n = ggml_nrows(src0);
5438+
const int nc = src0->ne[0];
5439+
5440+
assert( dst->nb[0] == sizeof(float));
5441+
assert(src0->nb[0] == sizeof(float));
5442+
5443+
for (int i = 0; i < n; i++) {
5444+
fun(nc,
5445+
(float *) ((char *) dst->data + i*( dst->nb[1])),
5446+
(float *) ((char *) src0->data + i*(src0->nb[1])));
5447+
}
5448+
}
5449+
5450+
5451+
static void ggml_compute_forward_map_unary(
5452+
const struct ggml_compute_params * params,
5453+
const struct ggml_tensor * src0,
5454+
struct ggml_tensor * dst,
5455+
void (*const fun)(int, float *, float *)) {
5456+
switch (src0->type) {
5457+
case GGML_TYPE_F32:
5458+
{
5459+
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
5460+
} break;
5461+
case GGML_TYPE_Q4_0:
5462+
case GGML_TYPE_Q4_1:
5463+
case GGML_TYPE_I8:
5464+
case GGML_TYPE_I16:
5465+
case GGML_TYPE_I32:
5466+
case GGML_TYPE_F16:
5467+
case GGML_TYPE_COUNT:
5468+
{
5469+
GGML_ASSERT(false);
5470+
} break;
5471+
}
5472+
}
5473+
5474+
// ggml_compute_forward_map_binary
5475+
5476+
static void ggml_compute_forward_map_binary_f32(
5477+
const struct ggml_compute_params * params,
5478+
const struct ggml_tensor * src0,
5479+
const struct ggml_tensor * src1,
5480+
struct ggml_tensor * dst,
5481+
void (*const fun)(int, float *, float *, float *)) {
5482+
assert(params->ith == 0);
5483+
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
5484+
5485+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5486+
return;
5487+
}
5488+
5489+
const int n = ggml_nrows(src0);
5490+
const int nc = src0->ne[0];
5491+
5492+
assert( dst->nb[0] == sizeof(float));
5493+
assert(src0->nb[0] == sizeof(float));
5494+
assert(src1->nb[0] == sizeof(float));
5495+
5496+
for (int i = 0; i < n; i++) {
5497+
fun(nc,
5498+
(float *) ((char *) dst->data + i*( dst->nb[1])),
5499+
(float *) ((char *) src0->data + i*(src0->nb[1])),
5500+
(float *) ((char *) src1->data + i*(src1->nb[1])));
5501+
}
5502+
}
5503+
5504+
5505+
static void ggml_compute_forward_map_binary(
5506+
const struct ggml_compute_params * params,
5507+
const struct ggml_tensor * src0,
5508+
const struct ggml_tensor * src1,
5509+
struct ggml_tensor * dst,
5510+
void (*const fun)(int, float *, float *, float *)) {
5511+
switch (src0->type) {
5512+
case GGML_TYPE_F32:
5513+
{
5514+
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
5515+
} break;
5516+
case GGML_TYPE_Q4_0:
5517+
case GGML_TYPE_Q4_1:
5518+
case GGML_TYPE_I8:
5519+
case GGML_TYPE_I16:
5520+
case GGML_TYPE_I32:
5521+
case GGML_TYPE_F16:
5522+
case GGML_TYPE_COUNT:
5523+
{
5524+
GGML_ASSERT(false);
5525+
} break;
5526+
}
5527+
}
5528+
53325529
// ggml_compute_forward_add
53335530

53345531
static void ggml_compute_forward_add_f32(
@@ -8877,7 +9074,19 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
88779074
{
88789075
ggml_compute_forward_dup(params, tensor->src0, tensor);
88799076
} break;
8880-
case GGML_OP_ADD:
9077+
case GGML_OP_MAP_UNARY:
9078+
{
9079+
void (*const fun)(int, float *, float *) = *((void **)tensor->opt[0]->data);
9080+
ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
9081+
}
9082+
break;
9083+
case GGML_OP_MAP_BINARY:
9084+
{
9085+
void (*const fun)(int, float *, float *, float *) = *((void **)tensor->opt[0]->data);
9086+
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
9087+
}
9088+
break;
9089+
case GGML_OP_ADD:
88819090
{
88829091
ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
88839092
} break;

ggml.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ enum ggml_op {
253253
GGML_OP_FLASH_ATTN,
254254
GGML_OP_FLASH_FF,
255255

256+
GGML_OP_MAP_UNARY,
257+
GGML_OP_MAP_BINARY,
258+
256259
GGML_OP_COUNT,
257260
};
258261

@@ -419,6 +422,17 @@ struct ggml_tensor * ggml_dup(
419422
struct ggml_context * ctx,
420423
struct ggml_tensor * a);
421424

425+
struct ggml_tensor *ggml_map_unary(
426+
struct ggml_context *ctx,
427+
struct ggml_tensor *a,
428+
void (*const fun)(int, float *, float *));
429+
430+
struct ggml_tensor *ggml_map_binary(
431+
struct ggml_context *ctx,
432+
struct ggml_tensor *a,
433+
struct ggml_tensor *b,
434+
void (*const fun)(int, float *, float *, float *));
435+
422436
struct ggml_tensor * ggml_add(
423437
struct ggml_context * ctx,
424438
struct ggml_tensor * a,

0 commit comments

Comments
 (0)