Skip to content

Commit e8457c9

Browse files
committed
cuda : wip
1 parent 6b58ae9 commit e8457c9

File tree

1 file changed

+62
-5
lines changed

1 file changed

+62
-5
lines changed

ggml-cuda.cu

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4582,12 +4582,43 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
45824582
}
45834583
}
45844584

4585-
// TODO: generalize for all quants
4586-
template <cpy_kernel_t cpy_blck>
4585+
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
4586+
const float * xi = (const float *) cxi;
4587+
block_q4_0 * dsti = (block_q4_0 *) cdsti;
4588+
4589+
float amax = 0.0f;
4590+
float max = 0.0f;
4591+
4592+
for (int j = 0; j < QK4_0; ++j) {
4593+
const float v = xi[j];
4594+
if (amax < fabsf(v)) {
4595+
amax = fabsf(v);
4596+
max = v;
4597+
}
4598+
}
4599+
4600+
const float d = max / -8;
4601+
const float id = d ? 1.0f/d : 0.0f;
4602+
4603+
y[i].d = d;
4604+
4605+
for (int j = 0; j < QK4_0/2; ++j) {
4606+
const float x0 = xi[0 + j]*id;
4607+
const float x1 = xi[QK4_0/2 + j]*id;
4608+
4609+
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
4610+
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
4611+
4612+
dsti->qs[j] = xi0;
4613+
dsti->qs[j] |= xi1 << 4;
4614+
}
4615+
}
4616+
4617+
template <cpy_kernel_t cpy_blck, int qk>
45874618
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
45884619
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
45894620
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
4590-
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*QK8_0;
4621+
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
45914622

45924623
if (i >= ne) {
45934624
return;
@@ -4600,7 +4631,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
46004631

46014632
const int i12 = i / (ne10*ne11);
46024633
const int i11 = (i - i12*ne10*ne11) / ne10;
4603-
const int i10 = (i - i12*ne10*ne11 - i11*ne10)/QK8_0;
4634+
const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
46044635
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
46054636

46064637
cpy_blck(cx + x_offset, cdst + dst_offset);
@@ -5791,7 +5822,29 @@ static void ggml_cpy_f32_q8_0_cuda(
57915822

57925823
GGML_ASSERT(ne % QK8_0 == 0);
57935824
const int num_blocks = ne / QK8_0;
5794-
cpy_f32_q<cpy_blck_f32_q8_0><<<num_blocks, 1, 0, stream>>>
5825+
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
5826+
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5827+
}
5828+
5829+
static void ggml_cpy_f32_q4_0_cuda(
5830+
const char * cx, char * cdst, const int ne,
5831+
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
5832+
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
5833+
5834+
GGML_ASSERT(ne % QK4_0 == 0);
5835+
const int num_blocks = ne / QK4_0;
5836+
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
5837+
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5838+
}
5839+
5840+
static void ggml_cpy_f32_q4_1_cuda(
5841+
const char * cx, char * cdst, const int ne,
5842+
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
5843+
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
5844+
5845+
GGML_ASSERT(ne % QK4_1 == 0);
5846+
const int num_blocks = ne / QK4_1;
5847+
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
57955848
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
57965849
}
57975850

@@ -7836,6 +7889,10 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
78367889
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
78377890
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
78387891
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
7892+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
7893+
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
7894+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
7895+
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
78397896
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
78407897
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
78417898
} else {

0 commit comments

Comments
 (0)