Skip to content

Commit 3d92acf

Browse files
Gan FengGan Feng
authored andcommitted
Add cpy_q_f16.
1 parent d01b3c4 commit 3d92acf

File tree

4 files changed

+392
-172
lines changed

4 files changed

+392
-172
lines changed

ggml-cuda.cu

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5554,6 +5554,40 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
55545554
cpy_1(cx + x_offset, cdst + dst_offset);
55555555
}
55565556

5557+
static __device__ void cpy_blck_f16_q8_0(const char * cxi, char * cdsti) {
5558+
const half * xi = (const half *) cxi;
5559+
block_q8_0 * dsti = (block_q8_0 *) cdsti;
5560+
5561+
half amax = 0.0; // absolute max
5562+
5563+
for (int j = 0; j < QK8_0; j++) {
5564+
const half v = xi[j];
5565+
amax = __hmax(amax, __habs(v));
5566+
}
5567+
5568+
const half d = amax / (half)((1 << 7) - 1);
5569+
const half id = d ? ((half)1.0)/d : (half)0.0;
5570+
5571+
dsti->d = d;
5572+
5573+
for (int j = 0; j < QK8_0; ++j) {
5574+
const half x0 = xi[j]*id;
5575+
5576+
dsti->qs[j] = __half2int_rz(x0);
5577+
}
5578+
}
5579+
5580+
static __device__ void cpy_blck_q8_0_f16(const char * cxi, char * cdsti) {
5581+
const block_q8_0 * xi = (const block_q8_0 *) cxi;
5582+
half * dsti = (half *) cdsti;
5583+
5584+
const half d = xi->d;
5585+
5586+
for (int j = 0; j < QK8_0; j++) {
5587+
dsti[j] = (half)xi->qs[j] * d;
5588+
}
5589+
}
5590+
55575591
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
55585592
const float * xi = (const float *) cxi;
55595593
block_q8_0 * dsti = (block_q8_0 *) cdsti;
@@ -5573,7 +5607,7 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
55735607
for (int j = 0; j < QK8_0; ++j) {
55745608
const float x0 = xi[j]*id;
55755609

5576-
dsti->qs[j] = roundf(x0);
5610+
dsti->qs[j] = __half2int_rz(x0);
55775611
}
55785612
}
55795613

@@ -5667,6 +5701,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
56675701
cpy_blck(cx + x_offset, cdst + dst_offset);
56685702
}
56695703

5704+
template <cpy_kernel_t cpy_blck, int qk>
5705+
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
5706+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
5707+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
5708+
const int nb12, const int nb13) {
5709+
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
5710+
5711+
if (i >= ne) {
5712+
return;
5713+
}
5714+
5715+
const int i03 = i/(ne00 * ne01 * ne02);
5716+
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
5717+
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
5718+
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
5719+
const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
5720+
5721+
const int i13 = i/(ne10 * ne11 * ne12);
5722+
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
5723+
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
5724+
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
5725+
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
5726+
5727+
cpy_blck(cx + x_offset, cdst + dst_offset);
5728+
}
5729+
56705730
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
56715731
const float y = (i0 / 2 - low) / max(0.001f, high - low);
56725732
return 1.0f - min(1.0f, max(0.0f, y));
@@ -7382,7 +7442,26 @@ static void ggml_cpy_f16_f16_cuda(
73827442
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
73837443
}
73847444

7445+
static void ggml_cpy_f16_q8_0_cuda(
7446+
const char * cx, char * cdst, const int ne,
7447+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
7448+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
7449+
7450+
GGML_ASSERT(ne % QK8_0 == 0);
7451+
const int num_blocks = ne / QK8_0;
7452+
cpy_f32_q<cpy_blck_f16_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
7453+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
7454+
}
73857455

7456+
static void ggml_cpy_q8_0_f16_cuda(
7457+
const char * cx, char * cdst, const int ne,
7458+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
7459+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
7460+
7461+
const int num_blocks = ne;
7462+
cpy_q_f32<cpy_blck_q8_0_f16, QK8_0><<<num_blocks, 1, 0, stream>>>
7463+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
7464+
}
73867465

73877466
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
73887467
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
@@ -10373,6 +10452,10 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
1037310452
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
1037410453
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
1037510454
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
10455+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_Q8_0) {
10456+
ggml_cpy_f16_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
10457+
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F16) {
10458+
ggml_cpy_q8_0_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
1037610459
} else {
1037710460
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
1037810461
ggml_type_name(src0->type), ggml_type_name(src1->type));

ggml-metal.m

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@
177177
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
178178
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
179179
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
180+
GGML_METAL_KERNEL_TYPE_CPY_F16_Q8_0,
181+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
180182
GGML_METAL_KERNEL_TYPE_CONCAT,
181183
GGML_METAL_KERNEL_TYPE_SQR,
182184
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -602,6 +604,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
602604
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
603605
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
604606
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
607+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_Q8_0, cpy_f16_q8_0, true);
608+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
605609
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
606610
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
607611
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
@@ -747,6 +751,14 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
747751
switch (op->type) {
748752
case GGML_TYPE_F16:
749753
case GGML_TYPE_F32:
754+
case GGML_TYPE_Q8_0:
755+
return true;
756+
default:
757+
return false;
758+
}
759+
case GGML_TYPE_Q8_0:
760+
switch (op->type) {
761+
case GGML_TYPE_F16:
750762
return true;
751763
default:
752764
return false;
@@ -2446,6 +2458,14 @@ static enum ggml_status ggml_metal_graph_compute(
24462458
switch (dstt) {
24472459
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
24482460
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2461+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_Q8_0].pipeline; break;
2462+
default: GGML_ASSERT(false && "not implemented");
2463+
};
2464+
} break;
2465+
case GGML_TYPE_Q8_0:
2466+
{
2467+
switch (dstt) {
2468+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
24492469
default: GGML_ASSERT(false && "not implemented");
24502470
};
24512471
} break;

ggml-metal.metal

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,6 +2031,110 @@ kernel void kernel_leaky_relu_f32(
20312031
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
20322032
}
20332033

2034+
kernel void kernel_cpy_q8_0_f16(
2035+
device const void * src0,
2036+
device half * dst,
2037+
constant int64_t & ne00,
2038+
constant int64_t & ne01,
2039+
constant int64_t & ne02,
2040+
constant int64_t & ne03,
2041+
constant uint64_t & nb00,
2042+
constant uint64_t & nb01,
2043+
constant uint64_t & nb02,
2044+
constant uint64_t & nb03,
2045+
constant int64_t & ne0,
2046+
constant int64_t & ne1,
2047+
constant int64_t & ne2,
2048+
constant int64_t & ne3,
2049+
constant uint64_t & nb0,
2050+
constant uint64_t & nb1,
2051+
constant uint64_t & nb2,
2052+
constant uint64_t & nb3,
2053+
uint3 tgpig[[threadgroup_position_in_grid]],
2054+
uint3 tpitg[[thread_position_in_threadgroup]],
2055+
uint3 ntg[[threads_per_threadgroup]]) {
2056+
const int64_t i03 = tgpig[2];
2057+
const int64_t i02 = tgpig[1];
2058+
const int64_t i01 = tgpig[0];
2059+
2060+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2061+
2062+
const int64_t i3 = n / (ne2*ne1*ne0);
2063+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2064+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2065+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)*QK8_0;
2066+
2067+
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2068+
2069+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2070+
device const block_q8_0 * src = (device block_q8_0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2071+
2072+
const half d = src->d;
2073+
2074+
for (int j = 0; j < QK8_0; ++j) {
2075+
dst_data[i00*QK8_0+j] = src->qs[j] * d;
2076+
}
2077+
}
2078+
}
2079+
2080+
kernel void kernel_cpy_f16_q8_0(
2081+
device const half * src0,
2082+
device void * dst,
2083+
constant int64_t & ne00,
2084+
constant int64_t & ne01,
2085+
constant int64_t & ne02,
2086+
constant int64_t & ne03,
2087+
constant uint64_t & nb00,
2088+
constant uint64_t & nb01,
2089+
constant uint64_t & nb02,
2090+
constant uint64_t & nb03,
2091+
constant int64_t & ne0,
2092+
constant int64_t & ne1,
2093+
constant int64_t & ne2,
2094+
constant int64_t & ne3,
2095+
constant uint64_t & nb0,
2096+
constant uint64_t & nb1,
2097+
constant uint64_t & nb2,
2098+
constant uint64_t & nb3,
2099+
uint3 tgpig[[threadgroup_position_in_grid]],
2100+
uint3 tpitg[[thread_position_in_threadgroup]],
2101+
uint3 ntg[[threads_per_threadgroup]]) {
2102+
const int64_t i03 = tgpig[2];
2103+
const int64_t i02 = tgpig[1];
2104+
const int64_t i01 = tgpig[0];
2105+
2106+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2107+
2108+
const int64_t i3 = n / (ne2*ne1*ne0);
2109+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2110+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2111+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
2112+
2113+
device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2114+
2115+
for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
2116+
device half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2117+
2118+
half amax = (half) 0.0f; // absolute max
2119+
2120+
for (int j = 0; j < QK8_0; j++) {
2121+
const half v = src[j];
2122+
amax = MAX(amax, fabs(v));
2123+
}
2124+
2125+
const half d = amax / ((1 << 7) - 1);
2126+
const half id = d ? 1.0f/d : 0.0f;
2127+
2128+
dst_data[i00/QK8_0].d = d;
2129+
2130+
for (int j = 0; j < QK8_0; ++j) {
2131+
const half x0 = src[j]*id;
2132+
2133+
dst_data[i00/QK8_0].qs[j] = round(x0);
2134+
}
2135+
}
2136+
}
2137+
20342138
kernel void kernel_cpy_f16_f16(
20352139
device const half * src0,
20362140
device half * dst,

0 commit comments

Comments
 (0)