Skip to content

Commit 8c2b881

Browse files
committed
cuda : poc for norm quants (only -b 1 works)
1 parent df54d2f commit 8c2b881

File tree

1 file changed

+32
-20
lines changed

1 file changed

+32
-20
lines changed

ggml-cuda.cu

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -163,21 +163,31 @@ typedef float2 dfloat2;
163163
#endif //GGML_CUDA_F16
164164

165165
static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
166-
const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
166+
//const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
167+
x8 += sizeof(int) * i32;
167168

168169
int x32 = 0;
169-
x32 |= x16[0] << 0;
170-
x32 |= x16[1] << 16;
170+
//x32 |= x16[0] << 0;
171+
//x32 |= x16[1] << 16;
172+
x32 |= ((uint32_t)(x8[0])) << 0;
173+
x32 |= ((uint32_t)(x8[1])) << 8;
174+
x32 |= ((uint32_t)(x8[2])) << 16;
175+
x32 |= ((uint32_t)(x8[3])) << 24;
171176

172177
return x32;
173178
}
174179

175180
static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
176-
const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
181+
//const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
182+
x8 += sizeof(int) * i32;
177183

178184
int x32 = 0;
179-
x32 |= x16[0] << 0;
180-
x32 |= x16[1] << 16;
185+
//x32 |= x16[0] << 0;
186+
//x32 |= x16[1] << 16;
187+
x32 |= ((uint32_t)(x8[0])) << 0;
188+
x32 |= ((uint32_t)(x8[1])) << 8;
189+
x32 |= ((uint32_t)(x8[2])) << 16;
190+
x32 |= ((uint32_t)(x8[3])) << 24;
181191

182192
return x32;
183193
}
@@ -2093,7 +2103,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
20932103
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
20942104

20952105
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
2096-
// x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
2106+
//x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = Q4_0D(bxi->d);
20972107
}
20982108

20992109
const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
@@ -2109,7 +2119,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
21092119

21102120
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
21112121

2112-
x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
2122+
x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = Q4_0D(bxi->d);
21132123
}
21142124
}
21152125

@@ -2143,15 +2153,15 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
21432153

21442154
#pragma unroll
21452155
for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
2146-
v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
2156+
v[i] = get_int_from_uint8(bq4_1->qs, iqs + i);
21472157
u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
21482158
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
21492159
}
21502160

21512161
const float d = Q4_1D(bq4_1->dm);
21522162
const float m = Q4_1M(bq4_1->dm);
21532163

2154-
const float2 dm = {d, m};
2164+
const half2 dm = {d, m};
21552165

21562166
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, dm, bq8_1->ds);
21572167
}
@@ -2189,7 +2199,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
21892199

21902200
const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
21912201

2192-
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
2202+
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
21932203
}
21942204

21952205
const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
@@ -2205,7 +2215,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
22052215

22062216
const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
22072217

2208-
x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
2218+
x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd].x = Q4_1D(bxi->dm);
2219+
x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd].y = Q4_1M(bxi->dm);
22092220
}
22102221
}
22112222

@@ -2353,16 +2364,16 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
23532364

23542365
#pragma unroll
23552366
for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
2356-
vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
2357-
vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
2367+
vl[i] = get_int_from_uint8(bq5_1->qs, iqs + i);
2368+
vh[i] = get_int_from_uint8(bq5_1->qh, 0) >> (4 * (iqs + i));
23582369
u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
23592370
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
23602371
}
23612372

2362-
const float d = Q5_1D(bq4_1->dm);
2363-
const float m = Q5_1M(bq4_1->dm);
2373+
const half d = Q5_1D(bq5_1->dm);
2374+
const half m = Q5_1M(bq5_1->dm);
23642375

2365-
const float2 dm = {d, m};
2376+
const half2 dm = {d, m};
23662377

23672378
return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, dm, bq8_1->ds);
23682379
}
@@ -2400,8 +2411,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
24002411

24012412
const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
24022413

2403-
const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
2404-
const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
2414+
const int ql = get_int_from_uint8(bxi->qs, kqsx);
2415+
const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_1));
24052416

24062417
int qs0 = (ql >> 0) & 0x0F0F0F0F;
24072418
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
@@ -2433,7 +2444,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
24332444

24342445
const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
24352446

2436-
x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
2447+
x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd].x = Q5_1D(bxi->dm);
2448+
x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd].y = Q5_1M(bxi->dm);
24372449
}
24382450
}
24392451

0 commit comments

Comments
 (0)