Skip to content

Commit 970b5ab

Browse files
committed
ggml-cuda : add TQ2_0 support
1 parent 5cd85b5 commit 970b5ab

File tree

11 files changed

+241
-2
lines changed

11 files changed

+241
-2
lines changed

ggml/src/ggml-common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ typedef sycl::half2 ggml_half2;
126126
#define QI6_K (QK_K / (4*QR6_K))
127127
#define QR6_K 2
128128

129+
#define QI2_0 (QK_K / (4*QR2_0))
130+
#define QR2_0 4
131+
129132
#define QI2_XXS (QK_K / (4*QR2_XXS))
130133
#define QR2_XXS 4
131134

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
440440
static constexpr int qi = QI6_K;
441441
};
442442

443+
template<>
444+
struct ggml_cuda_type_traits<GGML_TYPE_TQ2_0> {
445+
static constexpr int qk = QK_K;
446+
static constexpr int qr = QR2_0;
447+
static constexpr int qi = QI2_0;
448+
};
449+
443450
template<>
444451
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
445452
static constexpr int qk = QK_K;

ggml/src/ggml-cuda/convert.cu

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,26 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
277277
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
278278
}
279279

280+
template<typename dst_t>
281+
static __global__ void dequantize_block_tq2_0(const void * __restrict__ vx, dst_t * __restrict__ yy) {
282+
283+
const int64_t i = blockIdx.x;
284+
const block_tq2_0 * x = (const block_tq2_0 *) vx;
285+
286+
const int64_t tid = threadIdx.x; // 0..64
287+
const int64_t n = tid/32; // 0 or 1
288+
const int64_t l = tid - 32*n; // 0..32
289+
290+
const uint8_t q = x[i].qs[32*n + l];
291+
dst_t * y = yy + i*QK_K + 128*n;
292+
293+
float d = __half2float(x[i].d);
294+
y[l+ 0] = d * ((q >> 0) & 3) - d;
295+
y[l+32] = d * ((q >> 2) & 3) - d;
296+
y[l+64] = d * ((q >> 4) & 3) - d;
297+
y[l+96] = d * ((q >> 6) & 3) - d;
298+
}
299+
280300
template<typename dst_t>
281301
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
282302

@@ -515,6 +535,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k
515535
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
516536
}
517537

538+
template<typename dst_t>
539+
static void dequantize_row_tq2_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
540+
const int nb = k / QK_K;
541+
dequantize_block_tq2_0<<<nb, 64, 0, stream>>>(vx, y);
542+
}
543+
518544
template<typename dst_t>
519545
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
520546
const int nb = k / QK_K;
@@ -613,6 +639,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
613639
return dequantize_row_q5_K_cuda;
614640
case GGML_TYPE_Q6_K:
615641
return dequantize_row_q6_K_cuda;
642+
case GGML_TYPE_TQ2_0:
643+
return dequantize_row_tq2_0_cuda;
616644
case GGML_TYPE_IQ2_XXS:
617645
return dequantize_row_iq2_xxs_cuda;
618646
case GGML_TYPE_IQ2_XS:
@@ -660,6 +688,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
660688
return dequantize_row_q5_K_cuda;
661689
case GGML_TYPE_Q6_K:
662690
return dequantize_row_q6_K_cuda;
691+
case GGML_TYPE_TQ2_0:
692+
return dequantize_row_tq2_0_cuda;
663693
case GGML_TYPE_IQ2_XXS:
664694
return dequantize_row_iq2_xxs_cuda;
665695
case GGML_TYPE_IQ2_XS:

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2860,6 +2860,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
28602860
case GGML_TYPE_Q5_K:
28612861
case GGML_TYPE_Q6_K:
28622862
case GGML_TYPE_Q8_K:
2863+
case GGML_TYPE_TQ2_0:
28632864
case GGML_TYPE_IQ1_M:
28642865
case GGML_TYPE_IQ1_S:
28652866
case GGML_TYPE_IQ2_S:

ggml/src/ggml-cuda/mmq.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ void ggml_cuda_op_mul_mat_q(
6161
case GGML_TYPE_Q6_K:
6262
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
6363
break;
64+
case GGML_TYPE_TQ2_0:
65+
mul_mat_q_case<GGML_TYPE_TQ2_0>(ctx, args, stream);
66+
break;
6467
case GGML_TYPE_IQ2_XXS:
6568
mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
6669
break;
@@ -113,6 +116,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
113116
case GGML_TYPE_Q4_K:
114117
case GGML_TYPE_Q5_K:
115118
case GGML_TYPE_Q6_K:
119+
case GGML_TYPE_TQ2_0:
116120
case GGML_TYPE_IQ2_XXS:
117121
case GGML_TYPE_IQ2_XS:
118122
case GGML_TYPE_IQ2_S:

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
6363
case GGML_TYPE_Q5_K:
6464
return MMQ_Q8_1_DS_LAYOUT_DS4;
6565
case GGML_TYPE_Q6_K:
66+
case GGML_TYPE_TQ2_0:
6667
case GGML_TYPE_IQ2_XXS:
6768
case GGML_TYPE_IQ2_XS:
6869
case GGML_TYPE_IQ2_S:
@@ -139,6 +140,9 @@ static constexpr __device__ int get_mmq_y_device() {
139140
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
140141
}
141142

143+
// tile_x_sizes{qs, dm, sc}
144+
145+
// TODO: TQ2_0 to minimize shared mem
142146
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
143147
#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
144148
#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
@@ -161,6 +165,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
161165
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
162166
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
163167
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
168+
type == GGML_TYPE_TQ2_0 ? MMQ_DP4A_TXS_Q8_0 :
164169
type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
165170
type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
166171
type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
@@ -195,6 +200,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
195200
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
196201
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
197202
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
203+
type == GGML_TYPE_TQ2_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
198204
type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
199205
type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
200206
type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
@@ -1808,6 +1814,103 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
18081814
#endif // INT8_MMA_AVAILABLE
18091815
}
18101816

1817+
// This is the first "simple" type with a block size of 256
1818+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_tq2_0(
1819+
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1820+
1821+
#ifdef INT8_MMA_AVAILABLE
1822+
int * x_qs = (int *) x_tile;
1823+
float * x_df = (float *) (x_tile + 2*WARP_SIZE);
1824+
#else
1825+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y);
1826+
int * x_qs = (int *) x_tile;
1827+
float * x_df = (float *) (x_qs + txs.qs);
1828+
#endif // INT8_MMA_AVAILABLE
1829+
1830+
const int kqsx = threadIdx.x % QI2_0;
1831+
1832+
#pragma unroll
1833+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_0) {
1834+
int i = i0 + threadIdx.y*(WARP_SIZE/QI2_0) + threadIdx.x/QI2_0;
1835+
1836+
if (need_check) {
1837+
i = min(i, i_max);
1838+
}
1839+
1840+
const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
1841+
const int qs0 = get_int_b2(bxi->qs, kqsx);
1842+
1843+
#ifdef INT8_MMA_AVAILABLE
1844+
1845+
#pragma unroll
1846+
for (int l = 0; l < QR2_0; ++l) {
1847+
// 0..7, 32..39
1848+
// 8..15, 40..47
1849+
// 16..23, 48..55
1850+
// 24..31, 56..63
1851+
// FIXME: this might assume WARP_SIZE is >= 32
1852+
const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
1853+
1854+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101);
1855+
}
1856+
#else
1857+
x_qs[i*(2*WARP_SIZE + 1) + kqsx] = qs0;
1858+
#endif // INT8_MMA_AVAILABLE
1859+
}
1860+
1861+
// TODO: does this work with WARP_SIZE != 32?
1862+
#pragma unroll
1863+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_0/2)) {
1864+
int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_0) + threadIdx.x/(QI2_0/2);
1865+
1866+
if (need_check) {
1867+
i = min(i, i_max);
1868+
}
1869+
1870+
const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
1871+
1872+
const int k = threadIdx.x % (QI2_0/2);
1873+
1874+
#ifdef INT8_MMA_AVAILABLE
1875+
1876+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d;
1877+
#else
1878+
x_df[i*(WARP_SIZE/4) + i/4 + k] = bxi->d;
1879+
#endif // INT8_MMA_AVAILABLE
1880+
}
1881+
}
1882+
1883+
template <int mmq_x, int mmq_y, int nwarps>
1884+
static __device__ __forceinline__ void vec_dot_tq2_0_q8_1_dp4a(
1885+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1886+
1887+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y);
1888+
const int * x_qs = (const int *) x;
1889+
const float * x_df = (const float *) x_qs + txs.qs;
1890+
const int * y_qs = (const int *) y + 4;
1891+
const float * y_df = (const float *) y;
1892+
1893+
#pragma unroll
1894+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_0*VDR_TQ2_0_Q8_1_MMQ) {
1895+
const int k0 = k00 + k01;
1896+
1897+
#pragma unroll
1898+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1899+
const int j = j0 + threadIdx.y;
1900+
1901+
#pragma unroll
1902+
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1903+
const int i = i0 + threadIdx.x;
1904+
1905+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMQ>(
1906+
&x_qs[i*(2*WARP_SIZE + 1) + k0/QR2_0], &y_qs[j*MMQ_TILE_Y_K + k01],
1907+
x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2)], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1908+
// x_df[i*(WARP_SIZE/QI2_0) + i/QI2_0], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1909+
}
1910+
}
1911+
}
1912+
}
1913+
18111914
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
18121915
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
18131916

@@ -2427,6 +2530,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
24272530
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
24282531
};
24292532

2533+
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2534+
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_TQ2_0> {
2535+
static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ;
2536+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0<mmq_y, nwarps, need_check>;
2537+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2538+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_tq2_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2539+
};
2540+
24302541
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
24312542
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
24322543
static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
@@ -2916,6 +3027,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
29163027
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
29173028
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
29183029
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
3030+
extern DECL_MMQ_CASE(GGML_TYPE_TQ2_0);
29193031
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
29203032
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
29213033
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
1414
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
1515
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
1616
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
17+
type == GGML_TYPE_TQ2_0 ? vec_dot_tq2_0_q8_1 :
1718
type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
1819
type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
1920
type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
@@ -37,6 +38,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
3738
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
3839
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
3940
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
41+
type == GGML_TYPE_TQ2_0 ? VDR_TQ2_0_Q8_1_MMVQ :
4042
type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
4143
type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
4244
type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
@@ -271,6 +273,13 @@ static void mul_mat_vec_q6_K_q8_1_cuda(
271273
mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
272274
}
273275

276+
static void mul_mat_vec_tq2_0_q8_1_cuda(
277+
const void * vx, const void * vy, float * dst,
278+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
279+
280+
mul_mat_vec_q_cuda<GGML_TYPE_TQ2_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
281+
}
282+
274283
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
275284
const void * vx, const void * vy, float * dst,
276285
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
@@ -385,6 +394,9 @@ void ggml_cuda_op_mul_mat_vec_q(
385394
case GGML_TYPE_Q6_K:
386395
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
387396
break;
397+
case GGML_TYPE_TQ2_0:
398+
mul_mat_vec_tq2_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
399+
break;
388400
case GGML_TYPE_IQ2_XXS:
389401
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
390402
break;

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
TYPES_MMQ = [
2424
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
2525
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
26+
"GGML_TYPE_TQ2_0",
2627
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
2728
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
2829
]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmq.cuh"
4+
5+
DECL_MMQ_CASE(GGML_TYPE_TQ2_0);

ggml/src/ggml-cuda/vecdotq.cuh

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,36 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
524524
return d6 * sumf_d;
525525
}
526526

527+
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
528+
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
529+
530+
#define VDR_TQ2_0_Q8_1_MMVQ 2
531+
#define VDR_TQ2_0_Q8_1_MMQ 8
532+
533+
// Can use the same for both mmvq and mmq, because there are no sub-scales in a TQ2_0 block
534+
template <int vdr> static __device__ __forceinline__ float vec_dot_tq2_0_q8_1_impl(
535+
const int * __restrict__ v, const int * __restrict__ u, const float & d2, const float * __restrict__ d8) {
536+
537+
float sumf = 0.0f;
538+
539+
#pragma unroll
540+
for (int i0 = 0; i0 < QR2_0; ++i0) {
541+
int sumi = 0;
542+
543+
#pragma unroll
544+
for (int i = 0; i < vdr; ++i) {
545+
const int vi = (v[i] >> (2*i0)) & 0x03030303;
546+
547+
sumi = ggml_cuda_dp4a(__vsub4(vi, 0x01010101), u[vdr*i0 + i], sumi); // SIMD dot product
548+
}
549+
550+
// TODO: batch subtract by using d8 sum
551+
sumf += d8[i0] * sumi;
552+
}
553+
554+
return d2 * sumf;
555+
}
556+
527557
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
528558
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
529559

@@ -786,6 +816,37 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
786816
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
787817
}
788818

819+
static __device__ __forceinline__ float vec_dot_tq2_0_q8_1(
820+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
821+
822+
const block_tq2_0 * btq2_0 = (const block_tq2_0 *) vbq + kbx;
823+
824+
// iqs 0..7 all need bq8_offset 0, 1, 2, 3
825+
// iqs 8..15 all need bq8_offset 4, 5, 6, 7
826+
const int bq8_offset = QR2_0 * (iqs / 8);
827+
828+
int v[VDR_TQ2_0_Q8_1_MMVQ];
829+
int u[QR2_0*VDR_TQ2_0_Q8_1_MMVQ];
830+
float d8[QR2_0];
831+
832+
#pragma unroll
833+
for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) {
834+
v[i] = get_int_b2(btq2_0->qs, iqs + i);
835+
}
836+
837+
#pragma unroll
838+
for (int i0 = 0; i0 < QR2_0; ++i0) {
839+
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i0;
840+
841+
for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) {
842+
u[VDR_TQ2_0_Q8_1_MMVQ*i0 + i] = get_int_b4(bq8i->qs, (iqs % QI8_1) + i);
843+
}
844+
d8[i0] = __low2float(bq8i->ds);
845+
}
846+
847+
return vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMVQ>(v, u, btq2_0->d, d8);
848+
}
849+
789850
#define VDR_IQ2_XXS_Q8_1_MMVQ 2
790851
#define VDR_IQ2_XXS_Q8_1_MMQ 2
791852

0 commit comments

Comments
 (0)