Skip to content

Commit 6aa392d

Browse files
daniandthewebmglambda
authored andcommitted
vulkan: improve im2col (ggml-org#11826)
* vulkan: improve im2col performance
1 parent 1063c4a commit 6aa392d

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,20 @@ void main() {
4040
const uint batch = gl_GlobalInvocationID.z / p.IC;
4141
const uint ic = gl_GlobalInvocationID.z % p.IC;
4242

43+
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
44+
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
45+
const int oh_s1 = int(oh) * p.s1;
46+
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
47+
48+
const uint base_linear_idx = gidx * NUM_ITER;
49+
50+
const uint max_ky = ksize / p.OW;
51+
52+
uint current_kx = base_linear_idx / ksize;
53+
const uint rem = base_linear_idx - (current_kx * ksize);
54+
uint current_ky = rem / p.OW;
55+
uint current_ix = rem % p.OW;
56+
4357
A_TYPE values[NUM_ITER];
4458
uint offset_dst[NUM_ITER];
4559
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
@@ -48,36 +62,35 @@ void main() {
4862

4963
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
5064

51-
const uint i = gidx * NUM_ITER + idx;
65+
const uint linear_idx = base_linear_idx + idx;
5266

53-
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
54-
const uint kx = i / ksize;
55-
const uint kd = kx * ksize;
56-
const uint ky = (i - kd) / p.OW;
57-
const uint ix = i % p.OW;
67+
if (linear_idx >= p.pelements) {
68+
continue;
69+
}
5870

59-
const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
60-
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
71+
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
72+
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
6173

62-
offset_dst[idx] =
63-
((batch * p.OH + oh) * p.OW + ix) * p.CHW +
64-
(ic * (p.KW * p.KH) + ky * p.KW + kx);
74+
offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
6575

66-
if (i >= p.pelements) {
67-
continue;
76+
if ((iih < p.IH) && (iiw < p.IW)) {
77+
values[idx] = data_a[src_base + iih * p.IW + iiw];
6878
}
6979

70-
if (iih < p.IH && iiw < p.IW) {
71-
const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
72-
values[idx] = data_a[offset_src + iih * p.IW + iiw];
80+
if (++current_ix == p.OW) {
81+
current_ix = 0;
82+
if (++current_ky == max_ky) {
83+
current_ky = 0;
84+
current_kx++;
85+
}
7386
}
7487
}
7588

7689
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
7790

78-
const uint i = gidx * NUM_ITER + idx;
91+
const uint linear_idx = base_linear_idx + idx;
7992

80-
if (i >= p.pelements) {
93+
if (linear_idx >= p.pelements) {
8194
continue;
8295
}
8396

0 commit comments

Comments
 (0)