Skip to content

Commit a0deeee

Browse files
committed
Vulkan: Implement accumulator switch for specific mul mat mat shaders
1 parent 6f2c49c commit a0deeee

File tree

1 file changed

+85
-139
lines changed

1 file changed

+85
-139
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 85 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
17791779

17801780
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
17811781

1782+
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
1783+
// Intel drivers don't support coopmat properly yet
1784+
device->coopmat_support = false;
1785+
}
1786+
17821787
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
17831788

17841789
// Try to find a non-graphics compute queue and transfer-focused queues
@@ -1945,9 +1950,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
19451950
ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
19461951

19471952
// Shaders
1948-
// Disable matmul tile sizes early if not supported
1953+
// Disable matmul tile sizes early if performance low or not supported
19491954
switch (device->vendor_id) {
19501955
case VK_VENDOR_ID_AMD:
1956+
case VK_VENDOR_ID_INTEL:
19511957
device->mul_mat_l = false;
19521958
device->mul_mat_m = true;
19531959
device->mul_mat_s = true;
@@ -1963,14 +1969,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
19631969
device->mul_mat_id_m = true;
19641970
device->mul_mat_id_s = false;
19651971
break;
1966-
case VK_VENDOR_ID_INTEL:
1967-
device->mul_mat_l = false;
1968-
device->mul_mat_m = false;
1969-
device->mul_mat_s = true;
1970-
device->mul_mat_id_l = false;
1971-
device->mul_mat_id_m = false;
1972-
device->mul_mat_id_s = true;
1973-
break;
19741972
default:
19751973
device->mul_mat_l = true;
19761974
device->mul_mat_m = true;
@@ -2050,6 +2048,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
20502048
}
20512049
}
20522050

2051+
if (props2.properties.vendorID == VK_VENDOR_ID_INTEL) {
2052+
// Intel drivers don't support coopmat properly yet
2053+
coopmat_support = false;
2054+
}
2055+
20532056
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
20542057
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
20552058

@@ -3025,20 +3028,33 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
30253028
return split_k;
30263029
}
30273030

3028-
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
3031+
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type type_a) {
30293032
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
3033+
3034+
// On F32 matmuls, selecting this way increases performance significantly. On quants or fp16, it reduces performance.
3035+
// Maybe because it reduces checks and uses more vector loads, but why is fp16 worse?
3036+
if (type_a == GGML_TYPE_F32) {
3037+
if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n & mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
3038+
return aligned ? mmp->a_l : mmp->l;
3039+
}
3040+
if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
3041+
return aligned ? mmp->a_m : mmp->m;
3042+
}
3043+
return aligned ? mmp->a_s : mmp->s;
3044+
}
3045+
30303046
if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
30313047
return aligned ? mmp->a_s : mmp->s;
30323048
}
3033-
if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64 || ctx->device->coopmat_support)) || !ctx->device->mul_mat_l) {
3049+
if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
30343050
return aligned ? mmp->a_m : mmp->m;
30353051
}
30363052
return aligned ? mmp->a_l : mmp->l;
30373053
}
30383054

3039-
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
3055+
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type type_a) {
30403056
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
3041-
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
3057+
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, type_a)->align;
30423058
}
30433059

30443060
static void ggml_vk_matmul(
@@ -3227,10 +3243,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
32273243
const int y_ne = ne11 * ne10;
32283244
const int d_ne = ne11 * ne01;
32293245

3230-
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
3246+
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, src0->type));
32313247
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
32323248

3233-
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
3249+
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, src0->type);
32343250

32353251
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
32363252

@@ -5521,13 +5537,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
55215537
vk_pipeline p;
55225538
std::string shname;
55235539
if (shader_size == 0) {
5524-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_s;
5540+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
55255541
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
55265542
} else if (shader_size == 1) {
5527-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_m;
5543+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
55285544
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
55295545
} else if (shader_size == 2) {
5530-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_l;
5546+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
55315547
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
55325548
} else {
55335549
GGML_ASSERT(0);
@@ -5537,13 +5553,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
55375553

55385554
if (k != kpad) {
55395555
if (shader_size == 0) {
5540-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->s;
5556+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
55415557
shname = std::string(ggml_type_name(quant)) + "_S";
55425558
} else if (shader_size == 1) {
5543-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->m;
5559+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
55445560
shname = std::string(ggml_type_name(quant)) + "_M";
55455561
} else if (shader_size == 2) {
5546-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->l;
5562+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
55475563
shname = std::string(ggml_type_name(quant)) + "_L";
55485564
} else {
55495565
GGML_ASSERT(0);
@@ -5593,16 +5609,16 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
55935609
ggml_vk_buffer_write(y_buf, 0, y, y_sz);
55945610

55955611
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
5612+
ggml_vk_ctx_begin(ctx->device, subctx);
55965613
for (size_t i = 0; i < num_it; i++) {
5597-
ggml_vk_ctx_begin(ctx->device, subctx);
55985614
ggml_vk_matmul(
55995615
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
56005616
m, n, k,
56015617
k, k, m, k*m, k*n, m*n,
56025618
split_k, batch, batch, batch, 1, 1
56035619
);
5604-
ggml_vk_ctx_end(subctx);
56055620
}
5621+
ggml_vk_ctx_end(subctx);
56065622

56075623
auto begin = std::chrono::high_resolution_clock::now();
56085624

@@ -5702,109 +5718,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
57025718

57035719
static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
57045720
#if defined(GGML_VULKAN_RUN_TESTS)
5705-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
5706-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
5707-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
5708-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
5709-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
5710-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
5711-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
5712-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
5713-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
5714-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
5715-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
5716-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL);
5717-
5718-
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, 4, 4, 4, 1, 1, 1, 0);
5719-
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, 16, 16, 16, 1, 1, 1, 0);
5720-
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, 32, 32, 16, 1, 1, 1, 0);
5721-
5722-
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, 512, 512, 100, 32, 100, 1, 2);
5723-
5724-
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
5725-
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
5726-
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
5727-
// ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
5728-
// ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
5729-
// ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
5730-
5731-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
5732-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
5733-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
5734-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
5735-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
5736-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
5737-
5738-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
5739-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
5740-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
5741-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
5742-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
5743-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
5744-
5745-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
5746-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
5747-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
5748-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
5749-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
5750-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
5751-
5752-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
5753-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
5754-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
5755-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
5756-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
5757-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
5758-
5759-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
5760-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
5761-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
5762-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
5763-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
5764-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
5765-
5766-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q2_K);
5767-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q2_K);
5768-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q2_K);
5769-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q2_K);
5770-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q2_K);
5771-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q2_K);
5772-
5773-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q3_K);
5774-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q3_K);
5775-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q3_K);
5776-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q3_K);
5777-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q3_K);
5778-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q3_K);
5779-
5780-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_K);
5781-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_K);
5782-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_K);
5783-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_K);
5784-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_K);
5785-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_K);
5786-
5787-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_K);
5788-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_K);
5789-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_K);
5790-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_K);
5791-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_K);
5792-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_K);
5793-
5794-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q6_K);
5795-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q6_K);
5796-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q6_K);
5797-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q6_K);
5798-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q6_K);
5799-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q6_K);
5800-
5801-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_IQ4_NL);
5802-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_IQ4_NL);
5803-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_IQ4_NL);
5804-
5805-
std::cerr << std::endl;
5806-
58075721
const std::vector<size_t> vals {
5722+
512, 512, 128,
5723+
128, 512, 512,
5724+
4096, 512, 4096,
5725+
11008, 512, 4096,
5726+
4096, 512, 11008,
5727+
32000, 512, 4096,
58085728
8, 8, 8,
58095729
100, 46, 576,
58105730
623, 111, 128,
@@ -5817,25 +5737,51 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
58175737
49, 49, 128,
58185738
128, 49, 49,
58195739
4096, 49, 4096,
5820-
11008, 49, 4096,
5821-
4096, 49, 11008,
5822-
32000, 49, 4096,
5823-
512, 512, 128,
5824-
128, 512, 512,
5825-
4096, 512, 4096,
5826-
11008, 512, 4096,
5827-
4096, 512, 11008,
5828-
32000, 512, 4096,
58295740
};
5830-
const size_t num_it = 1;
5741+
const size_t num_it = 100;
58315742
for (size_t i = 0; i < vals.size(); i += 3) {
58325743
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
58335744
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
58345745
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
5835-
// ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
5836-
// ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
5837-
// ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
5838-
std::cerr << std::endl;
5746+
std::cerr << '\n';
5747+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0);
5748+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1);
5749+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2);
5750+
std::cerr << '\n';
5751+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
5752+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
5753+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
5754+
std::cerr << '\n' << std::endl;
5755+
5756+
if (vals[i + 2] % 32 == 0) {
5757+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0);
5758+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0);
5759+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0);
5760+
std::cerr << '\n';
5761+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0);
5762+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0);
5763+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0);
5764+
std::cerr << '\n';
5765+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0);
5766+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0);
5767+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0);
5768+
std::cerr << '\n' << std::endl;
5769+
}
5770+
5771+
if (vals[i + 2] % 256 == 0) {
5772+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K);
5773+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K);
5774+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K);
5775+
std::cerr << '\n';
5776+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K);
5777+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K);
5778+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K);
5779+
std::cerr << '\n';
5780+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K);
5781+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K);
5782+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K);
5783+
std::cerr << '\n' << std::endl;
5784+
}
58395785
}
58405786

58415787
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)