@@ -4582,12 +4582,43 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
4582
4582
}
4583
4583
}
4584
4584
4585
- // TODO: generalize for all quants
4586
- template <cpy_kernel_t cpy_blck>
4585
+ static __device__ void cpy_blck_f32_q4_0 (const char * cxi, char * cdsti) {
4586
+ const float * xi = (const float *) cxi;
4587
+ block_q4_0 * dsti = (block_q4_0 *) cdsti;
4588
+
4589
+ float amax = 0 .0f ;
4590
+ float max = 0 .0f ;
4591
+
4592
+ for (int j = 0 ; j < QK4_0; ++j) {
4593
+ const float v = xi[j];
4594
+ if (amax < fabsf (v)) {
4595
+ amax = fabsf (v);
4596
+ max = v;
4597
+ }
4598
+ }
4599
+
4600
+ const float d = max / -8 ;
4601
+ const float id = d ? 1 .0f /d : 0 .0f ;
4602
+
4603
+ y[i].d = d;
4604
+
4605
+ for (int j = 0 ; j < QK4_0/2 ; ++j) {
4606
+ const float x0 = xi[0 + j]*id;
4607
+ const float x1 = xi[QK4_0/2 + j]*id;
4608
+
4609
+ const uint8_t xi0 = min (15 , (int8_t )(x0 + 8 .5f ));
4610
+ const uint8_t xi1 = min (15 , (int8_t )(x1 + 8 .5f ));
4611
+
4612
+ dsti->qs [j] = xi0;
4613
+ dsti->qs [j] |= xi1 << 4 ;
4614
+ }
4615
+ }
4616
+
4617
+ template <cpy_kernel_t cpy_blck, int qk>
4587
4618
static __global__ void cpy_f32_q (const char * cx, char * cdst, const int ne,
4588
4619
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
4589
4620
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
4590
- const int i = (blockDim .x *blockIdx .x + threadIdx .x )*QK8_0 ;
4621
+ const int i = (blockDim .x *blockIdx .x + threadIdx .x )*qk ;
4591
4622
4592
4623
if (i >= ne) {
4593
4624
return ;
@@ -4600,7 +4631,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
4600
4631
4601
4632
const int i12 = i / (ne10*ne11);
4602
4633
const int i11 = (i - i12*ne10*ne11) / ne10;
4603
- const int i10 = (i - i12*ne10*ne11 - i11*ne10)/QK8_0 ;
4634
+ const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk ;
4604
4635
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
4605
4636
4606
4637
cpy_blck (cx + x_offset, cdst + dst_offset);
@@ -5791,7 +5822,29 @@ static void ggml_cpy_f32_q8_0_cuda(
5791
5822
5792
5823
GGML_ASSERT (ne % QK8_0 == 0 );
5793
5824
const int num_blocks = ne / QK8_0;
5794
- cpy_f32_q<cpy_blck_f32_q8_0><<<num_blocks, 1 , 0 , stream>>>
5825
+ cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1 , 0 , stream>>>
5826
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5827
+ }
5828
+
5829
+ static void ggml_cpy_f32_q4_0_cuda (
5830
+ const char * cx, char * cdst, const int ne,
5831
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
5832
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
5833
+
5834
+ GGML_ASSERT (ne % QK4_0 == 0 );
5835
+ const int num_blocks = ne / QK4_0;
5836
+ cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1 , 0 , stream>>>
5837
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5838
+ }
5839
+
5840
+ static void ggml_cpy_f32_q4_1_cuda (
5841
+ const char * cx, char * cdst, const int ne,
5842
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
5843
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
5844
+
5845
+ GGML_ASSERT (ne % QK4_1 == 0 );
5846
+ const int num_blocks = ne / QK4_1;
5847
+ cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1 , 0 , stream>>>
5795
5848
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5796
5849
}
5797
5850
@@ -7836,6 +7889,10 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
7836
7889
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
7837
7890
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
7838
7891
ggml_cpy_f32_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
7892
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
7893
+ ggml_cpy_f32_q4_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
7894
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
7895
+ ggml_cpy_f32_q4_1_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
7839
7896
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
7840
7897
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
7841
7898
} else {
0 commit comments