Skip to content

Commit 4523d10

Browse files
committed
ggml : add ggml_pool_1d and ggml_pool_2d
1 parent 680e6f9 commit 4523d10

File tree

2 files changed

+308
-2
lines changed

2 files changed

+308
-2
lines changed

ggml.c

Lines changed: 281 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3787,6 +3787,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
37873787
"CLAMP",
37883788
"CONV_1D",
37893789
"CONV_2D",
3790+
"POOL_1D",
3791+
"POOL_2D",
37903792

37913793
"FLASH_ATTN",
37923794
"FLASH_FF",
@@ -3805,7 +3807,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
38053807
"CROSS_ENTROPY_LOSS_BACK",
38063808
};
38073809

3808-
static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66");
3810+
static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
38093811

38103812
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
38113813
"none",
@@ -3865,6 +3867,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
38653867
"clamp(x)",
38663868
"conv_1d(x)",
38673869
"conv_2d(x)",
3870+
"pool_1d(x)",
3871+
"pool_2d(x)",
38683872

38693873
"flash_attn(x)",
38703874
"flash_ff(x)",
@@ -3883,7 +3887,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
38833887
"cross_entropy_loss_back(x,y)",
38843888
};
38853889

3886-
static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66");
3890+
static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
3891+
3892+
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
38873893

38883894
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
38893895
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -7214,6 +7220,98 @@ struct ggml_tensor* ggml_conv_1d_ph(
72147220
return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
72157221
}
72167222

7223+
7224+
// ggml_pool_*
7225+
7226+
static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) {
7227+
return (ins + 2 * p - ks) / s + 1;
7228+
}
7229+
7230+
// ggml_pool_2d
7231+
7232+
struct ggml_tensor* ggml_pool_1d(
7233+
struct ggml_context * ctx,
7234+
struct ggml_tensor * a,
7235+
enum ggml_op_pool op,
7236+
int k0,
7237+
int s0,
7238+
int p0) {
7239+
7240+
bool is_node = false;
7241+
7242+
if (a->grad) {
7243+
GGML_ASSERT(false); // TODO: implement backward
7244+
is_node = true;
7245+
}
7246+
7247+
const int64_t ne[3] = {
7248+
ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
7249+
a->ne[1],
7250+
};
7251+
struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7252+
7253+
ggml_scratch_save(ctx);
7254+
struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
7255+
((int32_t*)c->data)[0] = op;
7256+
((int32_t*)c->data)[1] = k0;
7257+
((int32_t*)c->data)[2] = s0;
7258+
((int32_t*)c->data)[3] = p0;
7259+
ggml_scratch_load(ctx);
7260+
7261+
result->op = GGML_OP_POOL_1D;
7262+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7263+
result->src[0] = a;
7264+
result->src[1] = c;
7265+
7266+
return result;
7267+
}
7268+
7269+
// ggml_pool_2d
7270+
7271+
struct ggml_tensor* ggml_pool_2d(
7272+
struct ggml_context * ctx,
7273+
struct ggml_tensor * a,
7274+
enum ggml_op_pool op,
7275+
int k0,
7276+
int k1,
7277+
int s0,
7278+
int s1,
7279+
int p0,
7280+
int p1) {
7281+
7282+
bool is_node = false;
7283+
7284+
if (a->grad) {
7285+
GGML_ASSERT(false); // TODO: implement backward
7286+
is_node = true;
7287+
}
7288+
7289+
const int64_t ne[3] = {
7290+
ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
7291+
ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
7292+
a->ne[2],
7293+
};
7294+
struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
7295+
7296+
ggml_scratch_save(ctx);
7297+
struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 7);
7298+
((int32_t*)c->data)[0] = op;
7299+
((int32_t*)c->data)[1] = k0;
7300+
((int32_t*)c->data)[2] = k1;
7301+
((int32_t*)c->data)[3] = s0;
7302+
((int32_t*)c->data)[4] = s1;
7303+
((int32_t*)c->data)[5] = p0;
7304+
((int32_t*)c->data)[6] = p1;
7305+
ggml_scratch_load(ctx);
7306+
7307+
result->op = GGML_OP_POOL_2D;
7308+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7309+
result->src[0] = a;
7310+
result->src[1] = c;
7311+
7312+
return result;
7313+
}
7314+
72177315
// ggml_flash_attn
72187316

72197317
struct ggml_tensor * ggml_flash_attn(
@@ -13013,6 +13111,166 @@ static void ggml_compute_forward_conv_2d(
1301313111
};
1301413112
}
1301513113

13114+
// ggml_compute_forward_pool_1d_sk_p0
13115+
13116+
static void ggml_compute_forward_pool_1d_sk_p0(
13117+
const struct ggml_compute_params * params,
13118+
const enum ggml_op_pool op,
13119+
const struct ggml_tensor * src,
13120+
const int k,
13121+
struct ggml_tensor * dst) {
13122+
assert(src->type == GGML_TYPE_F32);
13123+
assert(params->ith == 0);
13124+
13125+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13126+
return;
13127+
}
13128+
13129+
const char * cdata = (const char *)src->data;
13130+
const char * const data_end = cdata + ggml_nbytes(src);
13131+
float * drow = (float *)dst->data;
13132+
13133+
const int64_t rs = dst->ne[0];
13134+
13135+
while (cdata < data_end) {
13136+
const float * const srow = (const float *)cdata;
13137+
13138+
int j = 0;
13139+
13140+
for (int64_t i = 0; i < rs; ++i) {
13141+
switch (op) {
13142+
case GGML_OP_POOL_AVG: drow[i] = 0; break;
13143+
case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
13144+
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13145+
}
13146+
for (int ki = 0; ki < k; ++ki) {
13147+
switch (op) {
13148+
case GGML_OP_POOL_AVG: drow[i] += srow[j]; break;
13149+
case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break;
13150+
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13151+
}
13152+
++j;
13153+
}
13154+
switch (op) {
13155+
case GGML_OP_POOL_AVG: drow[i] /= k; break;
13156+
case GGML_OP_POOL_MAX: break;
13157+
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13158+
}
13159+
}
13160+
13161+
cdata += src->nb[1];
13162+
drow += rs;
13163+
}
13164+
}
13165+
13166+
// ggml_compute_forward_pool_1d
13167+
13168+
static void ggml_compute_forward_pool_1d(
13169+
const struct ggml_compute_params* params,
13170+
const struct ggml_tensor* src0,
13171+
const struct ggml_tensor* opt0,
13172+
struct ggml_tensor* dst) {
13173+
GGML_ASSERT(opt0->ne[0] == 4);
13174+
const int* opts = (const int*)opt0->data;
13175+
enum ggml_op_pool op = opts[0];
13176+
const int k0 = opts[1];
13177+
const int s0 = opts[2];
13178+
const int p0 = opts[3];
13179+
GGML_ASSERT(p0 == 0); // padding not supported
13180+
GGML_ASSERT(k0 == s0); // only s = k supported
13181+
13182+
ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst);
13183+
}
13184+
13185+
// ggml_compute_forward_pool_2d_sk_p0
13186+
13187+
static void ggml_compute_forward_pool_2d_sk_p0(
13188+
const struct ggml_compute_params * params,
13189+
const enum ggml_op_pool op,
13190+
const struct ggml_tensor * src,
13191+
const int k0,
13192+
const int k1,
13193+
struct ggml_tensor * dst) {
13194+
assert(src->type == GGML_TYPE_F32);
13195+
assert(params->ith == 0);
13196+
13197+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13198+
return;
13199+
}
13200+
13201+
const char * cdata = (const char*)src->data;
13202+
const char * const data_end = cdata + ggml_nbytes(src);
13203+
13204+
const int64_t px = dst->ne[0];
13205+
const int64_t py = dst->ne[1];
13206+
const int64_t pa = px * py;
13207+
13208+
float * dplane = (float *)dst->data;
13209+
13210+
const int ka = k0 * k1;
13211+
13212+
while (cdata < data_end) {
13213+
for (int oy = 0; oy < py; ++oy) {
13214+
float * const drow = dplane + oy * px;
13215+
for (int ox = 0; ox < px; ++ox) {
13216+
float * const out = drow + ox;
13217+
switch (op) {
13218+
case GGML_OP_POOL_AVG: *out = 0; break;
13219+
case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
13220+
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13221+
}
13222+
13223+
const int ix = ox * k0;
13224+
const int iy = oy * k1;
13225+
13226+
for (int ky = 0; ky < k1; ++ky) {
13227+
const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
13228+
for (int kx = 0; kx < k0; ++kx) {
13229+
int j = ix + kx;
13230+
switch (op) {
13231+
case GGML_OP_POOL_AVG: *out += srow[j]; break;
13232+
case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break;
13233+
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13234+
}
13235+
}
13236+
}
13237+
switch (op) {
13238+
case GGML_OP_POOL_AVG: *out /= ka; break;
13239+
case GGML_OP_POOL_MAX: break;
13240+
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13241+
}
13242+
}
13243+
}
13244+
13245+
cdata += src->nb[2];
13246+
dplane += pa;
13247+
}
13248+
}
13249+
13250+
// ggml_compute_forward_pool_2d
13251+
13252+
static void ggml_compute_forward_pool_2d(
13253+
const struct ggml_compute_params * params,
13254+
const struct ggml_tensor * src0,
13255+
const struct ggml_tensor * opt0,
13256+
struct ggml_tensor * dst) {
13257+
GGML_ASSERT(opt0->ne[0] == 7);
13258+
const int* opts = (const int*)opt0->data;
13259+
enum ggml_op_pool op = opts[0];
13260+
const int k0 = opts[1];
13261+
const int k1 = opts[2];
13262+
const int s0 = opts[3];
13263+
const int s1 = opts[4];
13264+
const int p0 = opts[5];
13265+
const int p1 = opts[6];
13266+
GGML_ASSERT(p0 == 0);
13267+
GGML_ASSERT(p1 == 0); // padding not supported
13268+
GGML_ASSERT(k0 == s0);
13269+
GGML_ASSERT(k1 == s1); // only s = k supported
13270+
13271+
ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst);
13272+
}
13273+
1301613274

1301713275
// ggml_compute_forward_flash_attn
1301813276

@@ -14794,6 +15052,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1479415052
{
1479515053
ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
1479615054
} break;
15055+
case GGML_OP_POOL_1D:
15056+
{
15057+
ggml_compute_forward_pool_1d(params, tensor->src[0], tensor->src[1], tensor);
15058+
} break;
15059+
case GGML_OP_POOL_2D:
15060+
{
15061+
ggml_compute_forward_pool_2d(params, tensor->src[0], tensor->src[1], tensor);
15062+
} break;
1479715063
case GGML_OP_FLASH_ATTN:
1479815064
{
1479915065
const int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
@@ -15494,6 +15760,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1549415760
{
1549515761
GGML_ASSERT(false); // TODO: not implemented
1549615762
} break;
15763+
case GGML_OP_POOL_1D:
15764+
{
15765+
GGML_ASSERT(false); // TODO: not implemented
15766+
} break;
15767+
case GGML_OP_POOL_2D:
15768+
{
15769+
GGML_ASSERT(false); // TODO: not implemented
15770+
} break;
1549715771
case GGML_OP_FLASH_ATTN:
1549815772
{
1549915773
struct ggml_tensor * flash_grad = NULL;
@@ -16315,6 +16589,11 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1631516589

1631616590
work_size = MAX(work_size, cur);
1631716591
} break;
16592+
case GGML_OP_POOL_1D:
16593+
case GGML_OP_POOL_2D:
16594+
{
16595+
n_tasks = 1;
16596+
} break;
1631816597
case GGML_OP_FLASH_ATTN:
1631916598
{
1632016599
n_tasks = n_threads;

ggml.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ extern "C" {
368368
GGML_OP_CLAMP,
369369
GGML_OP_CONV_1D,
370370
GGML_OP_CONV_2D,
371+
GGML_OP_POOL_1D,
372+
GGML_OP_POOL_2D,
371373

372374
GGML_OP_FLASH_ATTN,
373375
GGML_OP_FLASH_FF,
@@ -1173,6 +1175,31 @@ extern "C" {
11731175
int s,
11741176
int d);
11751177

1178+
enum ggml_op_pool {
1179+
GGML_OP_POOL_MAX,
1180+
GGML_OP_POOL_AVG,
1181+
GGML_OP_POOL_COUNT,
1182+
};
1183+
1184+
GGML_API struct ggml_tensor* ggml_pool_1d(
1185+
struct ggml_context * ctx,
1186+
struct ggml_tensor * a,
1187+
enum ggml_op_pool op,
1188+
int k0, // kernel size
1189+
int s0, // stride
1190+
int p0); // padding
1191+
1192+
GGML_API struct ggml_tensor* ggml_pool_2d(
1193+
struct ggml_context * ctx,
1194+
struct ggml_tensor * a,
1195+
enum ggml_op_pool op,
1196+
int k0,
1197+
int k1,
1198+
int s0,
1199+
int s1,
1200+
int p0,
1201+
int p1);
1202+
11761203
GGML_API struct ggml_tensor * ggml_flash_attn(
11771204
struct ggml_context * ctx,
11781205
struct ggml_tensor * q,

0 commit comments

Comments
 (0)