Skip to content

Commit fb43d5e

Browse files
committed
ggml-cuda : cleanup TQ2_0
This also removes custom TQ2_0 mmq dp4a, because re-using the one from Q8_0 allows avoiding to repeatedly unpack the 2-bit values to 8-bit and instead only do it once per tile.
1 parent 970b5ab commit fb43d5e

File tree

2 files changed

+7
-47
lines changed

2 files changed

+7
-47
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 7 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,6 @@ static constexpr __device__ int get_mmq_y_device() {
140140
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
141141
}
142142

143-
// tile_x_sizes{qs, dm, sc}
144-
145-
// TODO: TQ2_0 to minimize shared mem
146143
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
147144
#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
148145
#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
@@ -1814,7 +1811,6 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
18141811
#endif // INT8_MMA_AVAILABLE
18151812
}
18161813

1817-
// This is the first "simple" type with a block size of 256
18181814
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_tq2_0(
18191815
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
18201816

@@ -1840,22 +1836,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
18401836
const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
18411837
const int qs0 = get_int_b2(bxi->qs, kqsx);
18421838

1843-
#ifdef INT8_MMA_AVAILABLE
1844-
18451839
#pragma unroll
18461840
for (int l = 0; l < QR2_0; ++l) {
18471841
// 0..7, 32..39
18481842
// 8..15, 40..47
18491843
// 16..23, 48..55
18501844
// 24..31, 56..63
1851-
// FIXME: this might assume WARP_SIZE is >= 32
18521845
const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
1846+
const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101);
18531847

1854-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101);
1855-
}
1848+
#ifdef INT8_MMA_AVAILABLE
1849+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = q;
18561850
#else
1857-
x_qs[i*(2*WARP_SIZE + 1) + kqsx] = qs0;
1851+
// NOTE: this might assume WARP_SIZE is >= 32
1852+
x_qs[i*(2*WARP_SIZE + 1) + k] = q;
18581853
#endif // INT8_MMA_AVAILABLE
1854+
}
18591855
}
18601856

18611857
// TODO: does this work with WARP_SIZE != 32?
@@ -1872,45 +1868,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
18721868
const int k = threadIdx.x % (QI2_0/2);
18731869

18741870
#ifdef INT8_MMA_AVAILABLE
1875-
18761871
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d;
18771872
#else
18781873
x_df[i*(WARP_SIZE/4) + i/4 + k] = bxi->d;
18791874
#endif // INT8_MMA_AVAILABLE
18801875
}
18811876
}
18821877

1883-
template <int mmq_x, int mmq_y, int nwarps>
1884-
static __device__ __forceinline__ void vec_dot_tq2_0_q8_1_dp4a(
1885-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1886-
1887-
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y);
1888-
const int * x_qs = (const int *) x;
1889-
const float * x_df = (const float *) x_qs + txs.qs;
1890-
const int * y_qs = (const int *) y + 4;
1891-
const float * y_df = (const float *) y;
1892-
1893-
#pragma unroll
1894-
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_0*VDR_TQ2_0_Q8_1_MMQ) {
1895-
const int k0 = k00 + k01;
1896-
1897-
#pragma unroll
1898-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1899-
const int j = j0 + threadIdx.y;
1900-
1901-
#pragma unroll
1902-
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1903-
const int i = i0 + threadIdx.x;
1904-
1905-
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMQ>(
1906-
&x_qs[i*(2*WARP_SIZE + 1) + k0/QR2_0], &y_qs[j*MMQ_TILE_Y_K + k01],
1907-
x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2)], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1908-
// x_df[i*(WARP_SIZE/QI2_0) + i/QI2_0], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1909-
}
1910-
}
1911-
}
1912-
}
1913-
19141878
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
19151879
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
19161880

@@ -2535,7 +2499,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_TQ2_0> {
25352499
static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ;
25362500
static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0<mmq_y, nwarps, need_check>;
25372501
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2538-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_tq2_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2502+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
25392503
};
25402504

25412505
template <int mmq_x, int mmq_y, int nwarps, bool need_check>

ggml/src/ggml-cuda/vecdotq.cuh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -524,9 +524,6 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
524524
return d6 * sumf_d;
525525
}
526526

527-
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
528-
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
529-
530527
#define VDR_TQ2_0_Q8_1_MMVQ 2
531528
#define VDR_TQ2_0_Q8_1_MMQ 8
532529

@@ -547,7 +544,6 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_tq2_0_q8_1_im
547544
sumi = ggml_cuda_dp4a(__vsub4(vi, 0x01010101), u[vdr*i0 + i], sumi); // SIMD dot product
548545
}
549546

550-
// TODO: batch subtract by using d8 sum
551547
sumf += d8[i0] * sumi;
552548
}
553549

0 commit comments

Comments
 (0)