Skip to content

Commit e0de9c1

Browse files
committed
ggml : ggml_set_rows support broadcast
1 parent a32fc70 commit e0de9c1

File tree

3 files changed

+39
-15
lines changed

3 files changed

+39
-15
lines changed

ggml/include/ggml.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,15 @@ extern "C" {
13791379
struct ggml_tensor * b, // row indices
13801380
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
13811381

1382+
// a TD [n_embd, ne1, ne2, ne3]
1383+
// b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
1384+
// c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
1385+
//
1386+
// broadcast:
1387+
// ne2 % ne11 == 0
1388+
// ne3 % ne12 == 0
1389+
//
1390+
// return view(a)
13821391
GGML_API struct ggml_tensor * ggml_set_rows(
13831392
struct ggml_context * ctx,
13841393
struct ggml_tensor * a, // destination

ggml/src/ggml-cpu/ops.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4530,12 +4530,14 @@ static void ggml_compute_forward_set_rows_f32(
45304530
GGML_TENSOR_BINARY_OP_LOCALS
45314531

45324532
const int64_t nc = ne00;
4533-
const int64_t nr = ggml_nelements(src1);
4533+
const int64_t nr = ne01;
45344534

45354535
assert(ne0 == nc);
4536-
assert(ne02 == ne11);
4537-
assert(nb00 == sizeof(float));
4538-
assert(ggml_nrows(src0) == nr);
4536+
assert(ne2 == ne02);
4537+
assert(ne3 == ne03);
4538+
assert(src0->type == GGML_TYPE_F32);
4539+
assert(ne02 % ne11 == 0);
4540+
assert(ne03 % ne12 == 0);
45394541

45404542
const int ith = params->ith;
45414543
const int nth = params->nth;
@@ -4547,17 +4549,22 @@ static void ggml_compute_forward_set_rows_f32(
45474549
const int ir0 = dr*ith;
45484550
const int ir1 = MIN(ir0 + dr, nr);
45494551

4550-
for (int64_t i = ir0; i < ir1; ++i) {
4551-
const int64_t i12 = i/(ne11*ne10);
4552-
const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4553-
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4554-
const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4552+
for (int64_t i03 = 0; i03 < ne03; ++i03) {
4553+
for (int64_t i02 = 0; i02 < ne02; ++i02) {
4554+
for (int64_t i = ir0; i < ir1; ++i) {
4555+
const int64_t i12 = i03%ne12;
4556+
const int64_t i11 = i02%ne11;
4557+
const int64_t i10 = i;
4558+
4559+
const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
45554560

4556-
GGML_ASSERT(i01 >= 0 && i01 < ne1);
4561+
GGML_ASSERT(i01 >= 0 && i01 < ne1);
45574562

4558-
ggml_cpu_fp32_to_fp16(
4559-
(const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03),
4560-
(ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc);
4563+
ggml_cpu_fp32_to_fp16(
4564+
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
4565+
(ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), nc);
4566+
}
4567+
}
45614568
}
45624569
}
45634570

ggml/src/ggml.c

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3410,12 +3410,20 @@ struct ggml_tensor * ggml_set_rows(
34103410
struct ggml_tensor * a,
34113411
struct ggml_tensor * b,
34123412
struct ggml_tensor * c) {
3413-
GGML_ASSERT(b->ne[2] == c->ne[1]);
3413+
GGML_ASSERT(a->ne[0] == b->ne[0]);
3414+
GGML_ASSERT(a->ne[2] == b->ne[2]);
3415+
GGML_ASSERT(a->ne[3] == b->ne[3]);
3416+
GGML_ASSERT(b->ne[1] == c->ne[0]);
3417+
GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
3418+
GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
34143419
GGML_ASSERT(c->ne[3] == 1);
3415-
GGML_ASSERT(a->type == GGML_TYPE_F16);
3420+
GGML_ASSERT(a->type == GGML_TYPE_F16); // TODO: relax
34163421
GGML_ASSERT(b->type == GGML_TYPE_F32);
34173422
GGML_ASSERT(c->type == GGML_TYPE_I64);
34183423

3424+
GGML_ASSERT(ggml_is_contiguous_rows(a));
3425+
GGML_ASSERT(ggml_is_contiguous_rows(b));
3426+
34193427
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
34203428

34213429
result->op = GGML_OP_SET_ROWS;

0 commit comments

Comments
 (0)