Skip to content

Commit a8cd49b

Browse files
rgerganovggerganov
authored andcommitted
ggml : add ggml_set_rows
Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using indices from 'c'. ref: #8366
1 parent 1e86597 commit a8cd49b

File tree

5 files changed

+96
-2
lines changed

5 files changed

+96
-2
lines changed

ggml/include/ggml.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ extern "C" {
470470
GGML_OP_TRANSPOSE,
471471
GGML_OP_GET_ROWS,
472472
GGML_OP_GET_ROWS_BACK,
473+
GGML_OP_SET_ROWS,
473474
GGML_OP_DIAG,
474475
GGML_OP_DIAG_MASK_INF,
475476
GGML_OP_DIAG_MASK_ZERO,
@@ -1374,6 +1375,12 @@ extern "C" {
13741375
struct ggml_tensor * b, // row indices
13751376
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
13761377

1378+
GGML_API struct ggml_tensor * ggml_set_rows(
1379+
struct ggml_context * ctx,
1380+
struct ggml_tensor * a, // destination
1381+
struct ggml_tensor * b, // source
1382+
struct ggml_tensor * c); // row indices
1383+
13771384
GGML_API struct ggml_tensor * ggml_diag(
13781385
struct ggml_context * ctx,
13791386
struct ggml_tensor * a);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,6 +1891,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
18911891
{
18921892
ggml_compute_forward_get_rows_back(params, tensor);
18931893
} break;
1894+
case GGML_OP_SET_ROWS:
1895+
{
1896+
ggml_compute_forward_set_rows(params, tensor);
1897+
} break;
18941898
case GGML_OP_DIAG:
18951899
{
18961900
ggml_compute_forward_diag(params, tensor);
@@ -2240,6 +2244,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22402244
n_tasks = n_threads;
22412245
} break;
22422246
case GGML_OP_GET_ROWS:
2247+
case GGML_OP_SET_ROWS:
22432248
{
22442249
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
22452250
// decreases performance with GPU offloading

ggml/src/ggml-cpu/ops.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4470,6 +4470,65 @@ void ggml_compute_forward_get_rows(
44704470
//}
44714471
}
44724472

4473+
static void ggml_compute_forward_set_rows_f32(
4474+
const ggml_compute_params * params,
4475+
ggml_tensor * dst) {
4476+
4477+
const ggml_tensor * src0 = dst->src[0];
4478+
const ggml_tensor * src1 = dst->src[1];
4479+
4480+
GGML_TENSOR_BINARY_OP_LOCALS
4481+
4482+
const int64_t nc = ne00;
4483+
const int64_t nr = ggml_nelements(src1);
4484+
4485+
assert(ne0 == nc);
4486+
assert(ne02 == ne11);
4487+
assert(nb00 == sizeof(float));
4488+
assert(ggml_nrows(src0) == nr);
4489+
4490+
const int ith = params->ith;
4491+
const int nth = params->nth;
4492+
4493+
// rows per thread
4494+
const int dr = (nr + nth - 1)/nth;
4495+
4496+
// row range for this thread
4497+
const int ir0 = dr*ith;
4498+
const int ir1 = MIN(ir0 + dr, nr);
4499+
4500+
for (int64_t i = ir0; i < ir1; ++i) {
4501+
const int64_t i12 = i/(ne11*ne10);
4502+
const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4503+
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4504+
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4505+
4506+
GGML_ASSERT(i01 >= 0 && i01 < ne1);
4507+
4508+
ggml_cpu_fp32_to_fp16(
4509+
(const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03),
4510+
(ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc);
4511+
}
4512+
}
4513+
4514+
void ggml_compute_forward_set_rows(
4515+
const ggml_compute_params * params,
4516+
ggml_tensor * dst) {
4517+
4518+
const ggml_tensor * src0 = dst->src[0];
4519+
4520+
switch (src0->type) {
4521+
case GGML_TYPE_F32:
4522+
{
4523+
ggml_compute_forward_set_rows_f32(params, dst);
4524+
} break;
4525+
default:
4526+
{
4527+
GGML_ABORT("fatal error");
4528+
}
4529+
}
4530+
}
4531+
44734532
// ggml_compute_forward_get_rows_back
44744533

44754534
static void ggml_compute_forward_get_rows_back_f32_f16(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str
5353
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5454
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5555
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
56+
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5657
void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5758
void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5859
void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml.c

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
985985
"OPT_STEP_ADAMW",
986986
};
987987

988-
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
988+
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
989989

990990
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
991991
"none",
@@ -1080,7 +1080,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10801080
"adamw(x)",
10811081
};
10821082

1083-
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
1083+
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
10841084

10851085
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
10861086

@@ -3393,6 +3393,28 @@ struct ggml_tensor * ggml_get_rows_back(
33933393
return result;
33943394
}
33953395

3396+
// ggml_set_rows
3397+
3398+
struct ggml_tensor * ggml_set_rows(
3399+
struct ggml_context * ctx,
3400+
struct ggml_tensor * a,
3401+
struct ggml_tensor * b,
3402+
struct ggml_tensor * c) {
3403+
GGML_ASSERT(b->ne[2] == c->ne[1]);
3404+
GGML_ASSERT(c->ne[3] == 1);
3405+
GGML_ASSERT(a->type == GGML_TYPE_F16);
3406+
GGML_ASSERT(b->type == GGML_TYPE_F32);
3407+
GGML_ASSERT(c->type == GGML_TYPE_I32);
3408+
3409+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
3410+
3411+
result->op = GGML_OP_SET_ROWS;
3412+
result->src[0] = b;
3413+
result->src[1] = c;
3414+
3415+
return result;
3416+
}
3417+
33963418
// ggml_diag
33973419

33983420
struct ggml_tensor * ggml_diag(

0 commit comments

Comments
 (0)