-
Notifications
You must be signed in to change notification settings - Fork 12.2k
CUDA: add conv_2d_transpose #14287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
CUDA: add conv_2d_transpose #14287
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#include <algorithm> | ||
|
||
#include "conv2d-transpose.cuh" | ||
#include "ggml.h" | ||
|
||
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel, | ||
float * __restrict__ output, const int in_w, const int in_h, const int out_w, | ||
const int out_h, const int kernel_w, const int kernel_h, const int stride, | ||
const int c_in, const int c_out, const int batches) { | ||
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
const int total_elements = out_w * out_h * c_out * batches; | ||
|
||
if (global_idx >= total_elements) { | ||
return; | ||
} | ||
|
||
const int out_x_idx = global_idx % out_w; | ||
const int out_y_idx = (global_idx / out_w) % out_h; | ||
const int c_idx = (global_idx / (out_w * out_h)) % c_out; | ||
const int n_idx = global_idx / (out_w * out_h * c_out); | ||
|
||
float accumulator = 0; | ||
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds | ||
|
||
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) { | ||
for (int kh = 0; kh < kernel_h; ++kh) { | ||
int in_y = out_y_idx - kh; | ||
if (in_y < 0 || in_y % stride) continue; | ||
in_y /= stride; | ||
if (in_y >= in_h) continue; | ||
|
||
for (int kw = 0; kw < kernel_w; ++kw) { | ||
int in_x = out_x_idx - kw; | ||
if (in_x < 0 || in_x % stride) continue; | ||
in_x /= stride; | ||
if (in_x >= in_w) continue; | ||
|
||
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x; | ||
const int kernel_idx = | ||
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw; | ||
|
||
float input_val = input[input_idx]; | ||
half kern_val = kernel[kernel_idx]; | ||
|
||
accumulator += input_val * (float) kern_val; | ||
} | ||
} | ||
} | ||
|
||
output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator; | ||
} | ||
|
||
//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in) | ||
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
const ggml_tensor * kernel = dst->src[0]; | ||
const ggml_tensor * input = dst->src[1]; | ||
|
||
GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); | ||
|
||
const float * input_data = (const float *) input->data; | ||
float * output_data = (float *) dst->data; | ||
const half * kernel_data = (const half *) kernel->data; | ||
|
||
const int input_w = input->ne[0]; | ||
const int input_h = input->ne[1]; | ||
const int output_w = dst->ne[0]; | ||
const int output_h = dst->ne[1]; | ||
const int channels_in = input->ne[2]; | ||
const int channels_out = kernel->ne[2]; | ||
const int kernel_w = kernel->ne[0]; | ||
const int kernel_h = kernel->ne[1]; | ||
const int stride = dst->op_params[0]; | ||
const int batches = input->ne[3]; | ||
|
||
GGML_ASSERT(channels_in == kernel->ne[3]); | ||
GGML_ASSERT(stride > 0); | ||
|
||
cudaStream_t st = ctx.stream(); | ||
|
||
GGML_ASSERT(ggml_is_contiguous(input)); | ||
GGML_ASSERT(ggml_is_contiguous(kernel)); | ||
GGML_ASSERT(ggml_is_contiguous(dst)); | ||
|
||
const int total = (output_w * output_h * channels_out * batches); | ||
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE; | ||
|
||
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>( | ||
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride, | ||
channels_in, channels_out, batches); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#include "common.cuh" | ||
|
||
#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256 | ||
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.