@@ -32,8 +32,6 @@ __device__ inline int calculate_input_coord(int out_coord, int kern_coord, int s
32
32
return out_coord * stride + kern_coord * dilation - padding;
33
33
}
34
34
35
- // ───────────── Memory layout abstractions ─────────────
36
-
37
35
struct whcn_layout {
38
36
__device__ static int input_index (int n, int c, int y, int x, const conv_params & params) {
39
37
return n * (params.channels * params.in_w * params.in_h ) + c * params.in_w * params.in_h + y * params.in_w + x;
@@ -80,40 +78,12 @@ struct cwhn_layout {
80
78
}
81
79
};
82
80
83
- // ───────────── Generic convolution computation ─────────────
84
-
85
81
template <typename T, typename Layout>
86
- const __device__ inline T compute_conv2d_dw_pixel (const T * __restrict__ input, const T * __restrict__ kernel,
87
- const conv_params & params, int batch_idx, int channel_idx,
88
- int out_y_idx, int out_x_idx) {
89
- T accumulator = 0 ;
90
- kernel_bounds bounds = calculate_kernel_bounds (out_x_idx, out_y_idx, params);
91
-
92
- for (int kern_y = bounds.y_min ; kern_y < bounds.y_max ; ++kern_y) {
93
- int in_y_idx = calculate_input_coord (out_y_idx, kern_y, params.stride_y , params.dilation_y , params.padding_y );
94
-
95
- for (int kern_x = bounds.x_min ; kern_x < bounds.x_max ; ++kern_x) {
96
- int in_x_idx =
97
- calculate_input_coord (out_x_idx, kern_x, params.stride_x , params.dilation_x , params.padding_x );
98
-
99
- const T input_val = input[Layout::input_index (batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
100
- const T kernel_val = kernel[Layout::kernel_index (channel_idx, kern_y, kern_x, params)];
101
-
102
- accumulator += input_val * kernel_val;
103
- }
104
- }
105
-
106
- return accumulator;
107
- }
108
-
109
- // ───────────── Kernel instantiations ─────────────
110
-
111
- template <typename T>
112
- __global__ void conv2d_dw_whcn_kernel (const T * __restrict__ in, const T * __restrict__ kern, T * __restrict__ out,
113
- const int in_w, const int in_h, const int out_w, const int out_h,
114
- const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
115
- const int padding_x, const int padding_y, const int dilation_x,
116
- const int dilation_y, const int channels, const int batches) {
82
+ __global__ void conv2d_dw_kernel (const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
83
+ const int in_w, const int in_h, const int out_w, const int out_h,
84
+ const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
85
+ const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
86
+ const int channels, const int batches) {
117
87
int global_idx = blockIdx .x * blockDim .x + threadIdx .x ;
118
88
int total_elements = batches * channels * out_h * out_w;
119
89
@@ -125,42 +95,31 @@ __global__ void conv2d_dw_whcn_kernel(const T * __restrict__ in, const T * __res
125
95
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
126
96
127
97
int batch_idx, channel_idx, out_y_idx, out_x_idx;
128
- whcn_layout ::unpack_indices (global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
98
+ Layout ::unpack_indices (global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
129
99
130
- T result = compute_conv2d_dw_pixel<T, whcn_layout>(in, kern, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
131
- out[whcn_layout::output_index (batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = result;
132
- }
100
+ T accumulator = 0 ;
101
+ kernel_bounds bounds = calculate_kernel_bounds (out_x_idx, out_y_idx, params);
133
102
134
- template <typename T>
135
- __global__ void conv_2d_dw_cwhn_kernel (const T * __restrict__ in, const T * __restrict__ kern, T * __restrict__ out,
136
- const int in_w, const int in_h, const int out_w, const int out_h,
137
- const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
138
- const int padding_x, const int padding_y, const int dilation_x,
139
- const int dilation_y, const int channels, const int batches) {
140
- int global_idx = blockIdx .x * blockDim .x + threadIdx .x ;
141
- int total_elements = batches * channels * out_h * out_w;
103
+ for (int kern_y = bounds.y_min ; kern_y < bounds.y_max ; ++kern_y) {
104
+ int in_y_idx = calculate_input_coord (out_y_idx, kern_y, params.stride_y , params.dilation_y , params.padding_y );
142
105
143
- if (global_idx >= total_elements) {
144
- return ;
145
- }
106
+ for (int kern_x = bounds.x_min ; kern_x < bounds.x_max ; ++kern_x) {
107
+ int in_x_idx = calculate_input_coord (out_x_idx, kern_x, params.stride_x , params.dilation_x , params.padding_x );
146
108
147
- conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
148
- stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches } ;
109
+ const T input_val = input[ Layout::input_index (batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
110
+ const T kernel_val = kernel[ Layout::kernel_index (channel_idx, kern_y, kern_x, params)] ;
149
111
150
- int batch_idx, channel_idx, out_y_idx, out_x_idx;
151
- cwhn_layout::unpack_indices (global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
112
+ accumulator += input_val * kernel_val;
113
+ }
114
+ }
152
115
153
- const T result =
154
- compute_conv2d_dw_pixel<T, cwhn_layout>(in, kern, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
155
- out[cwhn_layout::output_index (batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = result;
116
+ output[Layout::output_index (batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
156
117
}
157
118
158
- // ───────────── dispatcher ─────────────
159
119
void ggml_cuda_op_conv2d_dw (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160
120
const ggml_tensor * kernel = dst->src [0 ];
161
121
const ggml_tensor * input = dst->src [1 ];
162
122
163
- // Only F32→F32 for now
164
123
GGML_ASSERT (kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
165
124
const float * w_d = (const float *) kernel->data ;
166
125
const float * x_d = (const float *) input->data ;
@@ -189,11 +148,11 @@ void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
189
148
const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1 ) / CUDA_CONV2D_DW_BLOCK_SIZE;
190
149
191
150
if (ggml_is_contiguous (input)) {
192
- conv2d_dw_whcn_kernel <<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0 , st>>> (
151
+ conv2d_dw_kernel< float , whcn_layout> <<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0 , st>>> (
193
152
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
194
153
dilation_x, dilation_y, channels, batches);
195
154
} else if (ggml_is_contiguous_channels (input)) {
196
- conv_2d_dw_cwhn_kernel <<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0 , st>>> (
155
+ conv2d_dw_kernel< float , cwhn_layout> <<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0 , st>>> (
197
156
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
198
157
dilation_x, dilation_y, channels, batches);
199
158
} else {
0 commit comments