@@ -155,25 +155,27 @@ static constexpr __device__ int get_mmq_y_device() {
155
155
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8 }
156
156
157
157
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes (ggml_type type, int mmq_y) {
158
- return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
159
- type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
160
- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
161
- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
162
- type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
163
- type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
164
- type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
165
- type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
166
- type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
167
- type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
168
- type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
169
- type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
170
- type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
171
- type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
172
- type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
173
- type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
174
- type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
175
- type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
176
- tile_x_sizes{0 , 0 , 0 };
158
+ switch (type) {
159
+ case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
160
+ case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
161
+ case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
162
+ case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
163
+ case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
164
+ case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
165
+ case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
166
+ case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
167
+ case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K;
168
+ case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K;
169
+ case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
170
+ case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16;
171
+ case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16;
172
+ case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
173
+ case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
174
+ case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
175
+ case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
176
+ case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
177
+ default : return tile_x_sizes{0 , 0 , 0 };
178
+ }
177
179
}
178
180
179
181
#define MMQ_MMA_TILE_X_K_Q8_0 (2 *WARP_SIZE + 2 *WARP_SIZE/QI8_0 + 4 )
@@ -189,25 +191,27 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
189
191
static_assert (MMQ_MMA_TILE_X_K_Q6_K % 8 == 4 , " Wrong padding." );
190
192
191
193
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k (ggml_type type) {
192
- return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
193
- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
194
- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
195
- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
196
- type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
197
- type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
198
- type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
199
- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
200
- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
201
- type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
202
- type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
203
- type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
204
- type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
205
- type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
206
- type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
207
- type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
208
- type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
209
- type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
210
- 0 ;
194
+ switch (type) {
195
+ case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
196
+ case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
197
+ case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
198
+ case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
199
+ case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
200
+ case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
201
+ case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
202
+ case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
203
+ case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1;
204
+ case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K;
205
+ case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
206
+ case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K;
207
+ case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K;
208
+ case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
209
+ case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
210
+ case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
211
+ case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
212
+ case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
213
+ default : return 0 ;
214
+ }
211
215
}
212
216
213
217
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
0 commit comments