Skip to content

Commit 0e0590a

Browse files
authored
cuda : update supports_op for matrix multiplication (#8245)
1 parent a9f3b10 commit 0e0590a

File tree

2 files changed

+31
-17
lines changed

2 files changed

+31
-17
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2711,27 +2711,40 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
27112711
case GGML_OP_MUL_MAT:
27122712
case GGML_OP_MUL_MAT_ID:
27132713
{
2714-
struct ggml_tensor * a;
2715-
struct ggml_tensor * b;
2714+
struct ggml_tensor * a = op->src[0];
27162715
if (op->op == GGML_OP_MUL_MAT) {
2717-
a = op->src[0];
2718-
b = op->src[1];
2719-
} else {
2720-
a = op->src[2];
2721-
b = op->src[1];
2722-
}
2723-
if (a->ne[3] != b->ne[3]) {
2724-
return false;
2725-
}
2726-
ggml_type a_type = a->type;
2727-
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
2728-
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
2729-
a_type == GGML_TYPE_IQ1_M || a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
2730-
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
2716+
struct ggml_tensor * b = op->src[1];
2717+
if (a->ne[3] != b->ne[3]) {
27312718
return false;
27322719
}
27332720
}
2734-
return true;
2721+
switch (a->type) {
2722+
case GGML_TYPE_F32:
2723+
case GGML_TYPE_F16:
2724+
case GGML_TYPE_Q4_0:
2725+
case GGML_TYPE_Q4_1:
2726+
case GGML_TYPE_Q5_0:
2727+
case GGML_TYPE_Q5_1:
2728+
case GGML_TYPE_Q8_0:
2729+
case GGML_TYPE_Q2_K:
2730+
case GGML_TYPE_Q3_K:
2731+
case GGML_TYPE_Q4_K:
2732+
case GGML_TYPE_Q5_K:
2733+
case GGML_TYPE_Q6_K:
2734+
case GGML_TYPE_Q8_K:
2735+
case GGML_TYPE_IQ1_M:
2736+
case GGML_TYPE_IQ1_S:
2737+
case GGML_TYPE_IQ2_S:
2738+
case GGML_TYPE_IQ2_XS:
2739+
case GGML_TYPE_IQ2_XXS:
2740+
case GGML_TYPE_IQ3_S:
2741+
case GGML_TYPE_IQ3_XXS:
2742+
case GGML_TYPE_IQ4_NL:
2743+
case GGML_TYPE_IQ4_XS:
2744+
return true;
2745+
default:
2746+
return false;
2747+
}
27352748
} break;
27362749
case GGML_OP_GET_ROWS:
27372750
{

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
20522052
GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
20532053
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
20542054
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
2055+
GGML_TYPE_BF16,
20552056
};
20562057

20572058
// unary ops

0 commit comments

Comments
 (0)