Skip to content

CUDA: add conv_2d_transpose #14287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions ggml/src/ggml-cuda/conv2d-transpose.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include <algorithm>

#include "conv2d-transpose.cuh"
#include "ggml.h"

__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
const int out_h, const int kernel_w, const int kernel_h, const int stride,
const int c_in, const int c_out, const int batches) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;

const int total_elements = out_w * out_h * c_out * batches;

if (global_idx >= total_elements) {
return;
}

const int out_x_idx = global_idx % out_w;
const int out_y_idx = (global_idx / out_w) % out_h;
const int c_idx = (global_idx / (out_w * out_h)) % c_out;
const int n_idx = global_idx / (out_w * out_h * c_out);

float accumulator = 0;
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds

for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
for (int kh = 0; kh < kernel_h; ++kh) {
int in_y = out_y_idx - kh;
if (in_y < 0 || in_y % stride) continue;
in_y /= stride;
if (in_y >= in_h) continue;

for (int kw = 0; kw < kernel_w; ++kw) {
int in_x = out_x_idx - kw;
if (in_x < 0 || in_x % stride) continue;
in_x /= stride;
if (in_x >= in_w) continue;

const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
const int kernel_idx =
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;

float input_val = input[input_idx];
half kern_val = kernel[kernel_idx];

accumulator += input_val * (float) kern_val;
}
}
}

output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
}

//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];

GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);

const float * input_data = (const float *) input->data;
float * output_data = (float *) dst->data;
const half * kernel_data = (const half *) kernel->data;

const int input_w = input->ne[0];
const int input_h = input->ne[1];
const int output_w = dst->ne[0];
const int output_h = dst->ne[1];
const int channels_in = input->ne[2];
const int channels_out = kernel->ne[2];
const int kernel_w = kernel->ne[0];
const int kernel_h = kernel->ne[1];
const int stride = dst->op_params[0];
const int batches = input->ne[3];

GGML_ASSERT(channels_in == kernel->ne[3]);
GGML_ASSERT(stride > 0);

cudaStream_t st = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(input));
GGML_ASSERT(ggml_is_contiguous(kernel));
GGML_ASSERT(ggml_is_contiguous(dst));

const int total = (output_w * output_h * channels_out * batches);
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;

conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
channels_in, channels_out, batches);
}
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/conv2d-transpose.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "common.cuh"

#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/conv2d-transpose.cuh"
#include "ggml-cuda/convert.cuh"
#include "ggml-cuda/count-equal.cuh"
#include "ggml-cuda/cpy.cuh"
Expand Down Expand Up @@ -2314,6 +2315,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_2D:
ggml_cuda_conv_2d_transpose_p0(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_cuda_op_conv_transpose_1d(ctx,dst);
break;
Expand Down Expand Up @@ -3214,6 +3218,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
}
case GGML_OP_IM2COL:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
Expand Down
34 changes: 34 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2725,6 +2725,35 @@ struct test_conv_transpose_1d : public test_case {
}
};

// GGML_OP_CONV_TRANSPOSE_2D
struct test_conv_transpose_2d : public test_case {
const std::array<int64_t, 4> ne_input;
const std::array<int64_t, 4> ne_kernel;
const int stride;

std::string vars() override {
return VARS_TO_STR3(ne_input, ne_kernel, stride);
}

test_conv_transpose_2d(std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
int stride = 1)
: ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
ggml_set_name(input, "input");

ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data());
ggml_set_name(kernel, "kernel");

ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride);
ggml_set_name(out, "out");

return out;
}
};

// GGML_OP_IM2COL
struct test_im2col : public test_case {
const ggml_type type_input;
Expand Down Expand Up @@ -4050,6 +4079,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));

test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));

test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));

Expand Down Expand Up @@ -4618,6 +4650,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));

test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));

return test_cases;
}

Expand Down
Loading