@@ -33,8 +33,8 @@ static __global__ void k_get_rows(
33
33
dfloat2 v;
34
34
dequantize_kernel (src0_row, ib, iqs, v);
35
35
36
- dst_row[iybs + iqs + 0 ] = float ( v.x ) ;
37
- dst_row[iybs + iqs + y_offset] = float ( v.y ) ;
36
+ dst_row[iybs + iqs + 0 ] = v.x ;
37
+ dst_row[iybs + iqs + y_offset] = v.y ;
38
38
}
39
39
40
40
template <typename src0_t , typename dst_t >
@@ -60,7 +60,7 @@ static __global__ void k_get_rows_float(
60
60
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
61
61
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
62
62
63
- dst_row[i00] = float ( src0_row[i00]) ;
63
+ dst_row[i00] = src0_row[i00];
64
64
}
65
65
66
66
template <typename grad_t , typename dst_t >
@@ -86,161 +86,122 @@ static __global__ void k_get_rows_back_float(
86
86
dst[dst_row*ncols + col] = sum;
87
87
}
88
88
89
- template <int qk, int qr, dequantize_kernel_t dq, typename dst_t >
90
- static void get_rows_cuda_q (
91
- const void * src0_d , const int32_t * src1_d, dst_t * dst_d ,
92
- const int64_t ne00 , const size_t nb01, const size_t nb02, const size_t nb03,
93
- const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
94
- const size_t nb1, const size_t nb2, const size_t nb3,
95
- cudaStream_t stream) {
89
+ template <int qk, int qr, dequantize_kernel_t dq>
90
+ static void get_rows_cuda (
91
+ const ggml_tensor * src0 , const ggml_tensor * src1, ggml_tensor * dst ,
92
+ const void * src0_dd , const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
93
+
94
+ GGML_TENSOR_BINARY_OP_LOCALS
95
+
96
96
const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
97
97
const int block_num_x = (ne00 + 2 *CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / (2 *CUDA_GET_ROWS_BLOCK_SIZE);
98
98
const dim3 block_nums (block_num_x, ne10, ne11*ne12);
99
99
100
100
// strides in elements
101
- // const size_t s0 = nb0 / sizeof(dst_t );
102
- const size_t s1 = nb1 / sizeof ( dst_t );
103
- const size_t s2 = nb2 / sizeof ( dst_t );
104
- const size_t s3 = nb3 / sizeof ( dst_t );
101
+ // const size_t s0 = nb0 / ggml_element_size(dst );
102
+ const size_t s1 = nb1 / ggml_element_size (dst );
103
+ const size_t s2 = nb2 / ggml_element_size (dst );
104
+ const size_t s3 = nb3 / ggml_element_size (dst );
105
105
106
- const size_t s10 = nb10 / sizeof ( int32_t );
107
- const size_t s11 = nb11 / sizeof ( int32_t );
108
- const size_t s12 = nb12 / sizeof ( int32_t );
109
- // const size_t s13 = nb13 / sizeof(int32_t );
106
+ const size_t s10 = nb10 / ggml_element_size (src1 );
107
+ const size_t s11 = nb11 / ggml_element_size (src1 );
108
+ const size_t s12 = nb12 / ggml_element_size (src1 );
109
+ // const size_t s13 = nb13 / ggml_element_size(src1 );
110
110
111
111
GGML_ASSERT (ne00 % 2 == 0 );
112
112
113
113
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0 , stream>>> (
114
- src0_d, src1_d, dst_d ,
114
+ src0_dd, src1_dd, dst_dd ,
115
115
ne00, /* ne01, ne02, ne03,*/
116
116
/* ne10, ne11,*/ ne12, /* ne13,*/
117
117
/* s0,*/ s1, s2, s3,
118
118
/* nb00,*/ nb01, nb02, nb03,
119
119
s10, s11, s12/* , s13*/ );
120
+
121
+ GGML_UNUSED (dst);
120
122
}
121
123
122
- template <typename src0_t , typename dst_t >
124
+ template <typename src0_t >
123
125
static void get_rows_cuda_float (
124
- const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d,
125
- const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
126
- const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
127
- const size_t nb1, const size_t nb2, const size_t nb3,
128
- cudaStream_t stream) {
126
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
127
+ const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
128
+
129
+ GGML_TENSOR_BINARY_OP_LOCALS
130
+
131
+ GGML_ASSERT (ne13 == 1 );
132
+
129
133
const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
130
134
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / CUDA_GET_ROWS_BLOCK_SIZE;
131
135
const dim3 block_nums (block_num_x, ne10, ne11*ne12);
132
136
133
137
// strides in elements
134
- // const size_t s0 = nb0 / sizeof(dst_t );
135
- const size_t s1 = nb1 / sizeof ( dst_t );
136
- const size_t s2 = nb2 / sizeof ( dst_t );
137
- const size_t s3 = nb3 / sizeof ( dst_t );
138
+ // const size_t s0 = nb0 / ggml_element_size(dst );
139
+ const size_t s1 = nb1 / ggml_element_size (dst );
140
+ const size_t s2 = nb2 / ggml_element_size (dst );
141
+ const size_t s3 = nb3 / ggml_element_size (dst );
138
142
139
- const size_t s10 = nb10 / sizeof ( int32_t );
140
- const size_t s11 = nb11 / sizeof ( int32_t );
141
- const size_t s12 = nb12 / sizeof ( int32_t );
142
- // const size_t s13 = nb13 / sizeof(int32_t );
143
+ const size_t s10 = nb10 / ggml_element_size (src1 );
144
+ const size_t s11 = nb11 / ggml_element_size (src1 );
145
+ const size_t s12 = nb12 / ggml_element_size (src1 );
146
+ // const size_t s13 = nb13 / ggml_element_size(src1 );
143
147
144
148
k_get_rows_float<<<block_nums, block_dims, 0 , stream>>> (
145
- src0_d, src1_d, dst_d ,
149
+ src0_dd, src1_dd, dst_dd ,
146
150
ne00, /* ne01, ne02, ne03,*/
147
151
/* ne10, ne11,*/ ne12, /* ne13,*/
148
152
/* s0,*/ s1, s2, s3,
149
153
/* nb00,*/ nb01, nb02, nb03,
150
154
s10, s11, s12/* , s13*/ );
155
+
156
+ GGML_UNUSED (dst);
151
157
}
152
158
153
- template <typename dst_t >
154
- static void ggml_cuda_get_rows_switch_src0_type (
155
- const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
156
- const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
157
- const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
158
- const size_t nb1, const size_t nb2, const size_t nb3,
159
- cudaStream_t stream) {
160
- switch (src0_type) {
159
+ void ggml_cuda_op_get_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160
+ const ggml_tensor * src0 = dst->src [0 ];
161
+ const ggml_tensor * src1 = dst->src [1 ];
162
+
163
+ const void * src0_d = (const void *) src0->data ;
164
+ const int32_t * src1_d = (const int32_t *) src1->data ;
165
+ float * dst_d = (float *) dst->data ;
166
+
167
+ cudaStream_t stream = ctx.stream ();
168
+
169
+ GGML_ASSERT (src1->type == GGML_TYPE_I32);
170
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
171
+
172
+ GGML_ASSERT (src0->nb [0 ] == ggml_type_size (src0->type ));
173
+ GGML_ASSERT (src1->nb [0 ] == ggml_type_size (src1->type ));
174
+ GGML_ASSERT (dst->nb [0 ] == ggml_type_size (dst->type ));
175
+
176
+ switch (src0->type ) {
161
177
case GGML_TYPE_F16:
162
- get_rows_cuda_float ((const half *) src0_d, src1_d, dst_d,
163
- ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
178
+ get_rows_cuda_float (src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
164
179
break ;
165
180
case GGML_TYPE_F32:
166
- get_rows_cuda_float ((const float *) src0_d, src1_d, dst_d,
167
- ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
168
- break ;
169
- case GGML_TYPE_BF16:
170
- get_rows_cuda_float ((const nv_bfloat16 *) src0_d, src1_d, dst_d,
171
- ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
181
+ get_rows_cuda_float (src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
172
182
break ;
173
183
case GGML_TYPE_Q4_0:
174
- get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
175
- ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
184
+ get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
176
185
break ;
177
186
case GGML_TYPE_Q4_1:
178
- get_rows_cuda_q<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_d, dst_d,
179
- ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
187
+ get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
180
188
break ;
181
189
case GGML_TYPE_Q5_0:
182
- get_rows_cuda_q<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_d, dst_d,
183
- ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
190
+ get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
184
191
break ;
185
192
case GGML_TYPE_Q5_1:
186
- get_rows_cuda_q<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_d, dst_d,
187
- ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
193
+ get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
188
194
break ;
189
195
case GGML_TYPE_Q8_0:
190
- get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
191
- ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
196
+ get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
192
197
break ;
193
198
default :
194
199
// TODO: k-quants
195
- GGML_ABORT (" %s: unsupported src0 type: %s\n " , __func__, ggml_type_name (src0_type ));
200
+ GGML_ABORT (" %s: unsupported type: %s\n " , __func__, ggml_type_name (src0-> type ));
196
201
break ;
197
202
}
198
203
}
199
204
200
- void get_rows_cuda (
201
- const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
202
- int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
203
- int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
204
- size_t nb1, size_t nb2, size_t nb3,
205
- cudaStream_t stream) {
206
- switch (dst_type) {
207
- case GGML_TYPE_F32:
208
- ggml_cuda_get_rows_switch_src0_type (src0_d, src0_type, src1_d, (float *) dst_d,
209
- ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
210
- break ;
211
- case GGML_TYPE_F16:
212
- ggml_cuda_get_rows_switch_src0_type (src0_d, src0_type, src1_d, (half *) dst_d,
213
- ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
214
- break ;
215
- case GGML_TYPE_BF16:
216
- ggml_cuda_get_rows_switch_src0_type (src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d,
217
- ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
218
- break ;
219
- default :
220
- GGML_ABORT (" %s: unsupported dst type: %s\n " , __func__, ggml_type_name (dst_type));
221
- break ;
222
- }
223
- }
224
-
225
- void ggml_cuda_op_get_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
226
- const ggml_tensor * src0 = dst->src [0 ];
227
- const ggml_tensor * src1 = dst->src [1 ];
228
-
229
- cudaStream_t stream = ctx.stream ();
230
-
231
- GGML_TENSOR_BINARY_OP_LOCALS
232
-
233
- GGML_ASSERT (src1->type == GGML_TYPE_I32);
234
- GGML_ASSERT (ne13 == 1 );
235
-
236
- GGML_ASSERT (src0->nb [0 ] == ggml_type_size (src0->type ));
237
- GGML_ASSERT (src1->nb [0 ] == ggml_type_size (src1->type ));
238
- GGML_ASSERT (dst->nb [0 ] == ggml_type_size (dst->type ));
239
-
240
- get_rows_cuda (src0->data , src0->type , (const int32_t *) src1->data , dst->data , dst->type ,
241
- ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
242
- }
243
-
244
205
void ggml_cuda_op_get_rows_back (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
245
206
const ggml_tensor * src0 = dst->src [0 ]; // gradients of forward pass output
246
207
const ggml_tensor * src1 = dst->src [1 ]; // src1 in forward pass
0 commit comments