Skip to content

Commit e9a1b01

Browse files
JohannesGaesslermglambda
authored andcommitted
CUDA: support for mat. mul. with ne03 != ne13 (ggml-org#11656)
1 parent 8af1949 commit e9a1b01

File tree

2 files changed

+81
-60
lines changed

2 files changed

+81
-60
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,8 +1366,6 @@ static void ggml_cuda_op_mul_mat(
13661366
const int64_t ne13 = src1->ne[3];
13671367
const int64_t nrows1 = ggml_nrows(src1);
13681368

1369-
GGML_ASSERT(ne03 == ne13);
1370-
13711369
const int64_t ne0 = dst->ne[0];
13721370
const int64_t ne1 = dst->ne[1];
13731371

@@ -1381,9 +1379,11 @@ static void ggml_cuda_op_mul_mat(
13811379

13821380
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
13831381

1384-
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
1382+
GGML_ASSERT(ne12 % ne02 == 0);
1383+
GGML_ASSERT(ne13 % ne03 == 0);
13851384

13861385
const int64_t i02_divisor = ne12 / ne02;
1386+
const int64_t i03_divisor = ne13 / ne03;
13871387

13881388
const size_t src0_ts = ggml_type_size(src0->type);
13891389
const size_t src0_bs = ggml_blck_size(src0->type);
@@ -1399,6 +1399,7 @@ static void ggml_cuda_op_mul_mat(
13991399
GGML_ASSERT(!(split && ne02 > 1));
14001400
GGML_ASSERT(!(split && ne03 > 1));
14011401
GGML_ASSERT(!(split && ne02 < ne12));
1402+
GGML_ASSERT(!(split && ne03 < ne13));
14021403

14031404
ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
14041405

@@ -1562,7 +1563,8 @@ static void ggml_cuda_op_mul_mat(
15621563
}
15631564

15641565
// for split tensors the data begins at i0 == i0_offset_low
1565-
char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
1566+
const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs;
1567+
char * src0_dd_i = dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix;
15661568
float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
15671569
char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset;
15681570
float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
@@ -1606,8 +1608,9 @@ static void ggml_cuda_op_mul_mat(
16061608
CUDA_CHECK(cudaGetLastError());
16071609
}
16081610

1609-
if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
1610-
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
1611+
if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) {
1612+
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
1613+
src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
16111614
}
16121615

16131616
// do the computation
@@ -1882,7 +1885,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
18821885
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
18831886
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
18841887

1885-
if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
1888+
if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
18861889
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
18871890
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
18881891
ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
@@ -2216,12 +2219,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22162219
ggml_cuda_op_rms_norm_back(ctx, dst);
22172220
break;
22182221
case GGML_OP_MUL_MAT:
2219-
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
2220-
GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
2221-
return false;
2222-
} else {
2223-
ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
2224-
}
2222+
ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
22252223
break;
22262224
case GGML_OP_MUL_MAT_ID:
22272225
ggml_cuda_mul_mat_id(ctx, dst);
@@ -2998,9 +2996,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29982996
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
29992997
return false;
30002998
}
3001-
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
3002-
return false;
3003-
}
30042999
#ifdef GGML_USE_MUSA
30053000
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
30063001
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {

ggml/src/ggml-cuda/mmv.cu

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1+
#include "ggml.h"
12
#include "common.cuh"
23
#include "mmv.cuh"
34

45
template <typename T, typename type_acc, int block_size>
56
static __global__ void mul_mat_vec(
67
const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
7-
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
8+
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
9+
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
810
const int64_t row = blockIdx.x;
9-
const int64_t channel = blockIdx.z;
11+
const int64_t channel = blockIdx.y;
12+
const int64_t sample = blockIdx.z;
1013
const int tid = threadIdx.x;
1114
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1215

13-
x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
14-
y += channel *stride_channel_y;
15-
dst += channel *stride_channel_dst;
16+
x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row;
17+
y += sample *stride_sample_y + channel *stride_channel_y;
18+
dst += sample *stride_sample_dst + channel *stride_channel_dst;
1619

1720
const float2 * y2 = (const float2 *) y;
1821

@@ -91,12 +94,15 @@ template <typename T, typename type_acc>
9194
static void launch_mul_mat_vec_cuda(
9295
const T * x, const float * y, float * dst,
9396
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
94-
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
97+
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
98+
const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
9599
cudaStream_t stream) {
96100
GGML_ASSERT(ncols % 2 == 0);
97101
GGML_ASSERT(stride_row % 2 == 0);
98102
GGML_ASSERT(nchannels_y % nchannels_x == 0);
103+
GGML_ASSERT(nsamples_y % nsamples_x == 0);
99104
const int64_t channel_ratio = nchannels_y / nchannels_x;
105+
const int64_t sample_ratio = nsamples_y / nsamples_x;
100106
int device;
101107
int warp_size;
102108

@@ -118,40 +124,48 @@ static void launch_mul_mat_vec_cuda(
118124
}
119125

120126
const int smem = warp_size*sizeof(float);
121-
const dim3 block_nums(nrows, 1, nchannels_y);
127+
const dim3 block_nums(nrows, nchannels_y, nsamples_y);
122128
const dim3 block_dims(block_size_best, 1, 1);
123129
switch (block_size_best) {
124130
case 32: {
125131
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
126-
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
132+
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
133+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
127134
} break;
128135
case 64: {
129136
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
130-
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
137+
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
138+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
131139
} break;
132140
case 96: {
133141
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
134-
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
142+
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
143+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
135144
} break;
136145
case 128: {
137146
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
138-
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
147+
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
148+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
139149
} break;
140150
case 160: {
141151
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
142-
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
152+
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
153+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
143154
} break;
144155
case 192: {
145156
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
146-
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
157+
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
158+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
147159
} break;
148160
case 224: {
149161
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
150-
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
162+
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
163+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
151164
} break;
152165
case 256: {
153166
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
154-
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
167+
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
168+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
155169
} break;
156170
default: {
157171
GGML_ABORT("fatal error");
@@ -163,16 +177,19 @@ template<typename T>
163177
static void mul_mat_vec_cuda(
164178
const T * x, const float * y, float * dst,
165179
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
166-
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
180+
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
181+
const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
167182
enum ggml_prec prec, cudaStream_t stream) {
168183
switch (prec) {
169184
case GGML_PREC_DEFAULT: {
170-
launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
171-
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
185+
launch_mul_mat_vec_cuda<T, half>
186+
(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
187+
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
172188
} break;
173189
case GGML_PREC_F32: {
174-
launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
175-
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
190+
launch_mul_mat_vec_cuda<T, float>
191+
(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
192+
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
176193
} break;
177194
}
178195
}
@@ -181,40 +198,42 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
181198
GGML_ASSERT(src1->type == GGML_TYPE_F32);
182199
GGML_ASSERT(dst->type == GGML_TYPE_F32);
183200

184-
const int64_t ne00 = src0->ne[0];
185-
const int64_t ne01 = src0->ne[1];
201+
GGML_TENSOR_BINARY_OP_LOCALS;
202+
203+
const size_t ts_src0 = ggml_type_size(src0->type);
204+
const size_t ts_src1 = ggml_type_size(src1->type);
205+
const size_t ts_dst = ggml_type_size(dst->type);
206+
207+
GGML_ASSERT(ne11 == 1);
208+
GGML_ASSERT(ne12 == ne2);
209+
GGML_ASSERT(ne13 == ne3);
186210

187-
GGML_ASSERT(src1->ne[1] == 1);
211+
GGML_ASSERT(nb00 == ts_src0);
212+
GGML_ASSERT(nb10 == ts_src1);
213+
GGML_ASSERT(nb0 == ts_dst);
188214

189215
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
190216
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
191217

192218
const float * src1_d = (const float *) src1->data;
193219
float * dst_d = (float *) dst->data;
194220

195-
const int64_t ne02 = src0->ne[2];
196-
const int64_t ne12 = src1->ne[2];
197-
GGML_ASSERT(dst->ne[2] == ne12);
198-
199-
GGML_ASSERT(src0->ne[3] == 1);
200-
GGML_ASSERT(src1->ne[3] == 1);
201-
GGML_ASSERT( dst->ne[3] == 1);
202-
203-
const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type);
204-
const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type);
205-
const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
206-
const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
221+
const int64_t s01 = src0->nb[1] / ts_src0;
222+
const int64_t s02 = src0->nb[2] / ts_src0;
223+
const int64_t s12 = src1->nb[2] / ts_src1;
224+
const int64_t s2 = dst->nb[2] / ts_dst;
225+
const int64_t s03 = src0->nb[3] / ts_src0;
226+
const int64_t s13 = src1->nb[3] / ts_src1;
227+
const int64_t s3 = dst->nb[3] / ts_dst;
207228

208229
switch (src0->type) {
209230
case GGML_TYPE_F16: {
210231
const half * src0_d = (const half *) src0->data;
211-
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
212-
channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
232+
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
213233
} break;
214234
case GGML_TYPE_BF16: {
215235
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
216-
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
217-
channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
236+
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
218237
} break;
219238
default:
220239
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -243,20 +262,27 @@ void ggml_cuda_op_mul_mat_vec(
243262
const int64_t stride_row = ne00;
244263
const int64_t nchannels_x = 1;
245264
const int64_t nchannels_y = 1;
246-
const int64_t channel_stride_x = 0;
247-
const int64_t channel_stride_y = 0;
248-
const int64_t channel_stride_dst = 0;
265+
const int64_t stride_channel_x = 0;
266+
const int64_t stride_channel_y = 0;
267+
const int64_t stride_channel_dst = 0;
268+
const int64_t nsamples_x = 1;
269+
const int64_t nsamples_y = 1;
270+
const int64_t stride_sample_x = 0;
271+
const int64_t stride_sample_y = 0;
272+
const int64_t stride_sample_dst = 0;
249273

250274
switch (src0->type) {
251275
case GGML_TYPE_F16: {
252276
const half * src0_d = (const half *) src0_dd_i;
253277
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
254-
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
278+
nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
279+
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
255280
} break;
256281
case GGML_TYPE_BF16: {
257282
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
258283
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
259-
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
284+
nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
285+
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
260286
} break;
261287
default:
262288
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));

0 commit comments

Comments
 (0)