Skip to content

Commit c959f46

Browse files
authored
CUDA: add conv_2d_transpose (#14287)
* CUDA: add conv_2d_transpose * remove direct include of cuda_fp16 * Review: add brackets for readability, remove ggml_set_param and add asserts
1 parent 22015b2 commit c959f46

File tree

4 files changed

+134
-0
lines changed

4 files changed

+134
-0
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#include <algorithm>
2+
3+
#include "conv2d-transpose.cuh"
4+
#include "ggml.h"
5+
6+
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
7+
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
8+
const int out_h, const int kernel_w, const int kernel_h, const int stride,
9+
const int c_in, const int c_out, const int batches) {
10+
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
11+
12+
const int total_elements = out_w * out_h * c_out * batches;
13+
14+
if (global_idx >= total_elements) {
15+
return;
16+
}
17+
18+
const int out_x_idx = global_idx % out_w;
19+
const int out_y_idx = (global_idx / out_w) % out_h;
20+
const int c_idx = (global_idx / (out_w * out_h)) % c_out;
21+
const int n_idx = global_idx / (out_w * out_h * c_out);
22+
23+
float accumulator = 0;
24+
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds
25+
26+
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
27+
for (int kh = 0; kh < kernel_h; ++kh) {
28+
int in_y = out_y_idx - kh;
29+
if (in_y < 0 || in_y % stride) continue;
30+
in_y /= stride;
31+
if (in_y >= in_h) continue;
32+
33+
for (int kw = 0; kw < kernel_w; ++kw) {
34+
int in_x = out_x_idx - kw;
35+
if (in_x < 0 || in_x % stride) continue;
36+
in_x /= stride;
37+
if (in_x >= in_w) continue;
38+
39+
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
40+
const int kernel_idx =
41+
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
42+
43+
float input_val = input[input_idx];
44+
half kern_val = kernel[kernel_idx];
45+
46+
accumulator += input_val * (float) kern_val;
47+
}
48+
}
49+
}
50+
51+
output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
52+
}
53+
54+
//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
55+
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
56+
const ggml_tensor * kernel = dst->src[0];
57+
const ggml_tensor * input = dst->src[1];
58+
59+
GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
60+
61+
const float * input_data = (const float *) input->data;
62+
float * output_data = (float *) dst->data;
63+
const half * kernel_data = (const half *) kernel->data;
64+
65+
const int input_w = input->ne[0];
66+
const int input_h = input->ne[1];
67+
const int output_w = dst->ne[0];
68+
const int output_h = dst->ne[1];
69+
const int channels_in = input->ne[2];
70+
const int channels_out = kernel->ne[2];
71+
const int kernel_w = kernel->ne[0];
72+
const int kernel_h = kernel->ne[1];
73+
const int stride = dst->op_params[0];
74+
const int batches = input->ne[3];
75+
76+
GGML_ASSERT(channels_in == kernel->ne[3]);
77+
GGML_ASSERT(stride > 0);
78+
79+
cudaStream_t st = ctx.stream();
80+
81+
GGML_ASSERT(ggml_is_contiguous(input));
82+
GGML_ASSERT(ggml_is_contiguous(kernel));
83+
GGML_ASSERT(ggml_is_contiguous(dst));
84+
85+
const int total = (output_w * output_h * channels_out * batches);
86+
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
87+
88+
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
89+
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
90+
channels_in, channels_out, batches);
91+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256
4+
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "ggml-cuda/concat.cuh"
1313
#include "ggml-cuda/conv-transpose-1d.cuh"
1414
#include "ggml-cuda/conv2d-dw.cuh"
15+
#include "ggml-cuda/conv2d-transpose.cuh"
1516
#include "ggml-cuda/convert.cuh"
1617
#include "ggml-cuda/count-equal.cuh"
1718
#include "ggml-cuda/cpy.cuh"
@@ -2341,6 +2342,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23412342
case GGML_OP_CONV_2D_DW:
23422343
ggml_cuda_op_conv2d_dw(ctx, dst);
23432344
break;
2345+
case GGML_OP_CONV_TRANSPOSE_2D:
2346+
ggml_cuda_conv_2d_transpose_p0(ctx, dst);
2347+
break;
23442348
case GGML_OP_CONV_TRANSPOSE_1D:
23452349
ggml_cuda_op_conv_transpose_1d(ctx,dst);
23462350
break;
@@ -3252,6 +3256,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32523256
}
32533257
case GGML_OP_IM2COL:
32543258
case GGML_OP_CONV_2D_DW:
3259+
case GGML_OP_CONV_TRANSPOSE_2D:
32553260
case GGML_OP_POOL_2D:
32563261
case GGML_OP_SUM:
32573262
case GGML_OP_SUM_ROWS:

tests/test-backend-ops.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2725,6 +2725,35 @@ struct test_conv_transpose_1d : public test_case {
27252725
}
27262726
};
27272727

2728+
// GGML_OP_CONV_TRANSPOSE_2D
2729+
struct test_conv_transpose_2d : public test_case {
2730+
const std::array<int64_t, 4> ne_input;
2731+
const std::array<int64_t, 4> ne_kernel;
2732+
const int stride;
2733+
2734+
std::string vars() override {
2735+
return VARS_TO_STR3(ne_input, ne_kernel, stride);
2736+
}
2737+
2738+
test_conv_transpose_2d(std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
2739+
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
2740+
int stride = 1)
2741+
: ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){}
2742+
2743+
ggml_tensor * build_graph(ggml_context * ctx) override {
2744+
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
2745+
ggml_set_name(input, "input");
2746+
2747+
ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data());
2748+
ggml_set_name(kernel, "kernel");
2749+
2750+
ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride);
2751+
ggml_set_name(out, "out");
2752+
2753+
return out;
2754+
}
2755+
};
2756+
27282757
// GGML_OP_IM2COL
27292758
struct test_im2col : public test_case {
27302759
const ggml_type type_input;
@@ -4050,6 +4079,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
40504079
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
40514080
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
40524081

4082+
test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1));
4083+
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
4084+
40534085
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
40544086
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
40554087

@@ -4618,6 +4650,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
46184650
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
46194651
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
46204652

4653+
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
4654+
46214655
return test_cases;
46224656
}
46234657

0 commit comments

Comments
 (0)