Skip to content

Commit 2c60d2c

Browse files
committed
simplify using template
1 parent 6eb7fbb commit 2c60d2c

File tree

1 file changed

+20
-61
lines changed

1 file changed

+20
-61
lines changed

ggml/src/ggml-cuda/conv2d-dw.cu

Lines changed: 20 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ __device__ inline int calculate_input_coord(int out_coord, int kern_coord, int s
3232
return out_coord * stride + kern_coord * dilation - padding;
3333
}
3434

35-
// ───────────── Memory layout abstractions ─────────────
36-
3735
struct whcn_layout {
3836
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
3937
return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
@@ -80,40 +78,12 @@ struct cwhn_layout {
8078
}
8179
};
8280

83-
// ───────────── Generic convolution computation ─────────────
84-
8581
template <typename T, typename Layout>
86-
const __device__ inline T compute_conv2d_dw_pixel(const T * __restrict__ input, const T * __restrict__ kernel,
87-
const conv_params & params, int batch_idx, int channel_idx,
88-
int out_y_idx, int out_x_idx) {
89-
T accumulator = 0;
90-
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
91-
92-
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
93-
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
94-
95-
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
96-
int in_x_idx =
97-
calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
98-
99-
const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
100-
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
101-
102-
accumulator += input_val * kernel_val;
103-
}
104-
}
105-
106-
return accumulator;
107-
}
108-
109-
// ───────────── Kernel instantiations ─────────────
110-
111-
template <typename T>
112-
__global__ void conv2d_dw_whcn_kernel(const T * __restrict__ in, const T * __restrict__ kern, T * __restrict__ out,
113-
const int in_w, const int in_h, const int out_w, const int out_h,
114-
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
115-
const int padding_x, const int padding_y, const int dilation_x,
116-
const int dilation_y, const int channels, const int batches) {
82+
__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
83+
const int in_w, const int in_h, const int out_w, const int out_h,
84+
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
85+
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
86+
const int channels, const int batches) {
11787
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
11888
int total_elements = batches * channels * out_h * out_w;
11989

@@ -125,42 +95,31 @@ __global__ void conv2d_dw_whcn_kernel(const T * __restrict__ in, const T * __res
12595
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
12696

12797
int batch_idx, channel_idx, out_y_idx, out_x_idx;
128-
whcn_layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
98+
Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
12999

130-
T result = compute_conv2d_dw_pixel<T, whcn_layout>(in, kern, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
131-
out[whcn_layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = result;
132-
}
100+
T accumulator = 0;
101+
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
133102

134-
template <typename T>
135-
__global__ void conv_2d_dw_cwhn_kernel(const T * __restrict__ in, const T * __restrict__ kern, T * __restrict__ out,
136-
const int in_w, const int in_h, const int out_w, const int out_h,
137-
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
138-
const int padding_x, const int padding_y, const int dilation_x,
139-
const int dilation_y, const int channels, const int batches) {
140-
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
141-
int total_elements = batches * channels * out_h * out_w;
103+
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
104+
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
142105

143-
if (global_idx >= total_elements) {
144-
return;
145-
}
106+
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
107+
int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
146108

147-
conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
148-
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
109+
const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
110+
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
149111

150-
int batch_idx, channel_idx, out_y_idx, out_x_idx;
151-
cwhn_layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
112+
accumulator += input_val * kernel_val;
113+
}
114+
}
152115

153-
const T result =
154-
compute_conv2d_dw_pixel<T, cwhn_layout>(in, kern, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
155-
out[cwhn_layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = result;
116+
output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
156117
}
157118

158-
// ───────────── dispatcher ─────────────
159119
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160120
const ggml_tensor * kernel = dst->src[0];
161121
const ggml_tensor * input = dst->src[1];
162122

163-
// Only F32→F32 for now
164123
GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
165124
const float * w_d = (const float *) kernel->data;
166125
const float * x_d = (const float *) input->data;
@@ -189,11 +148,11 @@ void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
189148
const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
190149

191150
if (ggml_is_contiguous(input)) {
192-
conv2d_dw_whcn_kernel<<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
151+
conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
193152
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
194153
dilation_x, dilation_y, channels, batches);
195154
} else if (ggml_is_contiguous_channels(input)) {
196-
conv_2d_dw_cwhn_kernel<<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
155+
conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
197156
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
198157
dilation_x, dilation_y, channels, batches);
199158
} else {

0 commit comments

Comments
 (0)