Skip to content

Commit 89d2b3e

Browse files
Performance optimization: 2 byte aligned reads
1 parent 84e05d5 commit 89d2b3e

File tree

1 file changed

+39
-23
lines changed

1 file changed

+39
-23
lines changed

ggml-cuda.cu

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,24 @@ typedef float dfloat; // dequantize float
6060
typedef float2 dfloat2;
6161
#endif //GGML_CUDA_DMMV_F16
6262

63+
static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
64+
const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
65+
66+
int x32 = 0;
67+
x32 |= x16[0] << 0;
68+
x32 |= x16[1] << 16;
69+
70+
return x32;
71+
}
72+
73+
static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
74+
return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
75+
}
76+
77+
static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
78+
return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
79+
}
80+
6381
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
6482
typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
6583
typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
@@ -1315,10 +1333,9 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
13151333

13161334
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
13171335

1318-
int vi;
1319-
memcpy(&vi, &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
1320-
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
1321-
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]);
1336+
const int vi = get_int_from_uint8(bq4_0->qs, iqs);
1337+
const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
1338+
const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI4_0);
13221339

13231340
return vec_dot_q4_0_q8_1_impl(vi, ui0, ui1, bq4_0->d, bq8_1->ds);
13241341
}
@@ -1337,11 +1354,11 @@ static __device__ __forceinline__ void load_tiles_q4_0(
13371354
int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
13381355

13391356
const int kbx = k / QI4_0;
1340-
const int kqsx = sizeof(int) * (k % QI4_0);
1357+
const int kqsx = k % QI4_0;
13411358

13421359
const block_q4_0 * bx = ((block_q4_0 *) vx) + i*blocks_per_row + kbx;
13431360

1344-
memcpy(&x_ql[i * WARP_SIZE + i + k], &bx->qs[kqsx], sizeof(int));
1361+
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx);
13451362
x_dm[i * (WARP_SIZE / QI4_0) + kbx].x = bx->d;
13461363
}
13471364

@@ -1388,9 +1405,9 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
13881405

13891406
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
13901407

1391-
const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
1392-
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
1393-
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);
1408+
const int vi = get_int_from_uint8_aligned(bq4_1->qs, iqs);
1409+
const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
1410+
const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI4_1);
13941411

13951412
return vec_dot_q4_1_q8_1_impl(vi, ui0, ui1, bq4_1->dm, bq8_1->ds);
13961413
}
@@ -1409,11 +1426,11 @@ static __device__ __forceinline__ void load_tiles_q4_1(
14091426
int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
14101427

14111428
const int kbx = k / QI4_1;
1412-
const int kqsx = sizeof(int) * (k % QI4_1);
1429+
const int kqsx = k % QI4_1;
14131430

14141431
const block_q4_1 * bx = ((block_q4_1 *) vx) + i*blocks_per_row + kbx;
14151432

1416-
x_ql[i * WARP_SIZE + i + k] = *((int *) &bx->qs[kqsx]);
1433+
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx);
14171434
x_dm[i * (WARP_SIZE / QI4_1) + kbx] = bx->dm;
14181435
}
14191436

@@ -1433,18 +1450,18 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
14331450

14341451
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
14351452
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
1436-
vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
1437-
vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13
1438-
vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21
1439-
vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29
1453+
vi0 |= (qh0 << 4) & 0x00000010; // 0 -> 4
1454+
vi0 |= (qh0 << 11) & 0x00001000; // 1 -> 12
1455+
vi0 |= (qh0 << 18) & 0x00100000; // 2 -> 20
1456+
vi0 |= (qh0 << 25) & 0x10000000; // 3 -> 28
14401457
vi0 = __vsub4(vi0, 0x10101010); // subtract 16 from quantized values
14411458
int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
14421459

14431460
int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
1444-
vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5
1445-
vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13
1446-
vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21
1447-
vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29
1461+
vi1 |= (qh1 << 4) & 0x00000010; // 0 -> 4
1462+
vi1 |= (qh1 << 11) & 0x00001000; // 1 -> 12
1463+
vi1 |= (qh1 << 18) & 0x00100000; // 2 -> 20
1464+
vi1 |= (qh1 << 25) & 0x10000000; // 3 -> 28
14481465
vi1 = __vsub4(vi1, 0x10101010); // subtract 16 from quantized values
14491466
sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
14501467

@@ -1459,12 +1476,11 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
14591476

14601477
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
14611478

1462-
int qs;
1463-
memcpy(&qs, &bq5_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
1479+
const int qs = get_int_from_uint8(bq5_0->qs, iqs);
14641480
const int qh0 = bq5_0->qh[iqs/2 + 0] >> 4*(iqs%2);
14651481
const int qh1 = bq5_0->qh[iqs/2 + 2] >> 4*(iqs%2);
1466-
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
1467-
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]);
1482+
const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
1483+
const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI5_0);
14681484

14691485
return vec_dot_q5_0_q8_1_impl(qs, qh0, qh1, ui0, ui1, bq5_0->d, bq8_1->ds);
14701486
}

0 commit comments

Comments
 (0)