1
1
#include " ggml-blas.h"
2
2
#include " ggml-backend-impl.h"
3
3
4
+ #include < atomic>
5
+ #include < cassert>
4
6
#include < future>
5
7
#include < vector>
6
8
@@ -22,6 +24,7 @@ struct ggml_backend_blas_context {
22
24
#ifndef GGML_USE_OPENMP
23
25
std::vector<std::future<void >> tasks;
24
26
#endif
27
+ std::atomic<int > current_chunk;
25
28
};
26
29
27
30
// helper function to determine if it is better to use BLAS or not
@@ -48,6 +51,265 @@ static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) {
48
51
return false ;
49
52
}
50
53
54
+ static void ggml_compute_forward_mul_mat_one_chunk (
55
+ ggml_backend_blas_context * ctx,
56
+ struct ggml_tensor * dst,
57
+ const int64_t num_rows_per_vec_dot,
58
+ const int64_t ir0_start,
59
+ const int64_t ir0_end,
60
+ const int64_t ir1_start,
61
+ const int64_t ir1_end) {
62
+
63
+ const struct ggml_tensor * src0 = dst->src [0 ];
64
+ const struct ggml_tensor * src1 = dst->src [1 ];
65
+
66
+ GGML_TENSOR_BINARY_OP_LOCALS
67
+
68
+ const enum ggml_type type = src0->type ;
69
+
70
+ const bool src1_cont = ggml_is_contiguous (src1);
71
+
72
+ const ggml_type_traits_t * type_traits = ggml_internal_get_type_traits_ptr (type);
73
+
74
+ ggml_vec_dot_t const vec_dot = type_traits->vec_dot ;
75
+ enum ggml_type const vec_dot_type = type_traits->vec_dot_type ;
76
+
77
+ // broadcast factors
78
+ const int64_t r2 = ne12 / ne02;
79
+ const int64_t r3 = ne13 / ne03;
80
+
81
+ // printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
82
+
83
+ // threads with no work simply yield (not sure if it helps)
84
+ if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
85
+ return ;
86
+ }
87
+
88
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : ctx->work_data .get ();
89
+ const size_t row_size = ggml_row_size (vec_dot_type, ne10);
90
+
91
+ assert (ne12 % ne02 == 0 );
92
+ assert (ne13 % ne03 == 0 );
93
+
94
+ // block-tiling attempt
95
+ const int64_t blck_0 = 16 ;
96
+ const int64_t blck_1 = 16 ;
97
+
98
+ const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
99
+
100
+ // attempt to reduce false-sharing (does not seem to make a difference)
101
+ // 16 * 2, accounting for mmla kernels
102
+ float tmp[32 ];
103
+
104
+ for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
105
+ for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
106
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
107
+ const int64_t i13 = (ir1 / (ne12 * ne1));
108
+ const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
109
+ const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
110
+
111
+ // broadcast src0 into src1
112
+ const int64_t i03 = i13 / r3;
113
+ const int64_t i02 = i12 / r2;
114
+
115
+ const int64_t i1 = i11;
116
+ const int64_t i2 = i12;
117
+ const int64_t i3 = i13;
118
+
119
+ const char * src0_row = (const char *)src0->data + (0 + i02 * nb02 + i03 * nb03);
120
+
121
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
122
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
123
+ // the original src1 data pointer, so we should index using the indices directly
124
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
125
+ const char * src1_col = (const char *)wdata +
126
+ (src1_cont || src1->type != vec_dot_type
127
+ ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
128
+ : (i11 * nb11 + i12 * nb12 + i13 * nb13));
129
+ float * dst_col = (float *)((char *)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
130
+
131
+ // for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
132
+ // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
133
+ // }
134
+
135
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
136
+ vec_dot (ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0 ), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0 ), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0 ), num_rows_per_vec_dot);
137
+ }
138
+
139
+ for (int cn = 0 ; cn < num_rows_per_vec_dot; ++cn) {
140
+ memcpy (&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16 ), (std::min (iir0 + blck_0, ir0_end) - iir0) * sizeof (float ));
141
+ }
142
+ }
143
+ }
144
+ }
145
+ }
146
+
147
+ static void ggml_compute_forward_mul_mat (
148
+ ggml_backend_blas_context * ctx,
149
+ struct ggml_tensor * dst) {
150
+
151
+ const struct ggml_tensor * src0 = dst->src [0 ];
152
+ const struct ggml_tensor * src1 = dst->src [1 ];
153
+
154
+
155
+ GGML_TENSOR_BINARY_OP_LOCALS
156
+
157
+ const enum ggml_type type = src0->type ;
158
+
159
+ const ggml_type_traits_t * type_traits = ggml_internal_get_type_traits_ptr (type);
160
+ const ggml_type_traits_t * type_traits_vec_dot = ggml_internal_get_type_traits_ptr (type_traits->vec_dot_type );
161
+ enum ggml_type const vec_dot_type = type_traits->vec_dot_type ;
162
+ ggml_from_float_t const from_float_to_vec_dot = type_traits_vec_dot->from_float ;
163
+ int64_t const vec_dot_num_rows = type_traits->nrows ;
164
+
165
+ GGML_ASSERT (ne0 == ne01);
166
+ GGML_ASSERT (ne1 == ne11);
167
+ GGML_ASSERT (ne2 == ne12);
168
+ GGML_ASSERT (ne3 == ne13);
169
+
170
+ // we don't support permuted src0 or src1
171
+ GGML_ASSERT (nb00 == ggml_type_size (type));
172
+ GGML_ASSERT (nb10 == ggml_type_size (src1->type ));
173
+
174
+ // dst cannot be transposed or permuted
175
+ GGML_ASSERT (nb0 == sizeof (float ));
176
+ GGML_ASSERT (nb0 <= nb1);
177
+ GGML_ASSERT (nb1 <= nb2);
178
+ GGML_ASSERT (nb2 <= nb3);
179
+
180
+ // broadcast factors
181
+ const int64_t r2 = ne12 / ne02;
182
+ const int64_t r3 = ne13 / ne03;
183
+ GGML_UNUSED (r2);
184
+ GGML_UNUSED (r3);
185
+
186
+ // nb01 >= nb00 - src0 is not transposed
187
+ // compute by src0 rows
188
+
189
+ if (src1->type != vec_dot_type) {
190
+ const size_t row_size = ggml_row_size (vec_dot_type, ne10);
191
+ if (ctx->work_size < ne13*ne12*ne11*row_size) {
192
+ ctx->work_data .reset (new char [ne13*ne12*ne11*row_size]);
193
+ ctx->work_size = ne13*ne12*ne11*row_size;
194
+ }
195
+ char * wdata = ctx->work_data .get ();
196
+
197
+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
198
+ int block_size = ggml_blck_size (vec_dot_type);
199
+ int type_size = ggml_type_size (vec_dot_type);
200
+
201
+ for (int64_t i13 = 0 ; i13 < ne13; ++i13) {
202
+ for (int64_t i12 = 0 ; i12 < ne12; ++i12) {
203
+ for (int64_t i11 = 0 ; i11 < ne11; ++i11) {
204
+ // from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
205
+ // #pragma omp parallel num_threads(ctx->n_threads)
206
+ {
207
+ int nth = omp_get_num_threads ();
208
+ int ith = omp_get_thread_num ();
209
+ int blocks_per_thread = (ne10 + block_size - 1 ) / block_size / nth;
210
+ int i10_start = ith * blocks_per_thread * block_size;
211
+ int i10_end = std::min (i10_start + blocks_per_thread * block_size, (int )ne10);
212
+ // printf("thread %d/%d: i10_start = %d, i10_end = %d\n", ith, nth, i10_start, i10_end);
213
+ from_float_to_vec_dot ((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10_start*nb10),
214
+ (void *) ((char *) wdata + (type_size*i10_start/block_size)),
215
+ i10_end - i10_start);
216
+
217
+ }
218
+
219
+ wdata += row_size;
220
+ }
221
+ }
222
+ }
223
+ }
224
+
225
+ // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
226
+ const int64_t nr0 = ne0;
227
+
228
+ // This is the size of the rest of the dimensions of the result
229
+ const int64_t nr1 = ne1 * ne2 * ne3;
230
+
231
+ // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
232
+ int64_t num_rows_per_vec_dot = vec_dot_num_rows;
233
+ // TODO: currently the mmla kernels support only even numbered rows/cols.
234
+ // this check can be removed once they are extended to support odd numbered rows/cols too
235
+ if ((nr0 % 2 != 0 ) || (ne11 % 2 != 0 )) {
236
+ num_rows_per_vec_dot = 1 ;
237
+ }
238
+
239
+ // Now select a reasonable chunk size.
240
+ int chunk_size = 16 ;
241
+
242
+ // We need to step up the size if it's small
243
+ if (nr0 == 1 || nr1 == 1 ) {
244
+ chunk_size = 64 ;
245
+ }
246
+
247
+ // distribute the work across the inner or outer loop based on which one is larger
248
+ // The number of chunks in the 0/1 dim.
249
+ // CEIL(nr0/chunk_size)
250
+ int64_t nchunk0 = (nr0 + chunk_size - 1 ) / chunk_size;
251
+ int64_t nchunk1 = (nr1 + chunk_size - 1 ) / chunk_size;
252
+
253
+ // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
254
+ // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
255
+ // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
256
+
257
+ // const int ith = 0; // params->ith;
258
+ const int nth = ctx->n_threads ;
259
+
260
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
261
+ ctx->current_chunk .store (nth);
262
+
263
+ if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa ()) {
264
+ // distribute the thread work across the inner or outer loop based on which one is larger
265
+ nchunk0 = nr0 > nr1 ? nth : 1 ; // parallelize by src0 rows
266
+ nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
267
+ }
268
+
269
+ // The number of elements in each chunk
270
+ const int64_t dr0 = (nr0 + nchunk0 - 1 ) / nchunk0;
271
+ const int64_t dr1 = (nr1 + nchunk1 - 1 ) / nchunk1;
272
+
273
+ // if (ith == 0)
274
+ // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
275
+
276
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
277
+ if (nth > 1 ) {
278
+ #pragma omp parallel num_threads(nth)
279
+ {
280
+ int current_chunk = omp_get_thread_num ();
281
+
282
+ while (current_chunk < nchunk0 * nchunk1) {
283
+ const int64_t ith0 = current_chunk % nchunk0;
284
+ const int64_t ith1 = current_chunk / nchunk0;
285
+
286
+ const int64_t ir0_start = dr0 * ith0;
287
+ const int64_t ir0_end = std::min (ir0_start + dr0, nr0);
288
+
289
+ const int64_t ir1_start = dr1 * ith1;
290
+ const int64_t ir1_end = std::min (ir1_start + dr1, nr1);
291
+
292
+ ggml_compute_forward_mul_mat_one_chunk (ctx, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
293
+
294
+ if (nth >= nchunk0 * nchunk1) {
295
+ break ;
296
+ }
297
+
298
+ current_chunk = ctx->current_chunk .fetch_add (1 );
299
+ }
300
+ }
301
+ } else {
302
+ ggml_compute_forward_mul_mat_one_chunk (ctx, dst, num_rows_per_vec_dot, 0 , nr0, 0 , nr1);
303
+ }
304
+
305
+ #ifdef GGML_PERF
306
+ // These numbers are useful when trying to measure how well the threading scheduling works.
307
+ // int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1;
308
+ // float time = (ggml_perf_time_us() - t0);
309
+ // printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed);
310
+ #endif
311
+ }
312
+
51
313
static void ggml_backend_blas_mul_mat (ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
52
314
const struct ggml_tensor * src0 = dst->src [0 ];
53
315
const struct ggml_tensor * src1 = dst->src [1 ];
@@ -255,7 +517,8 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t
255
517
256
518
switch (node->op ) {
257
519
case GGML_OP_MUL_MAT:
258
- ggml_backend_blas_mul_mat (ctx, node);
520
+ // ggml_backend_blas_mul_mat(ctx, node);
521
+ ggml_compute_forward_mul_mat (ctx, node);
259
522
break ;
260
523
261
524
case GGML_OP_OUT_PROD:
@@ -281,6 +544,10 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t
281
544
}
282
545
283
546
GGML_CALL static bool ggml_backend_blas_supports_op (ggml_backend_t backend, const struct ggml_tensor * op) {
547
+
548
+ return op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_OUT_PROD;
549
+
550
+ /*
284
551
const struct ggml_tensor * src0 = op->src[0];
285
552
const struct ggml_tensor * src1 = op->src[1];
286
553
@@ -291,6 +558,7 @@ GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, cons
291
558
ggml_is_matrix(src1) &&
292
559
ggml_is_contiguous(src0) &&
293
560
(ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
561
+ */
294
562
295
563
GGML_UNUSED (backend);
296
564
}
0 commit comments