1
+ #include " ggml.h"
1
2
#include " common.cuh"
2
3
#include " mmv.cuh"
3
4
4
5
template <typename T, typename type_acc, int block_size>
5
6
static __global__ void mul_mat_vec (
6
7
const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
7
- const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
8
+ const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
9
+ const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
8
10
const int64_t row = blockIdx .x ;
9
- const int64_t channel = blockIdx .z ;
11
+ const int64_t channel = blockIdx .y ;
12
+ const int64_t sample = blockIdx .z ;
10
13
const int tid = threadIdx .x ;
11
14
constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
12
15
13
- x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
14
- y += channel *stride_channel_y;
15
- dst += channel *stride_channel_dst;
16
+ x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row;
17
+ y += sample *stride_sample_y + channel *stride_channel_y;
18
+ dst += sample *stride_sample_dst + channel *stride_channel_dst;
16
19
17
20
const float2 * y2 = (const float2 *) y;
18
21
@@ -91,12 +94,15 @@ template <typename T, typename type_acc>
91
94
static void launch_mul_mat_vec_cuda (
92
95
const T * x, const float * y, float * dst,
93
96
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
94
- const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
97
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
98
+ const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
95
99
cudaStream_t stream) {
96
100
GGML_ASSERT (ncols % 2 == 0 );
97
101
GGML_ASSERT (stride_row % 2 == 0 );
98
102
GGML_ASSERT (nchannels_y % nchannels_x == 0 );
103
+ GGML_ASSERT (nsamples_y % nsamples_x == 0 );
99
104
const int64_t channel_ratio = nchannels_y / nchannels_x;
105
+ const int64_t sample_ratio = nsamples_y / nsamples_x;
100
106
int device;
101
107
int warp_size;
102
108
@@ -118,40 +124,48 @@ static void launch_mul_mat_vec_cuda(
118
124
}
119
125
120
126
const int smem = warp_size*sizeof (float );
121
- const dim3 block_nums (nrows, 1 , nchannels_y );
127
+ const dim3 block_nums (nrows, nchannels_y, nsamples_y );
122
128
const dim3 block_dims (block_size_best, 1 , 1 );
123
129
switch (block_size_best) {
124
130
case 32 : {
125
131
mul_mat_vec<T, type_acc, 32 ><<<block_nums, block_dims, smem, stream>>>
126
- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
132
+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
133
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
127
134
} break ;
128
135
case 64 : {
129
136
mul_mat_vec<T, type_acc, 64 ><<<block_nums, block_dims, smem, stream>>>
130
- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
137
+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
138
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
131
139
} break ;
132
140
case 96 : {
133
141
mul_mat_vec<T, type_acc, 96 ><<<block_nums, block_dims, smem, stream>>>
134
- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
142
+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
143
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
135
144
} break ;
136
145
case 128 : {
137
146
mul_mat_vec<T, type_acc, 128 ><<<block_nums, block_dims, smem, stream>>>
138
- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
147
+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
148
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
139
149
} break ;
140
150
case 160 : {
141
151
mul_mat_vec<T, type_acc, 160 ><<<block_nums, block_dims, smem, stream>>>
142
- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
152
+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
153
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
143
154
} break ;
144
155
case 192 : {
145
156
mul_mat_vec<T, type_acc, 192 ><<<block_nums, block_dims, smem, stream>>>
146
- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
157
+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
158
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
147
159
} break ;
148
160
case 224 : {
149
161
mul_mat_vec<T, type_acc, 224 ><<<block_nums, block_dims, smem, stream>>>
150
- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
162
+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
163
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
151
164
} break ;
152
165
case 256 : {
153
166
mul_mat_vec<T, type_acc, 256 ><<<block_nums, block_dims, smem, stream>>>
154
- (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
167
+ (x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
168
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
155
169
} break ;
156
170
default : {
157
171
GGML_ABORT (" fatal error" );
@@ -163,16 +177,19 @@ template<typename T>
163
177
static void mul_mat_vec_cuda (
164
178
const T * x, const float * y, float * dst,
165
179
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
166
- const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
180
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
181
+ const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
167
182
enum ggml_prec prec, cudaStream_t stream) {
168
183
switch (prec) {
169
184
case GGML_PREC_DEFAULT: {
170
- launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
171
- stride_channel_x, stride_channel_y, stride_channel_dst, stream);
185
+ launch_mul_mat_vec_cuda<T, half>
186
+ (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
187
+ nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
172
188
} break ;
173
189
case GGML_PREC_F32: {
174
- launch_mul_mat_vec_cuda<T, float >(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
175
- stride_channel_x, stride_channel_y, stride_channel_dst, stream);
190
+ launch_mul_mat_vec_cuda<T, float >
191
+ (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
192
+ nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
176
193
} break ;
177
194
}
178
195
}
@@ -181,40 +198,42 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
181
198
GGML_ASSERT (src1->type == GGML_TYPE_F32);
182
199
GGML_ASSERT (dst->type == GGML_TYPE_F32);
183
200
184
- const int64_t ne00 = src0->ne [0 ];
185
- const int64_t ne01 = src0->ne [1 ];
201
+ GGML_TENSOR_BINARY_OP_LOCALS;
202
+
203
+ const size_t ts_src0 = ggml_type_size (src0->type );
204
+ const size_t ts_src1 = ggml_type_size (src1->type );
205
+ const size_t ts_dst = ggml_type_size (dst->type );
206
+
207
+ GGML_ASSERT (ne11 == 1 );
208
+ GGML_ASSERT (ne12 == ne2);
209
+ GGML_ASSERT (ne13 == ne3);
186
210
187
- GGML_ASSERT (src1->ne [1 ] == 1 );
211
+ GGML_ASSERT (nb00 == ts_src0);
212
+ GGML_ASSERT (nb10 == ts_src1);
213
+ GGML_ASSERT (nb0 == ts_dst);
188
214
189
215
const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
190
216
const enum ggml_prec prec = fast_fp16_available (cc) ? ggml_prec (dst->op_params [0 ]) : GGML_PREC_F32;
191
217
192
218
const float * src1_d = (const float *) src1->data ;
193
219
float * dst_d = (float *) dst->data ;
194
220
195
- const int64_t ne02 = src0->ne [2 ];
196
- const int64_t ne12 = src1->ne [2 ];
197
- GGML_ASSERT (dst->ne [2 ] == ne12);
198
-
199
- GGML_ASSERT (src0->ne [3 ] == 1 );
200
- GGML_ASSERT (src1->ne [3 ] == 1 );
201
- GGML_ASSERT ( dst->ne [3 ] == 1 );
202
-
203
- const int64_t stride_row = src0->nb [1 ] / ggml_type_size (src0->type );
204
- const int64_t channel_stride_x = src0->nb [2 ] / ggml_type_size (src0->type );
205
- const int64_t channel_stride_y = src1->nb [2 ] / ggml_type_size (src1->type );
206
- const int64_t channel_stride_dst = dst->nb [2 ] / ggml_type_size ( dst->type );
221
+ const int64_t s01 = src0->nb [1 ] / ts_src0;
222
+ const int64_t s02 = src0->nb [2 ] / ts_src0;
223
+ const int64_t s12 = src1->nb [2 ] / ts_src1;
224
+ const int64_t s2 = dst->nb [2 ] / ts_dst;
225
+ const int64_t s03 = src0->nb [3 ] / ts_src0;
226
+ const int64_t s13 = src1->nb [3 ] / ts_src1;
227
+ const int64_t s3 = dst->nb [3 ] / ts_dst;
207
228
208
229
switch (src0->type ) {
209
230
case GGML_TYPE_F16: {
210
231
const half * src0_d = (const half *) src0->data ;
211
- mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
212
- channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream ());
232
+ mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream ());
213
233
} break ;
214
234
case GGML_TYPE_BF16: {
215
235
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data ;
216
- mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
217
- channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream ());
236
+ mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream ());
218
237
} break ;
219
238
default :
220
239
GGML_ABORT (" unsupported type: %s" , ggml_type_name (src0->type ));
@@ -243,20 +262,27 @@ void ggml_cuda_op_mul_mat_vec(
243
262
const int64_t stride_row = ne00;
244
263
const int64_t nchannels_x = 1 ;
245
264
const int64_t nchannels_y = 1 ;
246
- const int64_t channel_stride_x = 0 ;
247
- const int64_t channel_stride_y = 0 ;
248
- const int64_t channel_stride_dst = 0 ;
265
+ const int64_t stride_channel_x = 0 ;
266
+ const int64_t stride_channel_y = 0 ;
267
+ const int64_t stride_channel_dst = 0 ;
268
+ const int64_t nsamples_x = 1 ;
269
+ const int64_t nsamples_y = 1 ;
270
+ const int64_t stride_sample_x = 0 ;
271
+ const int64_t stride_sample_y = 0 ;
272
+ const int64_t stride_sample_dst = 0 ;
249
273
250
274
switch (src0->type ) {
251
275
case GGML_TYPE_F16: {
252
276
const half * src0_d = (const half *) src0_dd_i;
253
277
mul_mat_vec_cuda (src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
254
- nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
278
+ nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
279
+ nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
255
280
} break ;
256
281
case GGML_TYPE_BF16: {
257
282
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
258
283
mul_mat_vec_cuda (src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
259
- nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
284
+ nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
285
+ nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
260
286
} break ;
261
287
default :
262
288
GGML_ABORT (" unsupported type: %s" , ggml_type_name (src0->type ));
0 commit comments