@@ -33,8 +33,8 @@ static __global__ void k_get_rows(
33
33
dfloat2 v;
34
34
dequantize_kernel (src0_row, ib, iqs, v);
35
35
36
- dst_row[iybs + iqs + 0 ] = v.x ;
37
- dst_row[iybs + iqs + y_offset] = v.y ;
36
+ dst_row[iybs + iqs + 0 ] = float ( v.x ) ;
37
+ dst_row[iybs + iqs + y_offset] = float ( v.y ) ;
38
38
}
39
39
40
40
template <typename src0_t , typename dst_t >
@@ -60,7 +60,7 @@ static __global__ void k_get_rows_float(
60
60
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
61
61
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
62
62
63
- dst_row[i00] = src0_row[i00];
63
+ dst_row[i00] = float ( src0_row[i00]) ;
64
64
}
65
65
66
66
template <typename grad_t , typename dst_t >
@@ -86,122 +86,161 @@ static __global__ void k_get_rows_back_float(
86
86
dst[dst_row*ncols + col] = sum;
87
87
}
88
88
89
- template <int qk, int qr, dequantize_kernel_t dq>
90
- static void get_rows_cuda (
91
- const ggml_tensor * src0 , const ggml_tensor * src1, ggml_tensor * dst ,
92
- const void * src0_dd , const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
93
-
94
- GGML_TENSOR_BINARY_OP_LOCALS
95
-
89
+ template <int qk, int qr, dequantize_kernel_t dq, typename dst_t >
90
+ static void get_rows_cuda_q (
91
+ const void * src0_d , const int32_t * src1_d, dst_t * dst_d ,
92
+ const int64_t ne00 , const size_t nb01, const size_t nb02, const size_t nb03,
93
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
94
+ const size_t nb1, const size_t nb2, const size_t nb3,
95
+ cudaStream_t stream) {
96
96
const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
97
97
const int block_num_x = (ne00 + 2 *CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / (2 *CUDA_GET_ROWS_BLOCK_SIZE);
98
98
const dim3 block_nums (block_num_x, ne10, ne11*ne12);
99
99
100
100
// strides in elements
101
- // const size_t s0 = nb0 / ggml_element_size(dst );
102
- const size_t s1 = nb1 / ggml_element_size (dst );
103
- const size_t s2 = nb2 / ggml_element_size (dst );
104
- const size_t s3 = nb3 / ggml_element_size (dst );
101
+ // const size_t s0 = nb0 / sizeof(dst_t );
102
+ const size_t s1 = nb1 / sizeof ( dst_t );
103
+ const size_t s2 = nb2 / sizeof ( dst_t );
104
+ const size_t s3 = nb3 / sizeof ( dst_t );
105
105
106
- const size_t s10 = nb10 / ggml_element_size (src1 );
107
- const size_t s11 = nb11 / ggml_element_size (src1 );
108
- const size_t s12 = nb12 / ggml_element_size (src1 );
109
- // const size_t s13 = nb13 / ggml_element_size(src1 );
106
+ const size_t s10 = nb10 / sizeof ( int32_t );
107
+ const size_t s11 = nb11 / sizeof ( int32_t );
108
+ const size_t s12 = nb12 / sizeof ( int32_t );
109
+ // const size_t s13 = nb13 / sizeof(int32_t );
110
110
111
111
GGML_ASSERT (ne00 % 2 == 0 );
112
112
113
113
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0 , stream>>> (
114
- src0_dd, src1_dd, dst_dd ,
114
+ src0_d, src1_d, dst_d ,
115
115
ne00, /* ne01, ne02, ne03,*/
116
116
/* ne10, ne11,*/ ne12, /* ne13,*/
117
117
/* s0,*/ s1, s2, s3,
118
118
/* nb00,*/ nb01, nb02, nb03,
119
119
s10, s11, s12/* , s13*/ );
120
-
121
- GGML_UNUSED (dst);
122
120
}
123
121
124
- template <typename src0_t >
122
+ template <typename src0_t , typename dst_t >
125
123
static void get_rows_cuda_float (
126
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
127
- const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
128
-
129
- GGML_TENSOR_BINARY_OP_LOCALS
130
-
131
- GGML_ASSERT (ne13 == 1 );
132
-
124
+ const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d,
125
+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
126
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
127
+ const size_t nb1, const size_t nb2, const size_t nb3,
128
+ cudaStream_t stream) {
133
129
const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
134
130
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / CUDA_GET_ROWS_BLOCK_SIZE;
135
131
const dim3 block_nums (block_num_x, ne10, ne11*ne12);
136
132
137
133
// strides in elements
138
- // const size_t s0 = nb0 / ggml_element_size(dst );
139
- const size_t s1 = nb1 / ggml_element_size (dst );
140
- const size_t s2 = nb2 / ggml_element_size (dst );
141
- const size_t s3 = nb3 / ggml_element_size (dst );
134
+ // const size_t s0 = nb0 / sizeof(dst_t );
135
+ const size_t s1 = nb1 / sizeof ( dst_t );
136
+ const size_t s2 = nb2 / sizeof ( dst_t );
137
+ const size_t s3 = nb3 / sizeof ( dst_t );
142
138
143
- const size_t s10 = nb10 / ggml_element_size (src1 );
144
- const size_t s11 = nb11 / ggml_element_size (src1 );
145
- const size_t s12 = nb12 / ggml_element_size (src1 );
146
- // const size_t s13 = nb13 / ggml_element_size(src1 );
139
+ const size_t s10 = nb10 / sizeof ( int32_t );
140
+ const size_t s11 = nb11 / sizeof ( int32_t );
141
+ const size_t s12 = nb12 / sizeof ( int32_t );
142
+ // const size_t s13 = nb13 / sizeof(int32_t );
147
143
148
144
k_get_rows_float<<<block_nums, block_dims, 0 , stream>>> (
149
- src0_dd, src1_dd, dst_dd ,
145
+ src0_d, src1_d, dst_d ,
150
146
ne00, /* ne01, ne02, ne03,*/
151
147
/* ne10, ne11,*/ ne12, /* ne13,*/
152
148
/* s0,*/ s1, s2, s3,
153
149
/* nb00,*/ nb01, nb02, nb03,
154
150
s10, s11, s12/* , s13*/ );
155
-
156
- GGML_UNUSED (dst);
157
151
}
158
152
159
- void ggml_cuda_op_get_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160
- const ggml_tensor * src0 = dst->src [0 ];
161
- const ggml_tensor * src1 = dst->src [1 ];
162
-
163
- const void * src0_d = (const void *) src0->data ;
164
- const int32_t * src1_d = (const int32_t *) src1->data ;
165
- float * dst_d = (float *) dst->data ;
166
-
167
- cudaStream_t stream = ctx.stream ();
168
-
169
- GGML_ASSERT (src1->type == GGML_TYPE_I32);
170
- GGML_ASSERT (dst->type == GGML_TYPE_F32);
171
-
172
- GGML_ASSERT (src0->nb [0 ] == ggml_type_size (src0->type ));
173
- GGML_ASSERT (src1->nb [0 ] == ggml_type_size (src1->type ));
174
- GGML_ASSERT (dst->nb [0 ] == ggml_type_size (dst->type ));
175
-
176
- switch (src0->type ) {
153
+ template <typename dst_t >
154
+ static void ggml_cuda_get_rows_switch_src0_type (
155
+ const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
156
+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
157
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
158
+ const size_t nb1, const size_t nb2, const size_t nb3,
159
+ cudaStream_t stream) {
160
+ switch (src0_type) {
177
161
case GGML_TYPE_F16:
178
- get_rows_cuda_float (src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
162
+ get_rows_cuda_float ((const half *) src0_d, src1_d, dst_d,
163
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
179
164
break ;
180
165
case GGML_TYPE_F32:
181
- get_rows_cuda_float (src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
166
+ get_rows_cuda_float ((const float *) src0_d, src1_d, dst_d,
167
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
168
+ break ;
169
+ case GGML_TYPE_BF16:
170
+ get_rows_cuda_float ((const nv_bfloat16 *) src0_d, src1_d, dst_d,
171
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
182
172
break ;
183
173
case GGML_TYPE_Q4_0:
184
- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
174
+ get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
175
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
185
176
break ;
186
177
case GGML_TYPE_Q4_1:
187
- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
178
+ get_rows_cuda_q<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_d, dst_d,
179
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
188
180
break ;
189
181
case GGML_TYPE_Q5_0:
190
- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
182
+ get_rows_cuda_q<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_d, dst_d,
183
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
191
184
break ;
192
185
case GGML_TYPE_Q5_1:
193
- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
186
+ get_rows_cuda_q<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_d, dst_d,
187
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
194
188
break ;
195
189
case GGML_TYPE_Q8_0:
196
- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
190
+ get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
191
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
197
192
break ;
198
193
default :
199
194
// TODO: k-quants
200
- GGML_ABORT (" %s: unsupported type: %s\n " , __func__, ggml_type_name (src0-> type ));
195
+ GGML_ABORT (" %s: unsupported src0 type: %s\n " , __func__, ggml_type_name (src0_type ));
201
196
break ;
202
197
}
203
198
}
204
199
200
+ void get_rows_cuda (
201
+ const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
202
+ int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
203
+ int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
204
+ size_t nb1, size_t nb2, size_t nb3,
205
+ cudaStream_t stream) {
206
+ switch (dst_type) {
207
+ case GGML_TYPE_F32:
208
+ ggml_cuda_get_rows_switch_src0_type (src0_d, src0_type, src1_d, (float *) dst_d,
209
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
210
+ break ;
211
+ case GGML_TYPE_F16:
212
+ ggml_cuda_get_rows_switch_src0_type (src0_d, src0_type, src1_d, (half *) dst_d,
213
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
214
+ break ;
215
+ case GGML_TYPE_BF16:
216
+ ggml_cuda_get_rows_switch_src0_type (src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d,
217
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
218
+ break ;
219
+ default :
220
+ GGML_ABORT (" %s: unsupported dst type: %s\n " , __func__, ggml_type_name (dst_type));
221
+ break ;
222
+ }
223
+ }
224
+
225
+ void ggml_cuda_op_get_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
226
+ const ggml_tensor * src0 = dst->src [0 ];
227
+ const ggml_tensor * src1 = dst->src [1 ];
228
+
229
+ cudaStream_t stream = ctx.stream ();
230
+
231
+ GGML_TENSOR_BINARY_OP_LOCALS
232
+
233
+ GGML_ASSERT (src1->type == GGML_TYPE_I32);
234
+ GGML_ASSERT (ne13 == 1 );
235
+
236
+ GGML_ASSERT (src0->nb [0 ] == ggml_type_size (src0->type ));
237
+ GGML_ASSERT (src1->nb [0 ] == ggml_type_size (src1->type ));
238
+ GGML_ASSERT (dst->nb [0 ] == ggml_type_size (dst->type ));
239
+
240
+ get_rows_cuda (src0->data , src0->type , (const int32_t *) src1->data , dst->data , dst->type ,
241
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
242
+ }
243
+
205
244
void ggml_cuda_op_get_rows_back (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
206
245
const ggml_tensor * src0 = dst->src [0 ]; // gradients of forward pass output
207
246
const ggml_tensor * src1 = dst->src [1 ]; // src1 in forward pass
0 commit comments