Skip to content

Commit a6196ee

Browse files
committed
CUDA: add conv_2d_transpose
1 parent 10bb545 commit a6196ee

File tree

4 files changed

+135
-0
lines changed

4 files changed

+135
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#include <cuda_fp16.h>
2+
3+
#include <algorithm>
4+
5+
#include "conv2d-transpose.cuh"
6+
#include "ggml.h"
7+
8+
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
9+
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
10+
const int out_h, const int kernel_w, const int kernel_h, const int stride,
11+
const int c_in, const int c_out, const int batches) {
12+
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
13+
14+
const int total_elements = out_w * out_h * c_out * batches;
15+
16+
if (global_idx >= total_elements) {
17+
return;
18+
}
19+
20+
const int out_x_idx = global_idx % out_w;
21+
const int out_y_idx = global_idx / out_w % out_h;
22+
const int c_idx = global_idx / (out_w * out_h) % c_out;
23+
const int n_idx = global_idx / (out_w * out_h * c_out);
24+
25+
float accumulator = 0;
26+
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds
27+
28+
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
29+
for (int kh = 0; kh < kernel_h; ++kh) {
30+
int in_y = out_y_idx - kh;
31+
if (in_y < 0 || in_y % stride) continue;
32+
in_y /= stride;
33+
if (in_y >= in_h) continue;
34+
35+
for (int kw = 0; kw < kernel_w; ++kw) {
36+
int in_x = out_x_idx - kw;
37+
if (in_x < 0 || in_x % stride) continue;
38+
in_x /= stride;
39+
if (in_x >= in_w) continue;
40+
41+
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
42+
const int kernel_idx =
43+
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
44+
45+
float input_val = input[input_idx];
46+
half kern_val = kernel[kernel_idx];
47+
48+
accumulator += input_val * (float) kern_val;
49+
}
50+
}
51+
}
52+
53+
output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
54+
}
55+
56+
//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
57+
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
58+
const ggml_tensor * kernel = dst->src[0];
59+
const ggml_tensor * input = dst->src[1];
60+
61+
GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
62+
63+
const float * input_data = (const float *) input->data;
64+
float * output_data = (float *) dst->data;
65+
const half * kernel_data = (const half *) kernel->data;
66+
67+
const int input_w = input->ne[0];
68+
const int input_h = input->ne[1];
69+
const int output_w = dst->ne[0];
70+
const int output_h = dst->ne[1];
71+
const int channels_in = input->ne[2];
72+
const int channels_out = kernel->ne[2];
73+
const int kernel_w = kernel->ne[0];
74+
const int kernel_h = kernel->ne[1];
75+
const int stride = dst->op_params[0];
76+
const int batches = input->ne[3];
77+
78+
GGML_ASSERT(channels_in == kernel->ne[3]);
79+
GGML_ASSERT(stride > 0);
80+
81+
cudaStream_t st = ctx.stream();
82+
83+
const int total = (output_w * output_h * channels_out * batches);
84+
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
85+
86+
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
87+
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
88+
channels_in, channels_out, batches);
89+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256
4+
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "ggml-cuda/clamp.cuh"
1212
#include "ggml-cuda/concat.cuh"
1313
#include "ggml-cuda/conv-transpose-1d.cuh"
14+
#include "ggml-cuda/conv2d-transpose.cuh"
1415
#include "ggml-cuda/convert.cuh"
1516
#include "ggml-cuda/count-equal.cuh"
1617
#include "ggml-cuda/cpy.cuh"
@@ -2310,6 +2311,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23102311
case GGML_OP_IM2COL:
23112312
ggml_cuda_op_im2col(ctx, dst);
23122313
break;
2314+
case GGML_OP_CONV_TRANSPOSE_2D:
2315+
ggml_cuda_conv_2d_transpose_p0(ctx, dst);
2316+
break;
23132317
case GGML_OP_CONV_TRANSPOSE_1D:
23142318
ggml_cuda_op_conv_transpose_1d(ctx,dst);
23152319
break;
@@ -3209,6 +3213,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32093213
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
32103214
}
32113215
case GGML_OP_IM2COL:
3216+
case GGML_OP_CONV_TRANSPOSE_2D:
32123217
case GGML_OP_POOL_2D:
32133218
case GGML_OP_SUM:
32143219
case GGML_OP_SUM_ROWS:

tests/test-backend-ops.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ struct test_case {
560560
}
561561

562562
double err = nmse(f1.data(), f2.data(), f1.size());
563+
563564
if (err > ud->max_err) {
564565
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
565566
//for (int i = 0; i < (int) f1.size(); i++) {
@@ -2725,6 +2726,37 @@ struct test_conv_transpose_1d : public test_case {
27252726
}
27262727
};
27272728

2729+
// GGML_OP_CONV_TRANSPOSE_2D
2730+
struct test_conv_transpose_2d : public test_case {
2731+
const std::array<int64_t, 4> ne_input;
2732+
const std::array<int64_t, 4> ne_kernel;
2733+
const int stride;
2734+
2735+
std::string vars() override {
2736+
return VARS_TO_STR3(ne_input, ne_kernel, stride);
2737+
}
2738+
2739+
test_conv_transpose_2d(std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
2740+
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
2741+
int stride = 1)
2742+
: ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){}
2743+
2744+
ggml_tensor * build_graph(ggml_context * ctx) override {
2745+
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
2746+
ggml_set_param(input);
2747+
ggml_set_name(input, "input");
2748+
2749+
ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data());
2750+
ggml_set_param(kernel);
2751+
ggml_set_name(kernel, "kernel");
2752+
2753+
ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride);
2754+
ggml_set_name(out, "out");
2755+
2756+
return out;
2757+
}
2758+
};
2759+
27282760
// GGML_OP_IM2COL
27292761
struct test_im2col : public test_case {
27302762
const ggml_type type_input;
@@ -4050,6 +4082,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
40504082
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
40514083
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
40524084

4085+
test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1));
4086+
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
4087+
40534088
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
40544089
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
40554090

@@ -4618,6 +4653,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
46184653
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
46194654
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
46204655

4656+
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
4657+
46214658
return test_cases;
46224659
}
46234660

0 commit comments

Comments
 (0)