Skip to content

vulkan: improve im2col #11826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 28, 2025
49 changes: 31 additions & 18 deletions ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ void main() {
const uint batch = gl_GlobalInvocationID.z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC;

const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
const int oh_s1 = int(oh) * p.s1;
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);

const uint base_linear_idx = gidx * NUM_ITER;

const uint max_ky = ksize / p.OW;

uint current_kx = base_linear_idx / ksize;
const uint rem = base_linear_idx - (current_kx * ksize);
uint current_ky = rem / p.OW;
uint current_ix = rem % p.OW;

A_TYPE values[NUM_ITER];
uint offset_dst[NUM_ITER];
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
Expand All @@ -48,36 +62,35 @@ void main() {

[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {

const uint i = gidx * NUM_ITER + idx;
const uint linear_idx = base_linear_idx + idx;

const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
const uint kx = i / ksize;
const uint kd = kx * ksize;
const uint ky = (i - kd) / p.OW;
const uint ix = i % p.OW;
if (linear_idx >= p.pelements) {
continue;
}

const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;

offset_dst[idx] =
((batch * p.OH + oh) * p.OW + ix) * p.CHW +
(ic * (p.KW * p.KH) + ky * p.KW + kx);
offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;

if (i >= p.pelements) {
continue;
if ((iih < p.IH) && (iiw < p.IW)) {
values[idx] = data_a[src_base + iih * p.IW + iiw];
}

if (iih < p.IH && iiw < p.IW) {
const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
values[idx] = data_a[offset_src + iih * p.IW + iiw];
if (++current_ix == p.OW) {
current_ix = 0;
if (++current_ky == max_ky) {
current_ky = 0;
current_kx++;
}
}
}

[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {

const uint i = gidx * NUM_ITER + idx;
const uint linear_idx = base_linear_idx + idx;

if (i >= p.pelements) {
if (linear_idx >= p.pelements) {
continue;
}

Expand Down
Loading