@@ -140,9 +140,6 @@ static constexpr __device__ int get_mmq_y_device() {
140
140
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
141
141
}
142
142
143
- // tile_x_sizes{qs, dm, sc}
144
-
145
- // TODO: TQ2_0 to minimize shared mem
146
143
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0 }
147
144
#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0 }
148
145
#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2 /QI8_0 + mmq_y/(QI8_0/2 ), 0 }
@@ -1814,7 +1811,6 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1814
1811
#endif // INT8_MMA_AVAILABLE
1815
1812
}
1816
1813
1817
- // This is the first "simple" type with a block size of 256
1818
1814
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_tq2_0 (
1819
1815
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1820
1816
@@ -1840,22 +1836,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1840
1836
const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
1841
1837
const int qs0 = get_int_b2 (bxi->qs , kqsx);
1842
1838
1843
- #ifdef INT8_MMA_AVAILABLE
1844
-
1845
1839
#pragma unroll
1846
1840
for (int l = 0 ; l < QR2_0; ++l) {
1847
1841
// 0..7, 32..39
1848
1842
// 8..15, 40..47
1849
1843
// 16..23, 48..55
1850
1844
// 24..31, 56..63
1851
- // FIXME: this might assume WARP_SIZE is >= 32
1852
1845
const int k = (kqsx/8 )*32 + l*8 + kqsx % 8 ;
1846
+ const int q = __vsub4 ((qs0 >> (2 *l)) & 0x03030303 , 0x01010101 );
1853
1847
1854
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = __vsub4 ((qs0 >> ( 2 *l)) & 0x03030303 , 0x01010101 );
1855
- }
1848
+ # ifdef INT8_MMA_AVAILABLE
1849
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = q;
1856
1850
#else
1857
- x_qs[i*(2 *WARP_SIZE + 1 ) + kqsx] = qs0;
1851
+ // NOTE: this might assume WARP_SIZE is >= 32
1852
+ x_qs[i*(2 *WARP_SIZE + 1 ) + k] = q;
1858
1853
#endif // INT8_MMA_AVAILABLE
1854
+ }
1859
1855
}
1860
1856
1861
1857
// TODO: does this work with WARP_SIZE != 32?
@@ -1872,45 +1868,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1872
1868
const int k = threadIdx .x % (QI2_0/2 );
1873
1869
1874
1870
#ifdef INT8_MMA_AVAILABLE
1875
-
1876
1871
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d ;
1877
1872
#else
1878
1873
x_df[i*(WARP_SIZE/4 ) + i/4 + k] = bxi->d ;
1879
1874
#endif // INT8_MMA_AVAILABLE
1880
1875
}
1881
1876
}
1882
1877
1883
- template <int mmq_x, int mmq_y, int nwarps>
1884
- static __device__ __forceinline__ void vec_dot_tq2_0_q8_1_dp4a (
1885
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1886
-
1887
- constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_TQ2_0, mmq_y);
1888
- const int * x_qs = (const int *) x;
1889
- const float * x_df = (const float *) x_qs + txs.qs ;
1890
- const int * y_qs = (const int *) y + 4 ;
1891
- const float * y_df = (const float *) y;
1892
-
1893
- #pragma unroll
1894
- for (int k01 = 0 ; k01 < WARP_SIZE; k01 += QR2_0*VDR_TQ2_0_Q8_1_MMQ) {
1895
- const int k0 = k00 + k01;
1896
-
1897
- #pragma unroll
1898
- for (int j0 = 0 ; j0 < mmq_x; j0 += nwarps) {
1899
- const int j = j0 + threadIdx .y ;
1900
-
1901
- #pragma unroll
1902
- for (int i0 = 0 ; i0 < mmq_y; i0 += WARP_SIZE) {
1903
- const int i = i0 + threadIdx .x ;
1904
-
1905
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMQ>(
1906
- &x_qs[i*(2 *WARP_SIZE + 1 ) + k0/QR2_0], &y_qs[j*MMQ_TILE_Y_K + k01],
1907
- x_df[i*(2 *WARP_SIZE/QI8_0) + i/(QI8_0/2 )], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1908
- // x_df[i*(WARP_SIZE/QI2_0) + i/QI2_0], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1909
- }
1910
- }
1911
- }
1912
- }
1913
-
1914
1878
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl (
1915
1879
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1916
1880
@@ -2535,7 +2499,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_TQ2_0> {
2535
2499
static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ;
2536
2500
static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0<mmq_y, nwarps, need_check>;
2537
2501
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2538
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_tq2_0_q8_1_dp4a <mmq_x, mmq_y, nwarps>;
2502
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a <mmq_x, mmq_y, nwarps>;
2539
2503
};
2540
2504
2541
2505
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
0 commit comments