Skip to content

Commit 0d39844

Browse files
authored
ggml-vulkan: adds support for op CONV_TRANSPOSE_1D (#13813)
* * ggml-vulkan: adds op CONV_TRANSPOSE_1D * test-backend-ops: adds more spohisticated tests for CONV_TRANSPOSE_1D * Missing barrier added to shader. Number of additional tests reduced to 108. * * Fixes typo in variable name. * Removes extra whitespaces. * Adds int64->int32 casts to prevent possible warnings. * Problem size reduced in tests to pass tests with llvmpipe. * supports_op condition moved from unintended position
1 parent 3e63a58 commit 0d39844

File tree

4 files changed

+186
-2
lines changed

4 files changed

+186
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ struct vk_device_struct {
396396
vk_pipeline pipeline_count_equal_i32;
397397
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
398398
vk_pipeline pipeline_timestep_embedding_f32;
399+
vk_pipeline pipeline_conv_transpose_1d_f32;
399400
vk_pipeline pipeline_pool2d_f32;
400401
vk_pipeline pipeline_rwkv_wkv6_f32;
401402
vk_pipeline pipeline_rwkv_wkv7_f32;
@@ -706,6 +707,21 @@ struct vk_op_timestep_embedding_push_constants {
706707
uint32_t max_period;
707708
};
708709

710+
struct vk_op_conv_transpose_1d_push_constants {
711+
uint32_t Cout;
712+
uint32_t Cin;
713+
uint32_t K;
714+
uint32_t L;
715+
uint32_t KL;
716+
717+
uint32_t nb01;
718+
uint32_t nb02;
719+
uint32_t nb11;
720+
uint32_t nb1;
721+
722+
int32_t s0;
723+
};
724+
709725
struct vk_op_pool2d_push_constants {
710726
uint32_t IW; uint32_t IH;
711727
uint32_t OW; uint32_t OH;
@@ -2726,6 +2742,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
27262742

27272743
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
27282744

2745+
ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
2746+
27292747
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
27302748

27312749
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
@@ -6392,6 +6410,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
63926410
return ctx->device->pipeline_timestep_embedding_f32;
63936411
}
63946412
return nullptr;
6413+
case GGML_OP_CONV_TRANSPOSE_1D:
6414+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6415+
return ctx->device->pipeline_conv_transpose_1d_f32;
6416+
}
6417+
return nullptr;
63956418
case GGML_OP_POOL_2D:
63966419
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
63976420
return ctx->device->pipeline_pool2d_f32;
@@ -6726,6 +6749,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
67266749
uint32_t half_ceil = (dim + 1) / 2;
67276750
elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
67286751
} break;
6752+
case GGML_OP_CONV_TRANSPOSE_1D:
6753+
{
6754+
elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
6755+
} break;
67296756
case GGML_OP_POOL_2D:
67306757
{
67316758
const uint32_t N = dst->ne[3];
@@ -7529,6 +7556,37 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
75297556
}, dryrun);
75307557
}
75317558

7559+
static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7560+
// src0: (K, Cout, Cin, 1) -- kernel
7561+
// src1: (L, Cin, 1, 1) -- input
7562+
// dst: (*, Cout, 1, 1)
7563+
7564+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
7565+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
7566+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
7567+
7568+
GGML_TENSOR_BINARY_OP_LOCALS
7569+
7570+
GGML_ASSERT(nb00 == sizeof(float));
7571+
GGML_ASSERT(nb10 == sizeof(float));
7572+
7573+
const int32_t s0 = dst->op_params[0];
7574+
7575+
vk_op_conv_transpose_1d_push_constants p{};
7576+
p.Cout = static_cast<uint32_t>(ne01);
7577+
p.Cin = static_cast<uint32_t>(ne02);
7578+
p.K = static_cast<uint32_t>(ne00);
7579+
p.L = static_cast<uint32_t>(ne10);
7580+
p.KL = static_cast<uint32_t>(ne0);
7581+
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
7582+
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
7583+
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
7584+
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
7585+
p.s0 = static_cast<uint32_t>(s0);
7586+
7587+
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
7588+
}
7589+
75327590
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
75337591
uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
75347592
const int32_t k1 = dst->op_params[1];
@@ -8600,6 +8658,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
86008658
case GGML_OP_COUNT_EQUAL:
86018659
case GGML_OP_IM2COL:
86028660
case GGML_OP_TIMESTEP_EMBEDDING:
8661+
case GGML_OP_CONV_TRANSPOSE_1D:
86038662
case GGML_OP_POOL_2D:
86048663
case GGML_OP_CONV_2D_DW:
86058664
case GGML_OP_RWKV_WKV6:
@@ -8664,6 +8723,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
86648723
case GGML_OP_COUNT_EQUAL:
86658724
case GGML_OP_IM2COL:
86668725
case GGML_OP_TIMESTEP_EMBEDDING:
8726+
case GGML_OP_CONV_TRANSPOSE_1D:
86678727
case GGML_OP_POOL_2D:
86688728
case GGML_OP_CONV_2D_DW:
86698729
case GGML_OP_LEAKY_RELU:
@@ -8835,6 +8895,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
88358895
case GGML_OP_TIMESTEP_EMBEDDING:
88368896
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
88378897

8898+
break;
8899+
case GGML_OP_CONV_TRANSPOSE_1D:
8900+
ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun);
8901+
88388902
break;
88398903
case GGML_OP_POOL_2D:
88408904
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
@@ -8963,6 +9027,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
89639027
case GGML_OP_COUNT_EQUAL:
89649028
case GGML_OP_IM2COL:
89659029
case GGML_OP_TIMESTEP_EMBEDDING:
9030+
case GGML_OP_CONV_TRANSPOSE_1D:
89669031
case GGML_OP_POOL_2D:
89679032
case GGML_OP_CONV_2D_DW:
89689033
case GGML_OP_RWKV_WKV6:
@@ -10024,6 +10089,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1002410089
case GGML_OP_LEAKY_RELU:
1002510090
case GGML_OP_OPT_STEP_ADAMW:
1002610091
return true;
10092+
case GGML_OP_CONV_TRANSPOSE_1D:
10093+
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1002710094
default:
1002810095
return false;
1002910096
}
@@ -10515,6 +10582,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1051510582
const int32_t dim = tensor->op_params[0];
1051610583
const int32_t max_period = tensor->op_params[1];
1051710584
tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
10585+
} else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){
10586+
const int32_t s0 = tensor->op_params[0];
10587+
const int32_t p0 = tensor->op_params[1];
10588+
const int32_t d0 = tensor->op_params[2];
10589+
tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
1051810590
} else if (tensor->op == GGML_OP_POOL_2D) {
1051910591
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
1052010592
const int32_t k0 = tensor->op_params[1];
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
5+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
6+
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
7+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]
8+
9+
layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
10+
11+
layout (push_constant) uniform parameter {
12+
uint32_t Cout;
13+
uint32_t Cin;
14+
uint32_t K;
15+
uint32_t L;
16+
uint32_t KL;
17+
18+
uint32_t nb01;
19+
uint32_t nb02;
20+
uint32_t nb11;
21+
uint32_t nb1;
22+
23+
int32_t s0;
24+
} p;
25+
26+
27+
uint32_t Cout_idx = gl_WorkGroupID.x;
28+
const uint32_t bs = gl_WorkGroupSize.x;
29+
uint32_t tid = gl_LocalInvocationID.x;
30+
// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
31+
uint32_t tmp_len = bs*p.s0+p.K;
32+
shared D_TYPE tmp[4096];
33+
34+
uint splitWork(uint workSize){
35+
return (bs + workSize -1) / bs;
36+
}
37+
38+
void main(){
39+
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
40+
uint32_t idx = i*bs+tid;
41+
if(idx < tmp_len){
42+
tmp[idx] = 0.0;
43+
}
44+
}
45+
46+
uint32_t L_blocks = splitWork(p.L);
47+
for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
48+
if(L_block_id > 0){
49+
barrier();
50+
// Shift values in tmp to the current processing window
51+
for(int i = 0; i < splitWork(tmp_len); i++){
52+
uint32_t idx = i*bs+tid;
53+
if(idx >= bs*p.s0 && idx < tmp_len){
54+
tmp[idx-bs*p.s0] = tmp[idx];
55+
tmp[idx] = 0.0;
56+
}else if(idx >= p.K && idx < bs*p.s0){
57+
tmp[idx] = 0.0;
58+
}
59+
}
60+
}
61+
barrier();
62+
63+
// Save contributions of the block to tmp
64+
uint32_t L_idx = L_block_id*bs + tid;
65+
for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
66+
D_TYPE dp = 0.0;
67+
for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
68+
A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
69+
if(L_idx < p.L){
70+
B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
71+
dp = fma(elemKrn, elemInp, dp);
72+
}
73+
}
74+
tmp[tid*p.s0 + K_idx] += dp;
75+
barrier();
76+
}
77+
78+
// Save the computed values except the last block that can have different size
79+
uint32_t KLb_idx = L_block_id*bs*p.s0;
80+
if(L_block_id < L_blocks-1){
81+
for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
82+
uint32_t sh_idx = p.s0*tid+s0_idx;
83+
uint32_t KL_idx = KLb_idx+sh_idx;
84+
if(KL_idx < p.KL){
85+
data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
86+
}
87+
}
88+
}
89+
}
90+
91+
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
92+
uint32_t idx = i*bs+tid;
93+
uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
94+
if(KL_idx < p.KL){
95+
data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
96+
}
97+
}
98+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,8 @@ void process_shaders() {
622622

623623
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
624624

625+
string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
626+
625627
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
626628

627629
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

tests/test-backend-ops.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2706,8 +2706,8 @@ struct test_conv_transpose_1d : public test_case {
27062706
return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);
27072707
}
27082708

2709-
test_conv_transpose_1d(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_height, input_channels, 1]
2710-
std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, kernel_height, input_channels, 1]
2709+
test_conv_transpose_1d(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_channels, 1 /* assert in cpu kernel*/, 1 (should be batch)]
2710+
std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, output_channels, input_channels, 1 (should be batch)]
27112711
int s0 = 1, int p0 = 0, int d0 = 1)
27122712
: ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0) {}
27132713

@@ -4029,6 +4029,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
40294029
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
40304030
test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));
40314031

4032+
for(uint32_t Cout : {1, 9}){
4033+
for(uint32_t Cin : {1, 7}){
4034+
for(uint32_t K : {1, 3, 1337}){
4035+
for(uint32_t L : {1, 2, 13}){
4036+
for(uint32_t s0: {1, 2, 3}){
4037+
test_cases.emplace_back(new test_conv_transpose_1d({L,Cin,1,1}, {K,Cout,Cin,1}, s0, 0, 1));
4038+
}
4039+
}
4040+
}
4041+
}
4042+
}
4043+
40324044
test_cases.emplace_back(new test_conv_transpose_1d());
40334045
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
40344046
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));

0 commit comments

Comments
 (0)