Skip to content

Commit c6e8cc2

Browse files
Aclyggerganov
authored andcommitted
ggml : Depthwise 2D convolution (ggml/1152)
* ggml-cpu : kernels for faster depthwise 2D convolution * fix compile: remove static after moving to ops.cpp * add dilation for depthwise_conv_2d * review: rename to ggml_conv_2d_dw_direct, remove redundant struct keywords, pass by ref, whitespace * review: rename depthwise_conv_2d -> conv_2d_dw everywhere
1 parent b10d8bf commit c6e8cc2

File tree

5 files changed

+250
-3
lines changed

5 files changed

+250
-3
lines changed

ggml/include/ggml.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ extern "C" {
481481
GGML_OP_CONV_TRANSPOSE_1D,
482482
GGML_OP_IM2COL,
483483
GGML_OP_IM2COL_BACK,
484+
GGML_OP_CONV_2D_DW,
484485
GGML_OP_CONV_TRANSPOSE_2D,
485486
GGML_OP_POOL_1D,
486487
GGML_OP_POOL_2D,
@@ -677,6 +678,9 @@ extern "C" {
677678
GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
678679
GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
679680

681+
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
682+
GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
683+
680684
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
681685
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
682686

@@ -1660,7 +1664,7 @@ extern "C" {
16601664
struct ggml_tensor * a,
16611665
struct ggml_tensor * b);
16621666

1663-
// depthwise
1667+
// depthwise (via im2col and mul_mat)
16641668
GGML_API struct ggml_tensor * ggml_conv_2d_dw(
16651669
struct ggml_context * ctx,
16661670
struct ggml_tensor * a, // convolution kernel
@@ -1672,6 +1676,22 @@ extern "C" {
16721676
int d0, // dilation dimension 0
16731677
int d1); // dilation dimension 1
16741678

1679+
// Depthwise 2D convolution
1680+
// may be faster than ggml_conv_2d_dw, but not available in all backends
1681+
// a: KW KH 1 C convolution kernel
1682+
// b: W H C N input data
1683+
// res: W_out H_out C N
1684+
GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct(
1685+
struct ggml_context * ctx,
1686+
struct ggml_tensor * a,
1687+
struct ggml_tensor * b,
1688+
int stride0,
1689+
int stride1,
1690+
int pad0,
1691+
int pad1,
1692+
int dilation0,
1693+
int dilation1);
1694+
16751695
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
16761696
struct ggml_context * ctx,
16771697
struct ggml_tensor * a,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,6 +1932,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19321932
{
19331933
ggml_compute_forward_im2col_back_f32(params, tensor);
19341934
} break;
1935+
case GGML_OP_CONV_2D_DW:
1936+
{
1937+
ggml_compute_forward_conv_2d_dw(params, tensor);
1938+
} break;
19351939
case GGML_OP_CONV_TRANSPOSE_2D:
19361940
{
19371941
ggml_compute_forward_conv_transpose_2d(params, tensor);
@@ -2268,6 +2272,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22682272
} break;
22692273
case GGML_OP_IM2COL:
22702274
case GGML_OP_IM2COL_BACK:
2275+
case GGML_OP_CONV_2D_DW:
22712276
case GGML_OP_CONV_TRANSPOSE_1D:
22722277
case GGML_OP_CONV_TRANSPOSE_2D:
22732278
{

ggml/src/ggml-cpu/ops.cpp

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6064,6 +6064,178 @@ void ggml_compute_forward_conv_transpose_2d(
60646064
}
60656065
}
60666066

6067+
// ggml_compute_forward_conv_2d_dw
6068+
6069+
struct ggml_conv_2d_dw_params {
6070+
int64_t channels;
6071+
int64_t batch;
6072+
int64_t src_w;
6073+
int64_t src_h;
6074+
int64_t dst_w;
6075+
int64_t dst_h;
6076+
int64_t knl_w;
6077+
int64_t knl_h;
6078+
int stride_x;
6079+
int stride_y;
6080+
int pad_x;
6081+
int pad_y;
6082+
int dilation_x;
6083+
int dilation_y;
6084+
};
6085+
6086+
static void ggml_compute_forward_conv_2d_dw_cwhn(
6087+
const ggml_compute_params * params,
6088+
const ggml_tensor * src,
6089+
const ggml_tensor * kernel,
6090+
ggml_tensor * dst,
6091+
const ggml_conv_2d_dw_params & p) {
6092+
6093+
const int64_t c = p.channels;
6094+
const float * knl_data = (const float *)kernel->data;
6095+
6096+
const int64_t rows_total = p.dst_h * p.batch;
6097+
const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
6098+
const int64_t row_start = params->ith * rows_per_thread;
6099+
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
6100+
6101+
#ifdef GGML_SIMD
6102+
const int64_t pkg_size = GGML_F32_EPR;
6103+
const int64_t pkg_count = c / pkg_size;
6104+
const int64_t c_pkg_end = pkg_count * pkg_size;
6105+
#else
6106+
const int64_t c_pkg_end = 0;
6107+
#endif
6108+
6109+
for (int64_t row = row_start; row < row_end; ++row) {
6110+
const int64_t dst_y = row % p.dst_h;
6111+
const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;
6112+
for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
6113+
float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
6114+
const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;
6115+
const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;
6116+
6117+
#ifdef GGML_SIMD
6118+
// Vectorized loop
6119+
for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
6120+
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
6121+
for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
6122+
const int64_t src_y = src_y_base + knl_y * p.dilation_y;
6123+
if (src_y < 0 || src_y >= p.src_h) {
6124+
continue;
6125+
}
6126+
for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
6127+
const int64_t src_x = src_x_base + knl_x * p.dilation_x;
6128+
if (src_x < 0 || src_x >= p.src_w) {
6129+
continue;
6130+
}
6131+
GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
6132+
GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i);
6133+
sum = GGML_F32_VEC_FMA(sum, k, s);
6134+
}
6135+
}
6136+
GGML_F32_VEC_STORE(dst_data + c_i, sum);
6137+
}
6138+
#endif
6139+
// Scalar loop
6140+
for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
6141+
float sum = 0.0f;
6142+
for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
6143+
const int64_t src_y = src_y_base + knl_y * p.dilation_y;
6144+
if (src_y < 0 || src_y >= p.src_h) {
6145+
continue;
6146+
}
6147+
for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
6148+
const int64_t src_x = src_x_base + knl_x * p.dilation_x;
6149+
if (src_x < 0 || src_x >= p.src_w) {
6150+
continue;
6151+
}
6152+
sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
6153+
* src_data[(src_y * p.src_w + src_x) * c + c_i];
6154+
}
6155+
}
6156+
dst_data[c_i] = sum;
6157+
}
6158+
}
6159+
}
6160+
}
6161+
6162+
static void ggml_compute_forward_conv_2d_dw_whcn(
6163+
const ggml_compute_params * params,
6164+
const ggml_tensor * src,
6165+
const ggml_tensor * kernel,
6166+
ggml_tensor * dst,
6167+
const ggml_conv_2d_dw_params & p) {
6168+
6169+
const int64_t n = p.channels * p.batch;
6170+
const int64_t per_thread = (n + params->nth - 1) / params->nth;
6171+
const int64_t start = params->ith * per_thread;
6172+
const int64_t end = MIN(start + per_thread, n);
6173+
6174+
for (int64_t i = start; i < end; ++i) {
6175+
const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;
6176+
const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;
6177+
float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;
6178+
6179+
for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {
6180+
for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
6181+
6182+
float sum = 0.0f;
6183+
for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
6184+
const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
6185+
if (src_y < 0 || src_y >= p.src_h) {
6186+
continue;
6187+
}
6188+
for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
6189+
const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
6190+
if (src_x < 0 || src_x >= p.src_w) {
6191+
continue;
6192+
}
6193+
sum += knl_data[knl_y * p.knl_w + knl_x]
6194+
* src_data[src_y * p.src_w + src_x];
6195+
}
6196+
}
6197+
dst_data[dst_y * p.dst_w + dst_x] = sum;
6198+
}
6199+
}
6200+
}
6201+
}
6202+
6203+
void ggml_compute_forward_conv_2d_dw(
6204+
const ggml_compute_params * params,
6205+
ggml_tensor * dst) {
6206+
6207+
const ggml_tensor * kernel = dst->src[0];
6208+
const ggml_tensor * src = dst->src[1];
6209+
ggml_conv_2d_dw_params p;
6210+
p.channels = src->ne[2];
6211+
p.batch = src->ne[3];
6212+
p.src_w = src->ne[0];
6213+
p.src_h = src->ne[1];
6214+
p.dst_w = dst->ne[0];
6215+
p.dst_h = dst->ne[1];
6216+
p.knl_w = kernel->ne[0];
6217+
p.knl_h = kernel->ne[1];
6218+
p.stride_x = dst->op_params[0];
6219+
p.stride_y = dst->op_params[1];
6220+
p.pad_x = dst->op_params[2];
6221+
p.pad_y = dst->op_params[3];
6222+
p.dilation_x = dst->op_params[4];
6223+
p.dilation_y = dst->op_params[5];
6224+
6225+
GGML_ASSERT(kernel->ne[3] == p.channels);
6226+
GGML_ASSERT(dst->ne[3] == p.batch);
6227+
6228+
if (ggml_is_contiguous(src)) {
6229+
ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);
6230+
} else if (ggml_is_contiguous_channels(src)) {
6231+
// kernel should also have channels most contiguous in memory
6232+
GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);
6233+
ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);
6234+
} else {
6235+
GGML_ABORT("non-contiguous memory layout not supported");
6236+
}
6237+
}
6238+
60676239
// ggml_compute_forward_pool_1d_sk_p0
60686240

60696241
static void ggml_compute_forward_pool_1d_sk_p0(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p
6565
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
6666
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
6767
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
68+
void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
6869
void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
6970
void ggml_compute_forward_pool_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7071
void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml.c

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
956956
"CONV_TRANSPOSE_1D",
957957
"IM2COL",
958958
"IM2COL_BACK",
959+
"CONV_2D_DW",
959960
"CONV_TRANSPOSE_2D",
960961
"POOL_1D",
961962
"POOL_2D",
@@ -993,7 +994,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
993994
"OPT_STEP_ADAMW",
994995
};
995996

996-
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
997+
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
997998

998999
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
9991000
"none",
@@ -1050,6 +1051,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10501051
"conv_transpose_1d(x)",
10511052
"im2col(x)",
10521053
"im2col_back(x)",
1054+
"conv_2d_dw(x)",
10531055
"conv_transpose_2d(x)",
10541056
"pool_1d(x)",
10551057
"pool_2d(x)",
@@ -1087,7 +1089,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10871089
"adamw(x)",
10881090
};
10891091

1090-
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
1092+
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
10911093

10921094
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
10931095

@@ -1344,6 +1346,13 @@ bool ggml_is_permuted(const struct ggml_tensor * tensor) {
13441346
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
13451347
}
13461348

1349+
bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
1350+
return
1351+
tensor->nb[0] > tensor->nb[2] &&
1352+
tensor->nb[1] > tensor->nb[0] &&
1353+
tensor->nb[2] == ggml_type_size(tensor->type);
1354+
}
1355+
13471356
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
13481357
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
13491358

@@ -4050,6 +4059,46 @@ struct ggml_tensor * ggml_conv_2d_dw(
40504059
return result;
40514060
}
40524061

4062+
// ggml_conv_2d_dw_direct
4063+
4064+
struct ggml_tensor * ggml_conv_2d_dw_direct(
4065+
struct ggml_context * ctx,
4066+
struct ggml_tensor * a,
4067+
struct ggml_tensor * b,
4068+
int stride0,
4069+
int stride1,
4070+
int pad0,
4071+
int pad1,
4072+
int dilation0,
4073+
int dilation1) {
4074+
GGML_ASSERT(a->ne[2] == 1);
4075+
GGML_ASSERT(a->ne[3] == b->ne[2]);
4076+
int64_t ne[4];
4077+
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
4078+
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
4079+
ne[2] = b->ne[2];
4080+
ne[3] = b->ne[3];
4081+
4082+
struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4083+
4084+
if (ggml_is_contiguous_channels(b)) {
4085+
// Result will be permuted the same way as input (CWHN order)
4086+
const int64_t type_size = ggml_type_size(result->type);
4087+
GGML_ASSERT(ggml_blck_size(result->type) == 1);
4088+
result->nb[0] = result->ne[2] * type_size;
4089+
result->nb[1] = result->ne[0] * result->nb[0];
4090+
result->nb[2] = type_size;
4091+
}
4092+
4093+
int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
4094+
ggml_set_op_params(result, params, sizeof(params));
4095+
4096+
result->op = GGML_OP_CONV_2D_DW;
4097+
result->src[0] = a;
4098+
result->src[1] = b;
4099+
return result;
4100+
}
4101+
40534102
// ggml_conv_transpose_2d_p0
40544103

40554104
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {

0 commit comments

Comments
 (0)