Skip to content

Commit 8a6427e

Browse files
jorgep31415facebook-github-bot
authored andcommitted
aten.convolution (Transpose) (#2883)
Summary: Pull Request resolved: #2883 ## Summary (cases handled) We introduce support for the convolution cases covered by ATen-VK's transpose implementation. This is achieved by - reusing the existing [`conv_transpose2d.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv_transpose2d.glsl), and - [moving special weights prepacking from CPU](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L134-L235) to the GPU in `conv_transpose2d_prepack_weights.glsl`. We also include resizing support for dynamic shapes. Note that only height and width of the input can vary. ## Cases not handled The implementation is on-par with ATen-VK's Transpose. This means the following cases are missing: 1. **Groups G > 1.** 2. **Batch (input) N > 1.** 3. **Dilation > 1.** ghstack-source-id: 221721754 exported-using-ghexport bypass-github-export-checks Reviewed By: copyrightly, SS-JIA Differential Revision: D55667336 fbshipit-source-id: 3b7b7c912ef947610624e2e1c5b753de393234a0
1 parent cb6ddae commit 8a6427e

12 files changed

+446
-71
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
layout(std430) buffer;
14+
15+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
16+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
17+
layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in;
18+
layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in;
19+
20+
layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents {
21+
uvec4 data;
22+
}
23+
out_extents;
24+
25+
layout(set = 0, binding = 5) uniform PRECISION restrict InExtents {
26+
uvec4 data;
27+
}
28+
in_extents;
29+
30+
layout(set = 0, binding = 6) uniform PRECISION restrict Params {
31+
ivec2 kernel_size;
32+
ivec2 stride;
33+
ivec2 padding;
34+
ivec2 dilation;
35+
}
36+
params;
37+
38+
// If fields are separated, SwiftShader cannot identify in_group_size.
39+
layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams {
40+
ivec2 overlay_region;
41+
int in_group_size;
42+
}
43+
extra_params;
44+
45+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
46+
47+
/*
48+
* Computes a 2D transpose convolution. Each shader invocation calculates the
49+
* output at a single output location. For details, refer to conv2d.glsl which
50+
* uses a similar approach.
51+
*/
52+
void main() {
53+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
54+
55+
if (any(greaterThanEqual(pos, out_extents.data.xyz))) {
56+
return;
57+
}
58+
59+
ivec2 ipos = pos.xy + params.padding;
60+
61+
const ivec2 start = max(
62+
ivec2(0),
63+
ivec2(ceil((vec2(ipos) - params.kernel_size + 1) / vec2(params.stride))));
64+
const ivec2 end =
65+
min(ivec2(in_extents.data.xy),
66+
ivec2(floor(vec2(ipos) / vec2(params.stride))) + 1);
67+
68+
const int ic = extra_params.in_group_size;
69+
const int kx_stride = ic * (params.stride.x - 1);
70+
71+
int ky_start = extra_params.overlay_region.y - 1 -
72+
(ipos.y - params.stride.y * start.y) + pos.z * params.kernel_size.y;
73+
int kx_start = (extra_params.overlay_region.x - 1 -
74+
(ipos.x - params.stride.x * start.x)) * ic;
75+
76+
${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
77+
for (int y = start.y, ky = ky_start; y < end.y; ++y, ky += params.stride.y) {
78+
for (int x = start.x, kx = kx_start; x < end.x; ++x, kx += kx_stride) {
79+
for (int z4 = 0; z4 < ic / 4; ++z4, kx += 4) {
80+
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);
81+
const ivec4 kxs = kx + ivec4(0, 1, 2, 3);
82+
83+
sum = fma(in_texel.xxxx, texelFetch(kernel_in, ivec2(kxs.x, ky), 0), sum);
84+
sum = fma(in_texel.yyyy, texelFetch(kernel_in, ivec2(kxs.y, ky), 0), sum);
85+
sum = fma(in_texel.zzzz, texelFetch(kernel_in, ivec2(kxs.z, ky), 0), sum);
86+
sum = fma(in_texel.wwww, texelFetch(kernel_in, ivec2(kxs.w, ky), 0), sum);
87+
}
88+
}
89+
}
90+
91+
imageStore(image_out, pos, sum);
92+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv_transpose2d:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
SUFFIX: half
15+
- VALUE: float
16+
SUFFIX: float
17+
shader_variants:
18+
- NAME: conv_transpose2d
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#include "indexing_utils.h"
14+
15+
layout(std430) buffer;
16+
17+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;
18+
layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
19+
${T[DTYPE]} data[];
20+
}
21+
buffer_in;
22+
23+
// Corresponds to {1,4,6,36} in the example below.
24+
layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes {
25+
ivec4 data;
26+
}
27+
gpu_sizes;
28+
29+
// Corresponds to {3,3,7,10} in the example below.
30+
layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes {
31+
ivec4 data;
32+
}
33+
original_sizes;
34+
35+
// Corresponds to {8,12} in the example below.
36+
layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes {
37+
ivec2 data;
38+
}
39+
padded_sizes;
40+
41+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
42+
43+
/*
44+
* Computes special prepacking for a 2D transpose convolution. Each shader
45+
* invocation calculates the input buffer location to read into the desired
46+
* texel.
47+
*
48+
* For details, refer to conv2d_prepack_weights.glsl which uses a similar
49+
* approach. For transpose, there are slight differences to reflect the data
50+
* access pattern in the shader. First, the weight tensor is flipped along the H
51+
* and W dims. Second, steps 3 and 4 are slightly different so that the splits
52+
* are interleaved.
53+
*/
54+
void main() {
55+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
56+
const ivec4 coord = POS_TO_COORD_CHANNELS_PACKED(pos, gpu_sizes.data);
57+
58+
if (any(greaterThanEqual(coord, gpu_sizes.data))) {
59+
return;
60+
}
61+
62+
// As in usual staging shaders, map from GPU texel position to normal CPU
63+
// buffer indices: (36,6) -> (4,6,36)
64+
const int base_index = COORD_TO_BUFFER_IDX(coord, gpu_sizes.data);
65+
const ivec4 p0 =
66+
base_index + ivec4(0, 1, 2, 3) * STRIDE_CHANNELS_PACKED(gpu_sizes.data);
67+
68+
// Re-map the normal CPU buffer indices to special indices, through a series
69+
// of mappings: reshape is a no-op to the underlying indices, so we only map
70+
// for flip, pad, and permute.
71+
const int Np = padded_sizes.data.y;
72+
const int Cp = padded_sizes.data.x;
73+
const int N = original_sizes.data.w;
74+
const int C = original_sizes.data.z;
75+
const int H = original_sizes.data.y;
76+
const int W = original_sizes.data.x;
77+
78+
// Undo step 6 premute: (4,2,3,36) -> (2,4,3,36)
79+
// In the following comments, a=b=c=3.
80+
// Undo step 3 permute, part 1: (8,a,b,c,4) -> (8,a,c,b,4)
81+
// Undo step 3 permute, part 2: (8,a,c,b,4) -> (8,c,a,b,4)
82+
// Undo step 3 permute, part 3: (8,c,a,b,4) -> (8,c,a,4,b)
83+
// Undo step 3 permute, part 4: (8,c,a,4,b) -> (8,c,4,a,b)
84+
const ivec4 p1 = SWAP_ADJ_DIMS(p0, 4, (Cp / 4), (H * Np * W));
85+
const ivec4 p2 = SWAP_ADJ_DIMS(p1, W, (Np / 4), 4);
86+
const ivec4 p3 = SWAP_ADJ_DIMS(p2, H, (Np / 4), (W * 4));
87+
const ivec4 p4 = SWAP_ADJ_DIMS(p3, W, 4, 1);
88+
const ivec4 p5 = SWAP_ADJ_DIMS(p4, H, 4, W);
89+
90+
// Undo step 0 permute: (8,12,3,3) -> (12,8,3,3)
91+
const ivec4 p6 = SWAP_ADJ_DIMS(p5, Cp, Np, (W * H));
92+
// Undo step 0 flip: (2,3)
93+
const ivec4 w = p6 % W;
94+
const ivec4 h = p6 % (H * W) / W;
95+
const ivec4 p7 = p6 + W - 1 - 2 * w + W * (H - 1 - 2 * h);
96+
97+
// Undo step 1 pad: (12,8,3,3) -> (10,7,3,3)
98+
// For values in the padded region, write zero instead of buffer data.
99+
const ivec4 c = p7 % (Cp * H * W) / (H * W);
100+
const ivec4 n = p7 / (Cp * H * W);
101+
const ivec4 p8 = p7 - n * (Cp - C) * H * W;
102+
const ivec4 mask = ivec4(greaterThanEqual(c, ivec4(C))) |
103+
ivec4(greaterThanEqual(n, ivec4(N)));
104+
105+
${T[DTYPE]} val_x = mix(buffer_in.data[p8.x], 0, mask.x);
106+
${T[DTYPE]} val_y = mix(buffer_in.data[p8.y], 0, mask.y);
107+
${T[DTYPE]} val_z = mix(buffer_in.data[p8.z], 0, mask.z);
108+
${T[DTYPE]} val_w = mix(buffer_in.data[p8.w], 0, mask.w);
109+
110+
${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w);
111+
112+
imageStore(image_out, pos.xy, texel);
113+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv_transpose2d_prepack_weights:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
SUFFIX: half
15+
- VALUE: float
16+
SUFFIX: float
17+
shader_variants:
18+
- NAME: conv_transpose2d_prepack_weights

0 commit comments

Comments
 (0)