Skip to content

Commit 62733f2

Browse files
committed
vulkan: improve im2col performance
1 parent 19d3c82 commit 62733f2

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,20 @@ void main() {
4040
const uint batch = gl_GlobalInvocationID.z / p.IC;
4141
const uint ic = gl_GlobalInvocationID.z % p.IC;
4242

43+
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
44+
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
45+
const int oh_s1 = int(oh) * p.s1;
46+
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
47+
48+
const uint base_linear_idx = gidx * NUM_ITER;
49+
50+
const uint max_ky = ksize / p.OW;
51+
52+
uint current_kx = base_linear_idx / ksize;
53+
const uint rem = base_linear_idx - (current_kx * ksize);
54+
uint current_ky = rem / p.OW;
55+
uint current_ix = rem % p.OW;
56+
4357
A_TYPE values[NUM_ITER];
4458
uint offset_dst[NUM_ITER];
4559
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
@@ -48,36 +62,36 @@ void main() {
4862

4963
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
5064

51-
const uint i = gidx * NUM_ITER + idx;
65+
const uint linear_idx = base_linear_idx + idx;
5266

53-
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
54-
const uint kx = i / ksize;
55-
const uint kd = kx * ksize;
56-
const uint ky = (i - kd) / p.OW;
57-
const uint ix = i % p.OW;
67+
if (linear_idx >= p.pelements) {
68+
continue;
69+
}
5870

59-
const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
60-
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
71+
const int iiw = int(current_ix) * p.s0 + int(current_kx) * p.d0 - p.p0;
72+
const int iih = oh_s1 + int(current_ky) * p.d1 - p.p1;
6173

62-
offset_dst[idx] =
63-
((batch * p.OH + oh) * p.OW + ix) * p.CHW +
64-
(ic * (p.KW * p.KH) + ky * p.KW + kx);
74+
offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
6575

66-
if (i >= p.pelements) {
67-
continue;
76+
const bool valid = (iih >= 0 && iih < int(p.IH)) && (iiw >= 0 && iiw < int(p.IW));
77+
if (valid) {
78+
values[idx] = data_a[src_base + uint(iih) * p.IW + uint(iiw)];
6879
}
6980

70-
if (iih < p.IH && iiw < p.IW) {
71-
const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
72-
values[idx] = data_a[offset_src + iih * p.IW + iiw];
81+
if (++current_ix == p.OW) {
82+
current_ix = 0;
83+
if (++current_ky == max_ky) {
84+
current_ky = 0;
85+
current_kx++;
86+
}
7387
}
7488
}
7589

7690
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
7791

78-
const uint i = gidx * NUM_ITER + idx;
92+
const uint linear_idx = base_linear_idx + idx;
7993

80-
if (i >= p.pelements) {
94+
if (linear_idx >= p.pelements) {
8195
continue;
8296
}
8397

0 commit comments

Comments
 (0)