Skip to content

sync : ggml #13268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ struct vk_device_struct {
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_conv2d_dw_whcn_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32;

// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
Expand Down Expand Up @@ -701,6 +703,24 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t H;
};

struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
uint32_t channels;
uint32_t dst_w;
uint32_t dst_h;
uint32_t src_w;
uint32_t src_h;
uint32_t knl_w;
uint32_t knl_h;
int32_t stride_x;
int32_t stride_y;
int32_t pad_x;
int32_t pad_y;
int32_t dilation_x;
int32_t dilation_y;
};

struct vk_op_upscale_push_constants {
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
Expand Down Expand Up @@ -2610,6 +2630,9 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);

for (auto &c : compiles) {
c.wait();
}
Expand Down Expand Up @@ -6137,6 +6160,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_leaky_relu_f32;
}
return nullptr;
case GGML_OP_CONV_2D_DW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (ggml_is_contiguous(src1)) {
return ctx->device->pipeline_conv2d_dw_whcn_f32;
} else if (ggml_is_contiguous_channels(src1)) {
return ctx->device->pipeline_conv2d_dw_cwhn_f32;
}
}
return nullptr;
default:
return nullptr;
}
Expand All @@ -6163,6 +6195,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_REPEAT_BACK:
case GGML_OP_ROPE:
case GGML_OP_RMS_NORM:
case GGML_OP_CONV_2D_DW:
return true;
default:
return false;
Expand Down Expand Up @@ -6459,6 +6492,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_UNARY:
case GGML_OP_CONV_2D_DW:
{
const uint32_t ne = ggml_nelements(dst);
if (ne > 262144) {
Expand Down Expand Up @@ -7245,6 +7279,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
}, dryrun);
}

static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
vk_op_conv2d_dw_push_constants p{};
p.ne = ggml_nelements(dst);
p.channels = dst->ne[2];
p.batches = dst->ne[3];
p.dst_w = dst->ne[0];
p.dst_h = dst->ne[1];
p.src_w = src1->ne[0];
p.src_h = src1->ne[1];
p.knl_w = src0->ne[0];
p.knl_h = src0->ne[1];
p.stride_x = dst->op_params[0];
p.stride_y = dst->op_params[1];
p.pad_x = dst->op_params[2];
p.pad_y = dst->op_params[3];
p.dilation_x = dst->op_params[4];
p.dilation_y = dst->op_params[5];

GGML_ASSERT(src0->ne[3] == p.channels);
GGML_ASSERT(src1->ne[3] == p.batches);

ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
}

static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const float * op_params = (const float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
Expand Down Expand Up @@ -8265,6 +8323,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
Expand Down Expand Up @@ -8328,6 +8387,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_LEAKY_RELU:
{
// These operations all go through ggml_vk_op_f32, so short-circuit and
Expand Down Expand Up @@ -8501,6 +8561,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_POOL_2D:
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);

break;
case GGML_OP_CONV_2D_DW:
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);

break;
case GGML_OP_LEAKY_RELU:
ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
Expand Down Expand Up @@ -8622,6 +8686,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
Expand Down Expand Up @@ -9599,6 +9664,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_2D_DW:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
Expand Down
105 changes: 105 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#version 450

#include "types.comp"

layout (push_constant) uniform parameter
{
uint ne;
uint batches;
uint channels;
uint dst_w;
uint dst_h;
uint src_w;
uint src_h;
uint knl_w;
uint knl_h;
int stride_x;
int stride_y;
int pad_x;
int pad_y;
int dilation_x;
int dilation_y;
} p;

layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
uint i0 = idx / p.dst_w;
uint dst_x = idx - i0 * p.dst_w;
uint i1 = i0 / p.dst_h;
uint dst_y = i0 - i1 * p.dst_h;
uint n = i1 / p.channels;
uint c = i1 - n * p.channels;

uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
uint knl_i = c * p.knl_h * p.knl_w;

FLOAT_TYPE sum = 0.0;
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
continue;
}
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
continue;
}
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
sum = fma(v, k, sum);
}
}
return sum;
}

FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
uint i0 = idx / p.channels;
uint c = idx - i0 * p.channels;
uint i1 = i0 / p.dst_w;
uint dst_x = i0 - i1 * p.dst_w;
uint n = i1 / p.dst_h;
uint dst_y = i1 - n * p.dst_h;

uint src_i = n * p.channels * p.src_h * p.src_w;
uint src_row = p.src_w * p.channels;
uint knl_row = p.knl_w * p.channels;

FLOAT_TYPE sum = 0.0;
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
continue;
}
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
continue;
}
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
sum = fma(v, k, sum);
}
}
return sum;
}

void main() {
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (idx >= p.ne) {
return;
}

FLOAT_TYPE result =
#ifdef WHCN
conv_2d_dw_whcn(idx);
#else
conv_2d_dw_cwhn(idx);
#endif
dst_data[idx] = D_TYPE(result);
}

3 changes: 3 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,9 @@ void process_shaders() {

string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));

for (auto &c : compiles) {
c.wait();
}
Expand Down
2 changes: 1 addition & 1 deletion scripts/sync-ggml.last
Original file line number Diff line number Diff line change
@@ -1 +1 @@
f3a375f20bf56860b30e7c511d03593a1e393345
0482de9c63b9134eb462c7732888c0ee0dbc2755
50 changes: 50 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2765,6 +2765,48 @@ struct test_im2col : public test_case {
}
};

// GGML_OP_CONV_2D_DW
struct test_conv_2d_dw : public test_case {
const std::array<int64_t, 4> ne_input;
const std::array<int64_t, 4> ne_kernel;
const int stride;
const int padding;
const int dilation;
const bool cwhn;

std::string vars() override {
return VARS_TO_STR6(ne_input, ne_kernel, stride, padding, dilation, cwhn);
}

test_conv_2d_dw(std::array<int64_t, 4> ne_input = {64, 64, 16, 1},
std::array<int64_t, 4> ne_kernel = {3, 3, 1, 16},
int stride = 1, int padding = 0, int dilation = 1, bool cwhn = false)
: ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), padding(padding), dilation(dilation), cwhn(cwhn) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
ggml_set_name(input, "input");

ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
ggml_set_name(kernel, "kernel");

if (cwhn) {
// change memory layout to channel-most-contiguous (CWHN),
// then permute it back so NE matches the original input
input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
input = ggml_permute(ctx, input, 2, 0, 1, 3);
kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
}

ggml_tensor * out = ggml_conv_2d_dw_direct(
ctx, kernel, input,
stride, stride, padding, padding, dilation, dilation);
ggml_set_name(out, "out");
return out;
}
};

// GGML_OP_CONCAT
struct test_concat : public test_case {
const ggml_type type;
Expand Down Expand Up @@ -3975,6 +4017,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));

test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, false));
test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, true));
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));

test_cases.emplace_back(new test_conv_transpose_1d());
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
Expand Down Expand Up @@ -4549,6 +4596,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
}

test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));

return test_cases;
}

Expand Down
Loading