1
1
#include " common.cuh"
2
2
#include " mmv.cuh"
3
3
4
- template <typename type_acc, int block_size>
4
+ template <typename T, typename type_acc, int block_size>
5
5
static __global__ void mul_mat_vec (
6
- const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
6
+ const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
7
7
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
8
8
const int64_t row = blockIdx .x ;
9
9
const int64_t channel = blockIdx .z ;
@@ -13,7 +13,6 @@ static __global__ void mul_mat_vec(
13
13
y += channel *stride_channel_y;
14
14
dst += channel *stride_channel_dst;
15
15
16
- const half2 * x2 = (const half2 *) x;
17
16
const float2 * y2 = (const float2 *) y;
18
17
19
18
extern __shared__ char data_mmv[];
@@ -28,28 +27,44 @@ static __global__ void mul_mat_vec(
28
27
29
28
float sumf;
30
29
31
- if (std::is_same<type_acc, float >::value) {
32
- sumf = 0 . 0f ;
30
+ if constexpr (std::is_same<T, half >::value) {
31
+ const half2 * x2 = ( const half2 *) x ;
33
32
34
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
35
- const float2 tmpx = __half22float2 (x2[col2]);
36
- const float2 tmpy = y2[col2];
37
- sumf += tmpx.x * tmpy.x ;
38
- sumf += tmpx.y * tmpy.y ;
39
- }
40
- } else {
33
+ if (std::is_same<type_acc, float >::value) {
34
+ sumf = 0 .0f ;
35
+
36
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
37
+ const float2 tmpx = __half22float2 (x2[col2]);
38
+ const float2 tmpy = y2[col2];
39
+ sumf += tmpx.x * tmpy.x ;
40
+ sumf += tmpx.y * tmpy.y ;
41
+ }
42
+ } else {
41
43
#ifdef FP16_AVAILABLE
42
- half2 sumh2 = make_half2 (0 .0f , 0 .0f );
44
+ half2 sumh2 = make_half2 (0 .0f , 0 .0f );
43
45
44
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
45
- const float2 tmp = y2[col2];
46
- sumh2 += x2[col2] * make_half2 (tmp.x , tmp.y );
47
- }
46
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
47
+ const float2 tmp = y2[col2];
48
+ sumh2 += x2[col2] * make_half2 (tmp.x , tmp.y );
49
+ }
48
50
49
- sumf = __low2float (sumh2) + __high2float (sumh2);
51
+ sumf = __low2float (sumh2) + __high2float (sumh2);
50
52
#else
51
- NO_DEVICE_CODE;
53
+ NO_DEVICE_CODE;
52
54
#endif // FP16_AVAILABLE
55
+ }
56
+ } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
57
+ const int * x2 = (const int *) x;
58
+ sumf = 0 .0f ;
59
+
60
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
61
+ const int tmpx = x2[col2];
62
+ const float2 tmpy = y2[col2];
63
+ sumf += float (reinterpret_cast <const nv_bfloat16 *>(&tmpx)[0 ]) * tmpy.x ;
64
+ sumf += float (reinterpret_cast <const nv_bfloat16 *>(&tmpx)[1 ]) * tmpy.y ;
65
+ }
66
+ } else {
67
+ static_assert (std::is_same<T, void >::value, " unsupported type" );
53
68
}
54
69
55
70
sumf = warp_reduce_sum (sumf);
@@ -71,9 +86,9 @@ static __global__ void mul_mat_vec(
71
86
dst[row] = sumf;
72
87
}
73
88
74
- template <typename type_acc>
89
+ template <typename T, typename type_acc>
75
90
static void launch_mul_mat_vec_cuda (
76
- const half * x, const float * y, float * dst,
91
+ const T * x, const float * y, float * dst,
77
92
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
78
93
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
79
94
cudaStream_t stream) {
@@ -97,35 +112,35 @@ static void launch_mul_mat_vec_cuda(
97
112
const dim3 block_dims (block_size_best, 1 , 1 );
98
113
switch (block_size_best) {
99
114
case 32 : {
100
- mul_mat_vec<type_acc, 32 ><<<block_nums, block_dims, smem, stream>>>
115
+ mul_mat_vec<T, type_acc, 32 ><<<block_nums, block_dims, smem, stream>>>
101
116
(x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
102
117
} break ;
103
118
case 64 : {
104
- mul_mat_vec<type_acc, 64 ><<<block_nums, block_dims, smem, stream>>>
119
+ mul_mat_vec<T, type_acc, 64 ><<<block_nums, block_dims, smem, stream>>>
105
120
(x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
106
121
} break ;
107
122
case 96 : {
108
- mul_mat_vec<type_acc, 96 ><<<block_nums, block_dims, smem, stream>>>
123
+ mul_mat_vec<T, type_acc, 96 ><<<block_nums, block_dims, smem, stream>>>
109
124
(x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
110
125
} break ;
111
126
case 128 : {
112
- mul_mat_vec<type_acc, 128 ><<<block_nums, block_dims, smem, stream>>>
127
+ mul_mat_vec<T, type_acc, 128 ><<<block_nums, block_dims, smem, stream>>>
113
128
(x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
114
129
} break ;
115
130
case 160 : {
116
- mul_mat_vec<type_acc, 160 ><<<block_nums, block_dims, smem, stream>>>
131
+ mul_mat_vec<T, type_acc, 160 ><<<block_nums, block_dims, smem, stream>>>
117
132
(x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
118
133
} break ;
119
134
case 192 : {
120
- mul_mat_vec<type_acc, 192 ><<<block_nums, block_dims, smem, stream>>>
135
+ mul_mat_vec<T, type_acc, 192 ><<<block_nums, block_dims, smem, stream>>>
121
136
(x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
122
137
} break ;
123
138
case 224 : {
124
- mul_mat_vec<type_acc, 224 ><<<block_nums, block_dims, smem, stream>>>
139
+ mul_mat_vec<T, type_acc, 224 ><<<block_nums, block_dims, smem, stream>>>
125
140
(x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
126
141
} break ;
127
142
case 256 : {
128
- mul_mat_vec<type_acc, 256 ><<<block_nums, block_dims, smem, stream>>>
143
+ mul_mat_vec<T, type_acc, 256 ><<<block_nums, block_dims, smem, stream>>>
129
144
(x, y, dst, ncols/2 , stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
130
145
} break ;
131
146
default : {
@@ -134,25 +149,25 @@ static void launch_mul_mat_vec_cuda(
134
149
}
135
150
}
136
151
152
+ template <typename T>
137
153
static void mul_mat_vec_cuda (
138
- const half * x, const float * y, float * dst,
154
+ const T * x, const float * y, float * dst,
139
155
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
140
156
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
141
157
enum ggml_prec prec, cudaStream_t stream) {
142
158
switch (prec) {
143
159
case GGML_PREC_DEFAULT: {
144
- launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
160
+ launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
145
161
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
146
162
} break ;
147
163
case GGML_PREC_F32: {
148
- launch_mul_mat_vec_cuda<float >(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
164
+ launch_mul_mat_vec_cuda<T, float >(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
149
165
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
150
166
} break ;
151
167
}
152
168
}
153
169
154
170
void ggml_cuda_mul_mat_vec (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
155
- GGML_ASSERT (src0->type == GGML_TYPE_F16);
156
171
GGML_ASSERT (src1->type == GGML_TYPE_F32);
157
172
GGML_ASSERT (dst->type == GGML_TYPE_F32);
158
173
@@ -164,7 +179,6 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
164
179
const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
165
180
const enum ggml_prec prec = fast_fp16_available (cc) ? ggml_prec (dst->op_params [0 ]) : GGML_PREC_F32;
166
181
167
- const half * src0_d = (const half *) src0->data ;
168
182
const float * src1_d = (const float *) src1->data ;
169
183
float * dst_d = (float *) dst->data ;
170
184
@@ -181,7 +195,20 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
181
195
const int64_t channel_stride_y = src1->nb [2 ] / ggml_type_size (src1->type );
182
196
const int64_t channel_stride_dst = dst->nb [2 ] / ggml_type_size ( dst->type );
183
197
184
- mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream ());
198
+ switch (src0->type ) {
199
+ case GGML_TYPE_F16: {
200
+ const half * src0_d = (const half *) src0->data ;
201
+ mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
202
+ channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream ());
203
+ } break ;
204
+ case GGML_TYPE_BF16: {
205
+ const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data ;
206
+ mul_mat_vec_cuda (src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
207
+ channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream ());
208
+ } break ;
209
+ default :
210
+ GGML_ABORT (" unsupported type: %s" , ggml_type_name (src0->type ));
211
+ }
185
212
}
186
213
187
214
void ggml_cuda_op_mul_mat_vec (
@@ -190,7 +217,6 @@ void ggml_cuda_op_mul_mat_vec(
190
217
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
191
218
const int64_t src1_padded_row_size, cudaStream_t stream) {
192
219
193
- GGML_ASSERT (src0->type == GGML_TYPE_F16);
194
220
GGML_ASSERT (src1->type == GGML_TYPE_F32);
195
221
GGML_ASSERT (dst->type == GGML_TYPE_F32);
196
222
@@ -211,8 +237,20 @@ void ggml_cuda_op_mul_mat_vec(
211
237
const int64_t channel_stride_y = 0 ;
212
238
const int64_t channel_stride_dst = 0 ;
213
239
214
- mul_mat_vec_cuda ((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
215
- nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
240
+ switch (src0->type ) {
241
+ case GGML_TYPE_F16: {
242
+ const half * src0_d = (const half *) src0_dd_i;
243
+ mul_mat_vec_cuda (src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
244
+ nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
245
+ } break ;
246
+ case GGML_TYPE_BF16: {
247
+ const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
248
+ mul_mat_vec_cuda (src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
249
+ nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
250
+ } break ;
251
+ default :
252
+ GGML_ABORT (" unsupported type: %s" , ggml_type_name (src0->type ));
253
+ }
216
254
217
255
GGML_UNUSED (ctx);
218
256
GGML_UNUSED (src1);
0 commit comments