Skip to content

ggml-vulkan: adds support for op CONV_TRANSPOSE_1D #13813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ struct vk_device_struct {
vk_pipeline pipeline_count_equal_i32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_conv_transpose_1d_f32;
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
Expand Down Expand Up @@ -706,6 +707,21 @@ struct vk_op_timestep_embedding_push_constants {
uint32_t max_period;
};

struct vk_op_conv_transpose_1d_push_constants {
uint32_t Cout;
uint32_t Cin;
uint32_t K;
uint32_t L;
uint32_t KL;

uint32_t nb01;
uint32_t nb02;
uint32_t nb11;
uint32_t nb1;

int32_t s0;
};

struct vk_op_pool2d_push_constants {
uint32_t IW; uint32_t IH;
uint32_t OW; uint32_t OH;
Expand Down Expand Up @@ -2727,6 +2743,8 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
Expand Down Expand Up @@ -6391,6 +6409,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_timestep_embedding_f32;
}
return nullptr;
case GGML_OP_CONV_TRANSPOSE_1D:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_conv_transpose_1d_f32;
}
return nullptr;
case GGML_OP_POOL_2D:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_pool2d_f32;
Expand Down Expand Up @@ -6725,6 +6748,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
uint32_t half_ceil = (dim + 1) / 2;
elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
} break;
case GGML_OP_CONV_TRANSPOSE_1D:
{
elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
} break;
case GGML_OP_POOL_2D:
{
const uint32_t N = dst->ne[3];
Expand Down Expand Up @@ -7528,6 +7555,37 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
}, dryrun);
}

static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
// src0: (K, Cout, Cin, 1) -- kernel
// src1: (L, Cin, 1, 1) -- input
// dst: (*, Cout, 1, 1)

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

GGML_TENSOR_BINARY_OP_LOCALS

GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));

const int32_t s0 = dst->op_params[0];

vk_op_conv_transpose_1d_push_constants p{};
p.Cout = static_cast<uint32_t>(ne01);
p.Cin = static_cast<uint32_t>(ne02);
p.K = static_cast<uint32_t>(ne00);
p.L = static_cast<uint32_t>(ne10);
p.KL = static_cast<uint32_t>(ne0);
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.s0 = static_cast<uint32_t>(s0);

ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
}

static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
const int32_t k1 = dst->op_params[1];
Expand Down Expand Up @@ -8599,6 +8657,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
Expand Down Expand Up @@ -8663,6 +8722,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_LEAKY_RELU:
Expand Down Expand Up @@ -8834,6 +8894,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_TIMESTEP_EMBEDDING:
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);

break;
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun);

break;
case GGML_OP_POOL_2D:
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
Expand Down Expand Up @@ -8962,6 +9026,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
Expand Down Expand Up @@ -9971,6 +10036,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW:
return true;
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
default:
return false;
}
Expand Down Expand Up @@ -10462,6 +10529,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
const int32_t dim = tensor->op_params[0];
const int32_t max_period = tensor->op_params[1];
tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
} else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){
const int32_t s0 = tensor->op_params[0];
const int32_t p0 = tensor->op_params[1];
const int32_t d0 = tensor->op_params[2];
tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
} else if (tensor->op == GGML_OP_POOL_2D) {
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
const int32_t k0 = tensor->op_params[1];
Expand Down
98 changes: 98 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#version 450

#include "types.comp"

layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]

layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;

layout (push_constant) uniform parameter {
uint32_t Cout;
uint32_t Cin;
uint32_t K;
uint32_t L;
uint32_t KL;

uint32_t nb01;
uint32_t nb02;
uint32_t nb11;
uint32_t nb1;

int32_t s0;
} p;


uint32_t Cout_idx = gl_WorkGroupID.x;
const uint32_t bs = gl_WorkGroupSize.x;
uint32_t tid = gl_LocalInvocationID.x;
// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
uint32_t tmp_len = bs*p.s0+p.K;
shared D_TYPE tmp[4096];

uint splitWork(uint workSize){
return (bs + workSize -1) / bs;
}

void main(){
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
uint32_t idx = i*bs+tid;
if(idx < tmp_len){
tmp[idx] = 0.0;
}
}

uint32_t L_blocks = splitWork(p.L);
for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
if(L_block_id > 0){
barrier();
// Shift values in tmp to the current processing window
for(int i = 0; i < splitWork(tmp_len); i++){
uint32_t idx = i*bs+tid;
if(idx >= bs*p.s0 && idx < tmp_len){
tmp[idx-bs*p.s0] = tmp[idx];
tmp[idx] = 0.0;
}else if(idx >= p.K && idx < bs*p.s0){
tmp[idx] = 0.0;
}
}
}
barrier();

// Save contributions of the block to tmp
uint32_t L_idx = L_block_id*bs + tid;
for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
D_TYPE dp = 0.0;
for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
if(L_idx < p.L){
B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
dp = fma(elemKrn, elemInp, dp);
}
}
tmp[tid*p.s0 + K_idx] += dp;
barrier();
}

// Save the computed values except the last block that can have different size
uint32_t KLb_idx = L_block_id*bs*p.s0;
if(L_block_id < L_blocks-1){
for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
uint32_t sh_idx = p.s0*tid+s0_idx;
uint32_t KL_idx = KLb_idx+sh_idx;
if(KL_idx < p.KL){
data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
}
}
}
}

for(uint32_t i = 0; i < splitWork(tmp_len); i++){
uint32_t idx = i*bs+tid;
uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
if(KL_idx < p.KL){
data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
}
}
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,8 @@ void process_shaders() {

string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});

string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
Expand Down
16 changes: 14 additions & 2 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2706,8 +2706,8 @@ struct test_conv_transpose_1d : public test_case {
return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);
}

test_conv_transpose_1d(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_height, input_channels, 1]
std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, kernel_height, input_channels, 1]
test_conv_transpose_1d(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_channels, 1 /* assert in cpu kernel*/, 1 (should be batch)]
std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, output_channels, input_channels, 1 (should be batch)]
int s0 = 1, int p0 = 0, int d0 = 1)
: ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0) {}

Expand Down Expand Up @@ -4029,6 +4029,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));

for(uint32_t Cout : {1, 9}){
for(uint32_t Cin : {1, 7}){
for(uint32_t K : {1, 3, 1337}){
for(uint32_t L : {1, 2, 13}){
for(uint32_t s0: {1, 2, 3}){
test_cases.emplace_back(new test_conv_transpose_1d({L,Cin,1,1}, {K,Cout,Cin,1}, s0, 0, 1));
}
}
}
}
}

test_cases.emplace_back(new test_conv_transpose_1d());
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
Expand Down
Loading