Skip to content

ggml : broadcast mul_mat + conv batch support #2199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 79 additions & 73 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4168,10 +4168,9 @@ static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) {
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");

return
(t0->ne[0] == t1->ne[0]) &&
(t0->ne[2] == t1->ne[2]) &&
(t0->ne[3] == t1->ne[3]);
return (t0->ne[0] == t1->ne[0]) &&
(t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
(t1->ne[3]%t0->ne[3] == 0);
}

static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
Expand Down Expand Up @@ -6036,8 +6035,8 @@ struct ggml_tensor * ggml_mul_mat(
is_node = true;
}

const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);

result->op = GGML_OP_MUL_MAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
Expand Down Expand Up @@ -7173,7 +7172,6 @@ struct ggml_tensor* ggml_conv_2d(
int d0,
int d1) {

GGML_ASSERT(b->ne[3] == 1);
GGML_ASSERT(a->ne[2] == b->ne[2]);
bool is_node = false;

Expand All @@ -7185,7 +7183,7 @@ struct ggml_tensor* ggml_conv_2d(
const int64_t ne[4] = {
ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1),
a->ne[3], 1,
a->ne[3], b->ne[3],
};
struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);

Expand Down Expand Up @@ -10641,7 +10639,6 @@ static void ggml_compute_forward_rms_norm_back(
}
}


// ggml_compute_forward_mul_mat

#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
Expand Down Expand Up @@ -10685,17 +10682,17 @@ static void ggml_compute_forward_mul_mat(
const int ith = params->ith;
const int nth = params->nth;

GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);

const enum ggml_type type = src0->type;

ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;

GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);

// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
GGML_ASSERT(nb10 == sizeof(float));
Expand All @@ -10706,16 +10703,16 @@ static void ggml_compute_forward_mul_mat(
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);

GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);

// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows

#if defined(GGML_USE_CLBLAST)
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
// ref: https://github.com/ggerganov/ggml/pull/224
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);

if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
}
Expand All @@ -10725,6 +10722,11 @@ static void ggml_compute_forward_mul_mat(

#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
// ref: https://github.com/ggerganov/ggml/pull/224
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);

if (params->ith != 0) {
return;
}
Expand Down Expand Up @@ -10794,41 +10796,44 @@ static void ggml_compute_forward_mul_mat(
return;
}

// parallelize by src0 rows using ggml_vec_dot_q
// parallelize by src0 rows
const int64_t dr = (ne01 + nth - 1)/nth;

// total rows in src0
const int nr = ne01*ne02*ne03;
const int64_t ir10 = dr*ith;
const int64_t ir11 = MIN(ir10 + dr, ne01);

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
// src1 rows
const int64_t nr1 = ne11*ne12*ne13;

void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];

for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);

const int i13 = i03;
const int i12 = i02;

const int i0 = i01;
const int i2 = i02;
const int i3 = i03;

void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));

float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));

for (int64_t ic = 0; ic < ne11; ++ic) {
vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];

for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
const int64_t i13 = (ir1/(ne12*ne11));
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);

const int64_t ir0 = (ir1/ne11)%(ne02*ne03);
const int64_t i03 = (ir0/(ne02));
// Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2.
// See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470:
// GG: this is likely the correct way to broadcast, though need some more thought
// therefore leaving the comments to remind us for now
const int64_t i02 = (i12 / (ne12 / ne02));
// Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon)
// const int64_t i02 = (ir0 - i03*ne02);

const int64_t i1 = i11;
const int64_t i2 = i12;
const int64_t i3 = i13;

const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 );
const char * src1_col = (const char *) wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size;

float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));

for (int64_t ir = ir10; ir < ir11; ++ir) {
vec_dot(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col);
}
}

Expand Down Expand Up @@ -13013,16 +13018,18 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
{
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;

for (int i12 = 0; i12 < ne12; i12++) {
const float * const src = (float *)((char *) src1->data + i12*nb12);
ggml_fp16_t * dst_data = wdata;

for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
for (int ik1 = 0; ik1 < nk1; ik1++) {
for (int ik0 = 0; ik0 < nk0; ik0++) {
dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
for (int i13 = 0; i13 < ne13; i13++) {
for (int i12 = 0; i12 < ne12; i12++) {
const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);

for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
for (int ik1 = 0; ik1 < nk1; ik1++) {
for (int ik0 = 0; ik0 < nk0; ik0++) {
dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
}
}
}
}
Expand All @@ -13049,14 +13056,16 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(

ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;

for (int i2 = ip0; i2 < ip1; i2++) {
float * dst_data = (float *)((char *) dst->data + i2*nb2);

for (int i1 = 0; i1 < ne1; ++i1) {
for (int i0 = 0; i0 < ne0; ++i0) {
ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
(ggml_fp16_t *) ((char *) src0->data + i2*nb03),
(ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0);
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = ip0; i2 < ip1; i2++) {
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2);

for (int i1 = 0; i1 < ne1; ++i1) {
for (int i0 = 0; i0 < ne0; ++i0) {
ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
(ggml_fp16_t *) ((char *) src0->data + i2*nb03),
(ggml_fp16_t *) wdata + i3*nb3 + (i1*ne0 + i0)*ew0);
}
}
}
}
Expand Down Expand Up @@ -13105,10 +13114,9 @@ static void ggml_compute_forward_conv_2d(

if (s0 == src0->ne[0] && s1 == src0->ne[1]) {
ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst);
}
else {
} else {
GGML_ASSERT(false); // only stride equal to kernel size is supported
};
}
}

// ggml_compute_forward_pool_1d_sk_p0
Expand Down Expand Up @@ -16558,8 +16566,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
{
n_tasks = n_threads;

GGML_ASSERT(node->src[1]->ne[3] == 1);

const int64_t ne00 = node->src[0]->ne[0]; // W
const int64_t ne01 = node->src[0]->ne[1]; // H
const int64_t ne02 = node->src[0]->ne[2]; // C
Expand Down