|
| 1 | +#include "conv2d-dw.cuh" |
| 2 | + |
| 3 | +struct conv_params { |
| 4 | + int in_w, in_h; |
| 5 | + int out_w, out_h; |
| 6 | + int kernel_w, kernel_h; |
| 7 | + int stride_x, stride_y; |
| 8 | + int padding_x, padding_y; |
| 9 | + int dilation_x, dilation_y; |
| 10 | + int channels, batches; |
| 11 | +}; |
| 12 | + |
| 13 | +struct kernel_bounds { |
| 14 | + int y_min, y_max; |
| 15 | + int x_min, x_max; |
| 16 | +}; |
| 17 | + |
| 18 | +__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) { |
| 19 | + kernel_bounds bounds; |
| 20 | + bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y); |
| 21 | + bounds.y_max = |
| 22 | + min(params.kernel_h, |
| 23 | + (params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y); |
| 24 | + bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x); |
| 25 | + bounds.x_max = |
| 26 | + min(params.kernel_w, |
| 27 | + (params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x); |
| 28 | + return bounds; |
| 29 | +} |
| 30 | + |
| 31 | +__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) { |
| 32 | + return out_coord * stride + kern_coord * dilation - padding; |
| 33 | +} |
| 34 | + |
| 35 | +struct whcn_layout { |
| 36 | + __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) { |
| 37 | + return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x; |
| 38 | + } |
| 39 | + |
| 40 | + __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) { |
| 41 | + return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx; |
| 42 | + } |
| 43 | + |
| 44 | + __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) { |
| 45 | + return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h + |
| 46 | + y * params.out_w + x; |
| 47 | + } |
| 48 | + |
| 49 | + __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y, |
| 50 | + int & out_x) { |
| 51 | + out_x = global_idx % params.out_w; |
| 52 | + out_y = (global_idx / params.out_w) % params.out_h; |
| 53 | + c = (global_idx / (params.out_w * params.out_h)) % params.channels; |
| 54 | + n = global_idx / (params.out_w * params.out_h * params.channels); |
| 55 | + } |
| 56 | +}; |
| 57 | + |
| 58 | +struct cwhn_layout { |
| 59 | + __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) { |
| 60 | + return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c; |
| 61 | + } |
| 62 | + |
| 63 | + __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) { |
| 64 | + return (ky * params.kernel_w + kx) * params.channels + c; |
| 65 | + } |
| 66 | + |
| 67 | + __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) { |
| 68 | + return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) + |
| 69 | + x * params.channels + c; |
| 70 | + } |
| 71 | + |
| 72 | + __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y, |
| 73 | + int & out_x) { |
| 74 | + c = global_idx % params.channels; |
| 75 | + out_x = (global_idx / params.channels) % params.out_w; |
| 76 | + out_y = (global_idx / (params.channels * params.out_w)) % params.out_h; |
| 77 | + n = global_idx / (params.channels * params.out_w * params.out_h); |
| 78 | + } |
| 79 | +}; |
| 80 | + |
| 81 | +template <typename T, typename Layout> |
| 82 | +__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output, |
| 83 | + const int in_w, const int in_h, const int out_w, const int out_h, |
| 84 | + const int kernel_w, const int kernel_h, const int stride_x, const int stride_y, |
| 85 | + const int padding_x, const int padding_y, const int dilation_x, const int dilation_y, |
| 86 | + const int channels, const int batches) { |
| 87 | + const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 88 | + const int total_elements = batches * channels * out_h * out_w; |
| 89 | + |
| 90 | + if (global_idx >= total_elements) { |
| 91 | + return; |
| 92 | + } |
| 93 | + |
| 94 | + conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, |
| 95 | + stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches }; |
| 96 | + |
| 97 | + int batch_idx, channel_idx, out_y_idx, out_x_idx; |
| 98 | + Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx); |
| 99 | + |
| 100 | + T accumulator = 0; |
| 101 | + kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params); |
| 102 | + |
| 103 | + for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) { |
| 104 | + int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y); |
| 105 | + |
| 106 | + for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) { |
| 107 | + int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x); |
| 108 | + |
| 109 | + const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)]; |
| 110 | + const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)]; |
| 111 | + |
| 112 | + accumulator += input_val * kernel_val; |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator; |
| 117 | +} |
| 118 | + |
| 119 | +void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 120 | + const ggml_tensor * kernel = dst->src[0]; |
| 121 | + const ggml_tensor * input = dst->src[1]; |
| 122 | + |
| 123 | + GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); |
| 124 | + const float * w_d = (const float *) kernel->data; |
| 125 | + const float * x_d = (const float *) input->data; |
| 126 | + float * y_d = (float *) dst->data; |
| 127 | + |
| 128 | + const int32_t * p = (const int32_t *) dst->op_params; |
| 129 | + const int stride_x = p[0]; |
| 130 | + const int stride_y = p[1]; |
| 131 | + const int padding_x = p[2]; |
| 132 | + const int padding_y = p[3]; |
| 133 | + const int dilation_x = p[4]; |
| 134 | + const int dilation_y = p[5]; |
| 135 | + |
| 136 | + const int in_w = input->ne[0]; |
| 137 | + const int in_h = input->ne[1]; |
| 138 | + const int kernel_w = kernel->ne[0]; |
| 139 | + const int kernel_h = kernel->ne[1]; |
| 140 | + const int out_w = dst->ne[0]; |
| 141 | + const int out_h = dst->ne[1]; |
| 142 | + const int channels = dst->ne[2]; |
| 143 | + const int batches = dst->ne[3]; |
| 144 | + |
| 145 | + cudaStream_t st = ctx.stream(); |
| 146 | + |
| 147 | + const int total = batches * channels * out_h * out_w; |
| 148 | + const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE; |
| 149 | + |
| 150 | + if (ggml_is_contiguous(input)) { |
| 151 | + conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>( |
| 152 | + x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y, |
| 153 | + dilation_x, dilation_y, channels, batches); |
| 154 | + } else if (ggml_is_contiguous_channels(input)) { |
| 155 | + conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>( |
| 156 | + x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y, |
| 157 | + dilation_x, dilation_y, channels, batches); |
| 158 | + } else { |
| 159 | + GGML_ABORT("Unsupported memory layout for conv_2d_dw"); |
| 160 | + } |
| 161 | +} |
0 commit comments