Skip to content

Commit 95735a6

Browse files
committed
Unify mul_mat
1 parent 5ea88de commit 95735a6

File tree

1 file changed

+18
-248
lines changed

1 file changed

+18
-248
lines changed

ggml.c

Lines changed: 18 additions & 248 deletions
Original file line numberDiff line numberDiff line change
@@ -1968,6 +1968,7 @@ static void f16_from_float(const float * restrict x, ggml_fp16_t * restrict y, i
19681968
}
19691969
}
19701970

1971+
static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
19711972
static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
19721973
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
19731974
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
@@ -1977,6 +1978,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
19771978
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
19781979

19791980
static const ggml_type_handling_t type_handling[GGML_TYPE_COUNT] = {
1981+
[GGML_TYPE_F32] = {
1982+
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
1983+
.vec_dot_type = GGML_TYPE_F32,
1984+
},
19801985
[GGML_TYPE_F16] = {
19811986
.to_float = (ggml_to_float_t) f16_to_float,
19821987
.from_float = (ggml_from_float_t) f16_from_float,
@@ -2561,7 +2566,7 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
25612566
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
25622567
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
25632568

2564-
inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
2569+
static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
25652570
#ifdef GGML_SIMD
25662571
float sumf = 0.0f;
25672572
const int np = (n & ~(GGML_F32_STEP - 1));
@@ -7941,215 +7946,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
79417946

79427947
#endif
79437948

7944-
static void ggml_compute_forward_mul_mat_f32(
7945-
const struct ggml_compute_params * params,
7946-
const struct ggml_tensor * src0,
7947-
const struct ggml_tensor * src1,
7948-
struct ggml_tensor * dst) {
7949-
int64_t t0 = ggml_perf_time_us();
7950-
UNUSED(t0);
7951-
7952-
const int64_t ne00 = src0->ne[0];
7953-
const int64_t ne01 = src0->ne[1];
7954-
const int64_t ne02 = src0->ne[2];
7955-
const int64_t ne03 = src0->ne[3];
7956-
7957-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7958-
const int64_t ne10 = src1->ne[0];
7959-
#endif
7960-
const int64_t ne11 = src1->ne[1];
7961-
#ifndef NDEBUG
7962-
const int64_t ne12 = src1->ne[2];
7963-
const int64_t ne13 = src1->ne[3];
7964-
7965-
const int64_t ne0 = dst->ne[0];
7966-
const int64_t ne1 = dst->ne[1];
7967-
const int64_t ne2 = dst->ne[2];
7968-
const int64_t ne3 = dst->ne[3];
7969-
7970-
const int nb00 = src0->nb[0];
7971-
#endif
7972-
const int nb01 = src0->nb[1];
7973-
const int nb02 = src0->nb[2];
7974-
const int nb03 = src0->nb[3];
7975-
7976-
#ifndef NDEBUG
7977-
const int nb10 = src1->nb[0];
7978-
#endif
7979-
const int nb11 = src1->nb[1];
7980-
const int nb12 = src1->nb[2];
7981-
const int nb13 = src1->nb[3];
7982-
7983-
const int nb0 = dst->nb[0];
7984-
const int nb1 = dst->nb[1];
7985-
const int nb2 = dst->nb[2];
7986-
const int nb3 = dst->nb[3];
7987-
7988-
const int ith = params->ith;
7989-
const int nth = params->nth;
7990-
7991-
assert(ne02 == ne12);
7992-
assert(ne03 == ne13);
7993-
assert(ne2 == ne12);
7994-
assert(ne3 == ne13);
7995-
7996-
// we don't support permuted src0 or src1
7997-
assert(nb00 == sizeof(float));
7998-
assert(nb10 == sizeof(float));
7999-
8000-
// dst cannot be transposed or permuted
8001-
assert(nb0 == sizeof(float));
8002-
assert(nb0 <= nb1);
8003-
assert(nb1 <= nb2);
8004-
assert(nb2 <= nb3);
8005-
8006-
assert(ne0 == ne01);
8007-
assert(ne1 == ne11);
8008-
assert(ne2 == ne02);
8009-
assert(ne3 == ne03);
8010-
8011-
// nb01 >= nb00 - src0 is not transposed
8012-
// compute by src0 rows
8013-
8014-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
8015-
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
8016-
if (params->ith != 0) {
8017-
return;
8018-
}
8019-
8020-
if (params->type == GGML_TASK_INIT) {
8021-
return;
8022-
}
8023-
8024-
if (params->type == GGML_TASK_FINALIZE) {
8025-
return;
8026-
}
8027-
8028-
#if defined(GGML_USE_CUBLAS)
8029-
const float alpha = 1.0f;
8030-
const float beta = 0.0f;
8031-
const int x_ne = ne01 * ne00;
8032-
const int y_ne = ne11 * ne10;
8033-
const int d_ne = ne11 * ne01;
8034-
8035-
size_t x_size, y_size, d_size;
8036-
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8037-
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8038-
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8039-
#endif
8040-
8041-
for (int64_t i03 = 0; i03 < ne03; i03++) {
8042-
for (int64_t i02 = 0; i02 < ne02; i02++) {
8043-
#if !defined(GGML_USE_CUBLAS)
8044-
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
8045-
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8046-
#endif
8047-
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8048-
8049-
#if defined(GGML_USE_CUBLAS)
8050-
// copy data to device
8051-
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
8052-
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
8053-
8054-
// compute
8055-
CUBLAS_CHECK(
8056-
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8057-
ne01, ne11, ne10,
8058-
&alpha, d_X, ne00,
8059-
d_Y, ne10,
8060-
&beta, d_D, ne01));
8061-
8062-
// copy data to host
8063-
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8064-
#elif defined(GGML_USE_CLBLAST)
8065-
// zT = y * xT
8066-
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8067-
ne11, ne01, ne10,
8068-
1.0f, y, ne10,
8069-
x, ne10,
8070-
0.0f, d, ne01,
8071-
GGML_TYPE_F32);
8072-
#else
8073-
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
8074-
ne11, ne01, ne10,
8075-
1.0f, y, ne10,
8076-
x, ne00,
8077-
0.0f, d, ne01);
8078-
#endif
8079-
}
8080-
}
8081-
#if defined(GGML_USE_CUBLAS)
8082-
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8083-
ggml_cuda_pool_free(d_X, x_size);
8084-
ggml_cuda_pool_free(d_Y, y_size);
8085-
ggml_cuda_pool_free(d_D, d_size);
8086-
#endif
8087-
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
8088-
8089-
return;
8090-
}
8091-
#endif
8092-
8093-
if (params->type == GGML_TASK_INIT) {
8094-
return;
8095-
}
8096-
8097-
if (params->type == GGML_TASK_FINALIZE) {
8098-
return;
8099-
}
8100-
8101-
// parallelize by src0 rows using ggml_vec_dot_f32
8102-
8103-
// total rows in src0
8104-
const int nr = ne01*ne02*ne03;
8105-
8106-
// rows per thread
8107-
const int dr = (nr + nth - 1)/nth;
8108-
8109-
// row range for this thread
8110-
const int ir0 = dr*ith;
8111-
const int ir1 = MIN(ir0 + dr, nr);
8112-
8113-
for (int ir = ir0; ir < ir1; ++ir) {
8114-
// src0 indices
8115-
const int i03 = ir/(ne02*ne01);
8116-
const int i02 = (ir - i03*ne02*ne01)/ne01;
8117-
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
8118-
8119-
for (int64_t ic = 0; ic < ne11; ++ic) {
8120-
// src1 indices
8121-
const int i13 = i03;
8122-
const int i12 = i02;
8123-
const int i11 = ic;
8124-
8125-
// dst indices
8126-
const int i0 = i01;
8127-
const int i1 = i11;
8128-
const int i2 = i02;
8129-
const int i3 = i03;
8130-
8131-
ggml_vec_dot_f32(ne00,
8132-
(float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
8133-
(float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
8134-
(float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
8135-
}
8136-
}
8137-
8138-
//int64_t t1 = ggml_perf_time_us();
8139-
//static int64_t acc = 0;
8140-
//acc += t1 - t0;
8141-
//if (t1 - t0 > 10) {
8142-
// printf("\n");
8143-
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
8144-
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
8145-
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
8146-
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
8147-
8148-
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
8149-
//}
8150-
}
8151-
8152-
static void ggml_compute_forward_mul_mat_q_f32(
7949+
static void ggml_compute_forward_mul_mat(
81537950
const struct ggml_compute_params * params,
81547951
const struct ggml_tensor * src0,
81557952
const struct ggml_tensor * src1,
@@ -8330,18 +8127,19 @@ static void ggml_compute_forward_mul_mat_q_f32(
83308127
#endif
83318128

83328129
if (params->type == GGML_TASK_INIT) {
8333-
char * wdata = params->wdata;
8334-
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
8335-
8336-
for (int64_t i13 = 0; i13 < ne13; ++i13) {
8337-
for (int64_t i12 = 0; i12 < ne12; ++i12) {
8338-
for (int64_t i11 = 0; i11 < ne11; ++i11) {
8339-
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
8340-
wdata += row_size;
8130+
if (vec_dot_type != GGML_TYPE_F32) {
8131+
char * wdata = params->wdata;
8132+
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
8133+
8134+
for (int64_t i13 = 0; i13 < ne13; ++i13) {
8135+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
8136+
for (int64_t i11 = 0; i11 < ne11; ++i11) {
8137+
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
8138+
wdata += row_size;
8139+
}
83418140
}
83428141
}
83438142
}
8344-
83458143
return;
83468144
}
83478145

@@ -8361,7 +8159,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
83618159
const int ir0 = dr*ith;
83628160
const int ir1 = MIN(ir0 + dr, nr);
83638161

8364-
void * wdata = params->wdata;
8162+
void * wdata = (vec_dot_type == GGML_TYPE_F32) ? src1->data : params->wdata;
83658163
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
83668164

83678165
for (int ir = ir0; ir < ir1; ++ir) {
@@ -8402,34 +8200,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
84028200
//}
84038201
}
84048202

8405-
static void ggml_compute_forward_mul_mat(
8406-
const struct ggml_compute_params * params,
8407-
const struct ggml_tensor * src0,
8408-
const struct ggml_tensor * src1,
8409-
struct ggml_tensor * dst) {
8410-
switch (src0->type) {
8411-
case GGML_TYPE_Q4_0:
8412-
case GGML_TYPE_Q4_1:
8413-
case GGML_TYPE_Q4_2:
8414-
case GGML_TYPE_Q5_0:
8415-
case GGML_TYPE_Q5_1:
8416-
case GGML_TYPE_Q8_0:
8417-
case GGML_TYPE_Q8_1:
8418-
case GGML_TYPE_F16:
8419-
{
8420-
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
8421-
} break;
8422-
case GGML_TYPE_F32:
8423-
{
8424-
ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
8425-
} break;
8426-
default:
8427-
{
8428-
GGML_ASSERT(false);
8429-
} break;
8430-
}
8431-
}
8432-
84338203
// ggml_compute_forward_scale
84348204

84358205
static void ggml_compute_forward_scale_f32(

0 commit comments

Comments
 (0)