5
5
6
6
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
7
7
static __global__ void dequantize_block (const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
8
- const int64_t i = ( int64_t ) 2 *(blockDim .x *blockIdx .x + threadIdx .x );
8
+ const int i = 2 *(blockDim .x *blockIdx .x + threadIdx .x );
9
9
10
10
if (i >= k) {
11
11
return ;
@@ -71,9 +71,9 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
71
71
const int64_t i = blockIdx .x ;
72
72
73
73
// assume 32 threads
74
- const int64_t tid = threadIdx .x ;
75
- const int64_t il = tid/8 ;
76
- const int64_t ir = tid%8 ;
74
+ const int tid = threadIdx .x ;
75
+ const int il = tid/8 ;
76
+ const int ir = tid%8 ;
77
77
const int64_t ib = 8 *i + ir;
78
78
if (ib >= nb32) {
79
79
return ;
@@ -99,9 +99,9 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
99
99
const int64_t i = blockIdx .x ;
100
100
101
101
// assume 32 threads
102
- const int64_t tid = threadIdx .x ;
103
- const int64_t il = tid/8 ;
104
- const int64_t ir = tid%8 ;
102
+ const int tid = threadIdx .x ;
103
+ const int il = tid/8 ;
104
+ const int ir = tid%8 ;
105
105
const int64_t ib = 8 *i + ir;
106
106
if (ib >= nb32) {
107
107
return ;
@@ -128,10 +128,10 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
128
128
const int64_t i = blockIdx .x ;
129
129
const block_q2_K * x = (const block_q2_K *) vx;
130
130
131
- const int64_t tid = threadIdx .x ;
131
+ const int tid = threadIdx .x ;
132
132
#if QK_K == 256
133
- const int64_t n = tid/32 ;
134
- const int64_t l = tid - 32 *n;
133
+ const int n = tid/32 ;
134
+ const int l = tid - 32 *n;
135
135
const int64_t is = 8 *n + l/16 ;
136
136
137
137
const uint8_t q = x[i].qs [32 *n + l];
@@ -159,16 +159,16 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
159
159
template <typename dst_t >
160
160
static __global__ void dequantize_block_q3_K (const void * __restrict__ vx, dst_t * __restrict__ yy) {
161
161
162
- const int64_t i = blockIdx .x ;
162
+ const int i = blockIdx .x ;
163
163
const block_q3_K * x = (const block_q3_K *) vx;
164
164
165
165
#if QK_K == 256
166
- const int64_t r = threadIdx .x /4 ;
167
- const int64_t tid = r/2 ;
168
- const int64_t is0 = r%2 ;
169
- const int64_t l0 = 16 *is0 + 4 *(threadIdx .x %4 );
170
- const int64_t n = tid / 4 ;
171
- const int64_t j = tid - 4 *n;
166
+ const int r = threadIdx .x /4 ;
167
+ const int tid = r/2 ;
168
+ const int is0 = r%2 ;
169
+ const int l0 = 16 *is0 + 4 *(threadIdx .x %4 );
170
+ const int n = tid / 4 ;
171
+ const int j = tid - 4 *n;
172
172
173
173
uint8_t m = 1 << (4 *n + j);
174
174
int64_t is = 8 *n + 2 *j + is0;
@@ -187,11 +187,11 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
187
187
188
188
for (int l = l0; l < l0+4 ; ++l) y[l] = dl * ((int8_t )((q[l] >> shift) & 3 ) - ((hm[l] & m) ? 0 : 4 ));
189
189
#else
190
- const int64_t tid = threadIdx .x ;
191
- const int64_t is = tid/16 ; // 0 or 1
192
- const int64_t il = tid%16 ; // 0...15
193
- const int64_t im = il/8 ; // 0...1
194
- const int64_t in = il%8 ; // 0...7
190
+ const int tid = threadIdx .x ;
191
+ const int is = tid/16 ; // 0 or 1
192
+ const int il = tid%16 ; // 0...15
193
+ const int im = il/8 ; // 0...1
194
+ const int in = il%8 ; // 0...7
195
195
196
196
dst_t * y = yy + i*QK_K + 16 *is + il;
197
197
@@ -229,11 +229,11 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
229
229
230
230
#if QK_K == 256
231
231
// assume 32 threads
232
- const int64_t tid = threadIdx .x ;
233
- const int64_t il = tid/8 ;
234
- const int64_t ir = tid%8 ;
235
- const int64_t is = 2 *il;
236
- const int64_t n = 4 ;
232
+ const int tid = threadIdx .x ;
233
+ const int il = tid/8 ;
234
+ const int ir = tid%8 ;
235
+ const int is = 2 *il;
236
+ const int n = 4 ;
237
237
238
238
dst_t * y = yy + i*QK_K + 64 *il + n*ir;
239
239
@@ -252,7 +252,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
252
252
y[l +32 ] = d2 * (q[l] >> 4 ) - m2;
253
253
}
254
254
#else
255
- const int64_t tid = threadIdx .x ;
255
+ const int tid = threadIdx .x ;
256
256
const uint8_t * q = x[i].qs ;
257
257
dst_t * y = yy + i*QK_K;
258
258
const float d = (float )x[i].dm [0 ];
@@ -270,10 +270,10 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
270
270
271
271
#if QK_K == 256
272
272
// assume 64 threads - this is very slightly better than the one below
273
- const int64_t tid = threadIdx .x ;
274
- const int64_t il = tid/16 ; // il is in 0...3
275
- const int64_t ir = tid%16 ; // ir is in 0...15
276
- const int64_t is = 2 *il; // is is in 0...6
273
+ const int tid = threadIdx .x ;
274
+ const int il = tid/16 ; // il is in 0...3
275
+ const int ir = tid%16 ; // ir is in 0...15
276
+ const int is = 2 *il; // is is in 0...6
277
277
278
278
dst_t * y = yy + i*QK_K + 64 *il + 2 *ir;
279
279
@@ -296,11 +296,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
296
296
y[32 ] = d2 * ((ql[ 0 ] >> 4 ) + (qh[ 0 ] & hm ? 16 : 0 )) - m2;
297
297
y[33 ] = d2 * ((ql[ 1 ] >> 4 ) + (qh[ 1 ] & hm ? 16 : 0 )) - m2;
298
298
#else
299
- const int64_t tid = threadIdx .x ;
299
+ const int tid = threadIdx .x ;
300
300
const uint8_t q = x[i].qs [tid];
301
- const int64_t im = tid/8 ; // 0...3
302
- const int64_t in = tid%8 ; // 0...7
303
- const int64_t is = tid/16 ; // 0 or 1
301
+ const int im = tid/8 ; // 0...3
302
+ const int in = tid%8 ; // 0...7
303
+ const int is = tid/16 ; // 0 or 1
304
304
const uint8_t h = x[i].qh [in] >> im;
305
305
const float d = x[i].d ;
306
306
dst_t * y = yy + i*QK_K + tid;
@@ -317,10 +317,10 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
317
317
#if QK_K == 256
318
318
319
319
// assume 64 threads - this is very slightly better than the one below
320
- const int64_t tid = threadIdx .x ;
321
- const int64_t ip = tid/32 ; // ip is 0 or 1
322
- const int64_t il = tid - 32 *ip; // 0...32
323
- const int64_t is = 8 *ip + il/16 ;
320
+ const int tid = threadIdx .x ;
321
+ const int ip = tid/32 ; // ip is 0 or 1
322
+ const int il = tid - 32 *ip; // 0...32
323
+ const int is = 8 *ip + il/16 ;
324
324
325
325
dst_t * y = yy + i*QK_K + 128 *ip + il;
326
326
@@ -337,9 +337,9 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
337
337
#else
338
338
339
339
// assume 32 threads
340
- const int64_t tid = threadIdx .x ;
341
- const int64_t ip = tid/16 ; // 0 or 1
342
- const int64_t il = tid - 16 *ip; // 0...15
340
+ const int tid = threadIdx .x ;
341
+ const int ip = tid/16 ; // 0 or 1
342
+ const int il = tid - 16 *ip; // 0...15
343
343
344
344
dst_t * y = yy + i*QK_K + 16 *ip + il;
345
345
@@ -360,10 +360,10 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
360
360
const int64_t i = blockIdx .x ;
361
361
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
362
362
363
- const int64_t tid = threadIdx .x ;
363
+ const int tid = threadIdx .x ;
364
364
#if QK_K == 256
365
- const int64_t il = tid/8 ; // 0...3
366
- const int64_t ib = tid%8 ; // 0...7
365
+ const int il = tid/8 ; // 0...3
366
+ const int ib = tid%8 ; // 0...7
367
367
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
368
368
const uint16_t * q2 = x[i].qs + 4 *ib;
369
369
const uint8_t * aux8 = (const uint8_t *)q2;
@@ -384,10 +384,10 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
384
384
const int64_t i = blockIdx .x ;
385
385
const block_iq2_xs * x = (const block_iq2_xs *) vx;
386
386
387
- const int64_t tid = threadIdx .x ;
387
+ const int tid = threadIdx .x ;
388
388
#if QK_K == 256
389
- const int64_t il = tid/8 ; // 0...3
390
- const int64_t ib = tid%8 ; // 0...7
389
+ const int il = tid/8 ; // 0...3
390
+ const int ib = tid%8 ; // 0...7
391
391
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
392
392
const uint16_t * q2 = x[i].qs + 4 *ib;
393
393
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511 ));
@@ -406,10 +406,10 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
406
406
const int64_t i = blockIdx .x ;
407
407
const block_iq2_s * x = (const block_iq2_s *) vx;
408
408
409
- const int64_t tid = threadIdx .x ;
409
+ const int tid = threadIdx .x ;
410
410
#if QK_K == 256
411
- const int64_t il = tid/8 ; // 0...3
412
- const int64_t ib = tid%8 ; // 0...7
411
+ const int il = tid/8 ; // 0...3
412
+ const int ib = tid%8 ; // 0...7
413
413
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
414
414
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs [4 *ib+il] | ((x[i].qh [ib] << (8 -2 *il)) & 0x300 )));
415
415
const float d = (float )x[i].d * (0 .5f + ((x[i].scales [ib] >> 4 *(il/2 )) & 0xf )) * 0 .25f ;
@@ -427,10 +427,10 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
427
427
const int64_t i = blockIdx .x ;
428
428
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
429
429
430
- const int64_t tid = threadIdx .x ;
430
+ const int tid = threadIdx .x ;
431
431
#if QK_K == 256
432
- const int64_t il = tid/8 ; // 0...3
433
- const int64_t ib = tid%8 ; // 0...7
432
+ const int il = tid/8 ; // 0...3
433
+ const int ib = tid%8 ; // 0...7
434
434
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
435
435
const uint8_t * q3 = x[i].qs + 8 *ib;
436
436
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4 ) + 2 *ib;
@@ -455,10 +455,10 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
455
455
const int64_t i = blockIdx .x ;
456
456
const block_iq3_s * x = (const block_iq3_s *) vx;
457
457
458
- const int64_t tid = threadIdx .x ;
458
+ const int tid = threadIdx .x ;
459
459
#if QK_K == 256
460
- const int64_t il = tid/8 ; // 0...3
461
- const int64_t ib = tid%8 ; // 0...7
460
+ const int il = tid/8 ; // 0...3
461
+ const int ib = tid%8 ; // 0...7
462
462
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
463
463
const uint8_t * qs = x[i].qs + 8 *ib;
464
464
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2 *il+0 ] | ((x[i].qh [ib] << (8 -2 *il)) & 256 )));
@@ -481,10 +481,10 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
481
481
const int64_t i = blockIdx .x ;
482
482
const block_iq1_s * x = (const block_iq1_s *) vx;
483
483
484
- const int64_t tid = threadIdx .x ;
484
+ const int tid = threadIdx .x ;
485
485
#if QK_K == 256
486
- const int64_t il = tid/8 ; // 0...3
487
- const int64_t ib = tid%8 ; // 0...7
486
+ const int il = tid/8 ; // 0...3
487
+ const int ib = tid%8 ; // 0...7
488
488
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
489
489
const float delta = x[i].qh [ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
490
490
const float d = (float )x[i].d * (2 *((x[i].qh [ib] >> 12 ) & 7 ) + 1 );
@@ -507,10 +507,10 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
507
507
const int64_t i = blockIdx .x ;
508
508
const block_iq1_m * x = (const block_iq1_m *) vx;
509
509
510
- const int64_t tid = threadIdx .x ;
510
+ const int tid = threadIdx .x ;
511
511
#if QK_K == 256
512
- const int64_t il = tid/8 ; // 0...3
513
- const int64_t ib = tid%8 ; // 0...7
512
+ const int il = tid/8 ; // 0...3
513
+ const int ib = tid%8 ; // 0...7
514
514
dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
515
515
const uint16_t * sc = (const uint16_t *)x[i].scales ;
516
516
iq1m_scale_t scale;
@@ -538,9 +538,9 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
538
538
const int64_t i = blockIdx .x ;
539
539
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
540
540
541
- const int64_t tid = threadIdx .x ;
542
- const int64_t il = tid/8 ; // 0...3
543
- const int64_t ib = tid%8 ; // 0...7
541
+ const int tid = threadIdx .x ;
542
+ const int il = tid/8 ; // 0...3
543
+ const int ib = tid%8 ; // 0...7
544
544
dst_t * y = yy + i*QK_K + 32 *ib + 4 *il;
545
545
const uint8_t * q4 = x[ib].qs + 4 *il;
546
546
const float d = (float )x[ib].d ;
@@ -557,9 +557,9 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
557
557
const int64_t i = blockIdx .x ;
558
558
const block_iq4_xs * x = (const block_iq4_xs *)vx;
559
559
560
- const int64_t tid = threadIdx .x ;
561
- const int64_t il = tid/8 ; // 0...3
562
- const int64_t ib = tid%8 ; // 0...7
560
+ const int tid = threadIdx .x ;
561
+ const int il = tid/8 ; // 0...3
562
+ const int ib = tid%8 ; // 0...7
563
563
dst_t * y = yy + i*QK_K + 32 *ib + 4 *il;
564
564
const uint8_t * q4 = x[i].qs + 16 *ib + 4 *il;
565
565
const float d = (float )x[i].d * ((((x[i].scales_l [ib/2 ] >> 4 *(ib%2 )) & 0xf ) | (((x[i].scales_h >> 2 *ib) & 3 ) << 4 )) - 32 );
@@ -707,7 +707,7 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
707
707
708
708
template <typename src_t , typename dst_t >
709
709
static __global__ void convert_unary (const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
710
- const int64_t i = ( int64_t ) blockDim .x *blockIdx .x + threadIdx .x ;
710
+ const int i = blockDim .x *blockIdx .x + threadIdx .x ;
711
711
712
712
if (i >= k) {
713
713
return ;
0 commit comments