Skip to content

Commit c9a59b7

Browse files
authored
ggml : add unary and binary map operations (#874)
* GGML map ops proof of concept. * Various cleanups. Add handling for task setting. Add handling for ggml_compute_backward. Rename functions to ggml_map_unary_f32 and ggml_map_binary_f32 Fix compiler warnings related to casting function pointers and `void *` Reorder functions and definitions based on the GGML op number. Use typedefs for map op function pointer types. * Fix position of map ops cases in ggml_compute_forward
1 parent a32f7ac commit c9a59b7

File tree

2 files changed

+237
-2
lines changed

2 files changed

+237
-2
lines changed

ggml.c

Lines changed: 219 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2712,9 +2712,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
27122712

27132713
"FLASH_ATTN",
27142714
"FLASH_FF",
2715+
2716+
"MAP_UNARY",
2717+
"MAP_BINARY",
27152718
};
27162719

2717-
static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
2720+
static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
27182721

27192722
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
27202723
"none",
@@ -2757,9 +2760,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
27572760

27582761
"flash_attn(x)",
27592762
"flash_ff(x)",
2763+
2764+
"f(x)",
2765+
"f(x,y)",
27602766
};
27612767

2762-
static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
2768+
static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
27632769

27642770
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
27652771
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -4907,6 +4913,90 @@ struct ggml_tensor * ggml_flash_ff(
49074913
return result;
49084914
}
49094915

4916+
// ggml_map_unary
4917+
4918+
struct ggml_tensor * ggml_map_unary_impl_f32(
4919+
struct ggml_context * ctx,
4920+
struct ggml_tensor * a,
4921+
const ggml_unary_op_f32_t fun,
4922+
bool inplace) {
4923+
bool is_node = false;
4924+
4925+
if (!inplace && a->grad) {
4926+
is_node = true;
4927+
}
4928+
4929+
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
4930+
*((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
4931+
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4932+
4933+
result->op = GGML_OP_MAP_UNARY;
4934+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4935+
result->src0 = a;
4936+
result->opt[0] = addr_tensor;
4937+
4938+
return result;
4939+
}
4940+
4941+
struct ggml_tensor * ggml_map_unary_f32(
4942+
struct ggml_context * ctx,
4943+
struct ggml_tensor * a,
4944+
const ggml_unary_op_f32_t fun) {
4945+
return ggml_map_unary_impl_f32(ctx, a, fun, false);
4946+
}
4947+
4948+
struct ggml_tensor * ggml_map_unary_inplace_f32(
4949+
struct ggml_context * ctx,
4950+
struct ggml_tensor * a,
4951+
const ggml_unary_op_f32_t fun) {
4952+
return ggml_map_unary_impl_f32(ctx, a, fun, true);
4953+
}
4954+
4955+
// ggml_map_binary
4956+
4957+
struct ggml_tensor * ggml_map_binary_impl_f32(
4958+
struct ggml_context * ctx,
4959+
struct ggml_tensor * a,
4960+
struct ggml_tensor * b,
4961+
const ggml_binary_op_f32_t fun,
4962+
bool inplace) {
4963+
GGML_ASSERT(ggml_are_same_shape(a, b));
4964+
4965+
bool is_node = false;
4966+
4967+
if (!inplace && (a->grad || b->grad)) {
4968+
is_node = true;
4969+
}
4970+
4971+
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
4972+
*((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
4973+
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4974+
4975+
result->op = GGML_OP_MAP_BINARY;
4976+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4977+
result->src0 = a;
4978+
result->src1 = b;
4979+
result->opt[0] = addr_tensor;
4980+
4981+
return result;
4982+
}
4983+
4984+
struct ggml_tensor * ggml_map_binary_f32(
4985+
struct ggml_context * ctx,
4986+
struct ggml_tensor * a,
4987+
struct ggml_tensor * b,
4988+
const ggml_binary_op_f32_t fun) {
4989+
return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
4990+
}
4991+
4992+
struct ggml_tensor * ggml_map_binary_inplace_f32(
4993+
struct ggml_context * ctx,
4994+
struct ggml_tensor * a,
4995+
struct ggml_tensor * b,
4996+
const ggml_binary_op_f32_t fun) {
4997+
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
4998+
}
4999+
49105000
////////////////////////////////////////////////////////////////////////////////
49115001

49125002
void ggml_set_param(
@@ -8875,6 +8965,111 @@ static void ggml_compute_forward_flash_ff(
88758965
}
88768966
}
88778967

8968+
// ggml_compute_forward_map_unary
8969+
8970+
static void ggml_compute_forward_map_unary_f32(
8971+
const struct ggml_compute_params * params,
8972+
const struct ggml_tensor * src0,
8973+
struct ggml_tensor * dst,
8974+
const ggml_unary_op_f32_t fun) {
8975+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
8976+
8977+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8978+
return;
8979+
}
8980+
8981+
const int n = ggml_nrows(src0);
8982+
const int nc = src0->ne[0];
8983+
8984+
assert( dst->nb[0] == sizeof(float));
8985+
assert(src0->nb[0] == sizeof(float));
8986+
8987+
for (int i = 0; i < n; i++) {
8988+
fun(nc,
8989+
(float *) ((char *) dst->data + i*( dst->nb[1])),
8990+
(float *) ((char *) src0->data + i*(src0->nb[1])));
8991+
}
8992+
}
8993+
8994+
8995+
static void ggml_compute_forward_map_unary(
8996+
const struct ggml_compute_params * params,
8997+
const struct ggml_tensor * src0,
8998+
struct ggml_tensor * dst,
8999+
const ggml_unary_op_f32_t fun) {
9000+
switch (src0->type) {
9001+
case GGML_TYPE_F32:
9002+
{
9003+
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
9004+
} break;
9005+
case GGML_TYPE_Q4_0:
9006+
case GGML_TYPE_Q4_1:
9007+
case GGML_TYPE_I8:
9008+
case GGML_TYPE_I16:
9009+
case GGML_TYPE_I32:
9010+
case GGML_TYPE_F16:
9011+
case GGML_TYPE_COUNT:
9012+
{
9013+
GGML_ASSERT(false);
9014+
} break;
9015+
}
9016+
}
9017+
9018+
// ggml_compute_forward_map_binary
9019+
9020+
static void ggml_compute_forward_map_binary_f32(
9021+
const struct ggml_compute_params * params,
9022+
const struct ggml_tensor * src0,
9023+
const struct ggml_tensor * src1,
9024+
struct ggml_tensor * dst,
9025+
const ggml_binary_op_f32_t fun) {
9026+
assert(params->ith == 0);
9027+
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
9028+
9029+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9030+
return;
9031+
}
9032+
9033+
const int n = ggml_nrows(src0);
9034+
const int nc = src0->ne[0];
9035+
9036+
assert( dst->nb[0] == sizeof(float));
9037+
assert(src0->nb[0] == sizeof(float));
9038+
assert(src1->nb[0] == sizeof(float));
9039+
9040+
for (int i = 0; i < n; i++) {
9041+
fun(nc,
9042+
(float *) ((char *) dst->data + i*( dst->nb[1])),
9043+
(float *) ((char *) src0->data + i*(src0->nb[1])),
9044+
(float *) ((char *) src1->data + i*(src1->nb[1])));
9045+
}
9046+
}
9047+
9048+
9049+
static void ggml_compute_forward_map_binary(
9050+
const struct ggml_compute_params * params,
9051+
const struct ggml_tensor * src0,
9052+
const struct ggml_tensor * src1,
9053+
struct ggml_tensor * dst,
9054+
const ggml_binary_op_f32_t fun) {
9055+
switch (src0->type) {
9056+
case GGML_TYPE_F32:
9057+
{
9058+
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
9059+
} break;
9060+
case GGML_TYPE_Q4_0:
9061+
case GGML_TYPE_Q4_1:
9062+
case GGML_TYPE_I8:
9063+
case GGML_TYPE_I16:
9064+
case GGML_TYPE_I32:
9065+
case GGML_TYPE_F16:
9066+
case GGML_TYPE_COUNT:
9067+
{
9068+
GGML_ASSERT(false);
9069+
} break;
9070+
}
9071+
}
9072+
88789073
/////////////////////////////////
88799074

88809075
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -9024,6 +9219,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
90249219
{
90259220
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
90269221
} break;
9222+
case GGML_OP_MAP_UNARY:
9223+
{
9224+
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
9225+
ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
9226+
}
9227+
break;
9228+
case GGML_OP_MAP_BINARY:
9229+
{
9230+
const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
9231+
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
9232+
}
9233+
break;
90279234
case GGML_OP_NONE:
90289235
{
90299236
// nop
@@ -9283,6 +9490,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
92839490
{
92849491
GGML_ASSERT(false); // not supported
92859492
} break;
9493+
case GGML_OP_MAP_UNARY:
9494+
case GGML_OP_MAP_BINARY:
9495+
{
9496+
GGML_ASSERT(false); // not supported
9497+
} break;
92869498
case GGML_OP_NONE:
92879499
{
92889500
// nop
@@ -9775,6 +9987,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
97759987

97769988
work_size = MAX(work_size, cur);
97779989
} break;
9990+
case GGML_OP_MAP_UNARY:
9991+
case GGML_OP_MAP_BINARY:
9992+
{
9993+
node->n_tasks = 1;
9994+
} break;
97789995
case GGML_OP_NONE:
97799996
{
97809997
node->n_tasks = 1;

ggml.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ enum ggml_op {
253253
GGML_OP_FLASH_ATTN,
254254
GGML_OP_FLASH_FF,
255255

256+
GGML_OP_MAP_UNARY,
257+
GGML_OP_MAP_BINARY,
258+
256259
GGML_OP_COUNT,
257260
};
258261

@@ -652,6 +655,21 @@ struct ggml_tensor * ggml_flash_ff(
652655
struct ggml_tensor * c0,
653656
struct ggml_tensor * c1);
654657

658+
// Mapping operations
659+
typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
660+
typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
661+
662+
struct ggml_tensor * ggml_map_unary_f32(
663+
struct ggml_context * ctx,
664+
struct ggml_tensor * a,
665+
const ggml_unary_op_f32_t fun);
666+
667+
struct ggml_tensor * ggml_map_binary_f32(
668+
struct ggml_context * ctx,
669+
struct ggml_tensor * a,
670+
struct ggml_tensor * b,
671+
const ggml_binary_op_f32_t fun);
672+
655673
//
656674
// automatic differentiation
657675
//

0 commit comments

Comments
 (0)