Skip to content

Commit d1fa03b

Browse files
committed
[ET-VK][Ops] aten.convolution (SlidingWindow)
## The Operator `nn.Module` invocations of [`nn.Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) and [`nn.ConvTranspose2d`](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d) get compiled to `aten.convolution.default` in the Edge Dialect, which carries the signature ``` - func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor ``` ## Summary (cases handled) We introduce support for the convolution cases covered by [ATen-VK's default SlidingWindow implementation](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L73). This is achieved by - reusing the [existing `conv2d.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d.glsl), and - [moving special weights prepacking from CPU](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L134-L235) to the GPU in `conv2d_prepack_weights.glsl`. We also include resizing support for dynamic shapes. Note that only height and width of the input can vary. ## Cases not handled The implementation is on-par with ATen-VK's SlidingWindow. This means the following cases are missing: 1. **Groups G > 1.** Largely not covered by ATen-VK. `G = in_channels` is covered by ATen-VK's Depthwise impl and will be added soon. 2. **Batch (input) N > 1.** Not covered by ATen-VK. 3. **Padding > 0 while Dilation, Kernel > 1.** Not covered by ATen-VK. ## Coming soon For our CUNET model, the first two are required and the third is useful. 1. Transpose convolution 2. Depthwise convolution (for completeness) 3. Pointwise convolution (for optimization) 4. Null bias Differential Revision: [D55346778](https://our.internmc.facebook.com/intern/diff/D55346778/) ghstack-source-id: 220997713 Pull Request resolved: #2812
1 parent 0edf83a commit d1fa03b

File tree

12 files changed

+700
-2
lines changed

12 files changed

+700
-2
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4848
exir_ops.edge.aten.max_pool2d_with_indices.default,
4949
# Sum
5050
exir_ops.edge.aten.sum.dim_IntList,
51+
# Convolution operators
52+
exir_ops.edge.aten.convolution.default,
5153
# Other
5254
operator.getitem,
5355
]

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class ComputeGraph final {
172172
const api::ScalarType dtype,
173173
const api::StorageType storage_type,
174174
const api::GPUMemoryLayout memory_layout,
175-
const int64_t shared_object_idx);
175+
const int64_t shared_object_idx = -1);
176176

177177
/*
178178
* Add a `vTensor` value to the graph with the specified properties. The
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#include "indexing_utils.h"
14+
15+
layout(std430) buffer;
16+
17+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
18+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
19+
layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in;
20+
layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in;
21+
22+
layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents {
23+
uvec4 data;
24+
}
25+
out_extents;
26+
27+
layout(set = 0, binding = 5) uniform PRECISION restrict InExtents {
28+
uvec4 data;
29+
}
30+
in_extents;
31+
32+
layout(set = 0, binding = 6) uniform PRECISION restrict Params {
33+
ivec2 kernel_size;
34+
ivec2 stride;
35+
ivec2 padding;
36+
ivec2 dilation;
37+
}
38+
params;
39+
40+
// If fields are separated, SwiftShader cannot identify in_group_size.
41+
layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams {
42+
ivec2 overlay_region;
43+
int in_group_size;
44+
}
45+
extra_params;
46+
47+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
48+
49+
/*
50+
* Computes a 2D convolution. Each shader invocation calculates the output at
51+
* a single output location.
52+
*/
53+
void main() {
54+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
55+
56+
if (any(greaterThanEqual(pos, out_extents.data.xyz))) {
57+
return;
58+
}
59+
60+
// Compute the index of the top-left element of the overlay region. Negative
61+
// indices indicate that the top-left element is in a region added by padding.
62+
const ivec2 ipos = pos.xy * params.stride - params.padding;
63+
64+
// Compute the start and end of the input indices to load. Padding is assumed
65+
// to be constant 0 padding, so reads from the padding region are skipped.
66+
const ivec2 start = max(ivec2(0), ipos);
67+
const ivec2 end = min(ipos + extra_params.overlay_region.xy, ivec2(in_extents.data.xy));
68+
// Compute the start of the kernel based on how far we are skipping ahead when
69+
// reading the input. Note that these are "canonical" indices.
70+
ivec2 kstart = (start - ipos) / params.dilation;
71+
// During prepacking, the weight tensor was rearranged in order to optimize
72+
// for data access linearity in this shader. Therefore we need to adjust the
73+
// canonical coordinates to the corresponding index in the rearranged weight
74+
// tensor. The x-coordinate is multipled by 4 since each group of 4 channels
75+
// is folded into the X axis. The y-coordinate is offset based on the z-
76+
// coordinate because the 2D planes were stacked atop each other vertically.
77+
kstart.x *= 4;
78+
kstart.y += pos.z * params.kernel_size.y;
79+
80+
// Perform the convolution by iterating over the overlay region.
81+
vec4 sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
82+
const int ic4 = extra_params.in_group_size / 4;
83+
for (int z4 = 0; z4 < ic4; ++z4, kstart.x += params.kernel_size.x * 4) {
84+
for (int y = start.y, ky = kstart.y; y < end.y; y += params.dilation.y, ++ky) {
85+
for (int x = start.x, kx = kstart.x; x < end.x; x += params.dilation.x, kx += 4) {
86+
const vec4 in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);
87+
88+
// To explain the calculation below, the contents of in_texel and the
89+
// group of 4 texels loaded from kernel_in are shown:
90+
//
91+
// in_texel kernel_in
92+
// -x-> ---x--->
93+
// +---+ +----+----+----+----+
94+
// ^ | w | ^ | D0 | D1 | D2 | D3 |
95+
// | +---+ | +----+----+----+----+
96+
// | | z | | | C0 | C1 | C2 | C3 |
97+
// z +---+ z +----+----+----+----+
98+
// | | y | | | B0 | B1 | B2 | B3 |
99+
// | +---+ | +----+----+----+----+
100+
// | x | | A0 | A1 | A2 | A3 |
101+
// +---+ +----+----+----+----+
102+
//
103+
// In the kernel_in graphic, cells sharing the same letter are from
104+
// the same batch/output channel index, and the number denotes a unique
105+
// channel index. To calculate the output texel, the following
106+
// calculation is performed:
107+
//
108+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
109+
// | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 |
110+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
111+
// | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 |
112+
// +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+
113+
// | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 |
114+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
115+
// | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 |
116+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
117+
//
118+
// which is what is expressed in the following calculations.
119+
120+
const vec4 ktex_0 = texelFetch(kernel_in, ivec2(kx + 0, ky), 0);
121+
sum = fma(in_texel.xxxx, ktex_0, sum);
122+
123+
const vec4 ktex_1 = texelFetch(kernel_in, ivec2(kx + 1, ky), 0);
124+
sum = fma(in_texel.yyyy, ktex_1, sum);
125+
126+
const vec4 ktex_2 = texelFetch(kernel_in, ivec2(kx + 2, ky), 0);
127+
sum = fma(in_texel.zzzz, ktex_2, sum);
128+
129+
const vec4 ktex_3 = texelFetch(kernel_in, ivec2(kx + 3, ky), 0);
130+
sum = fma(in_texel.wwww, ktex_3, sum);
131+
}
132+
}
133+
}
134+
135+
imageStore(image_out, pos, sum);
136+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv2d:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
SUFFIX: half
15+
- VALUE: float
16+
SUFFIX: float
17+
shader_variants:
18+
- NAME: conv2d
19+
20+
conv2d_prepack_weights:
21+
parameter_names_with_default_values:
22+
DTYPE: float
23+
generate_variant_forall:
24+
DTYPE:
25+
- VALUE: half
26+
SUFFIX: half
27+
- VALUE: float
28+
SUFFIX: float
29+
shader_variants:
30+
- NAME: conv2d_prepack_weights
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#include "indexing_utils.h"
14+
15+
layout(std430) buffer;
16+
17+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;
18+
layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
19+
${T[DTYPE]} data[];
20+
}
21+
buffer_in;
22+
23+
// Corresponds to {1,4,9,24} in the example below.
24+
layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes {
25+
ivec4 data;
26+
}
27+
gpu_sizes;
28+
29+
// Corresponds to {3,3,7,10} in the example below.
30+
layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes {
31+
ivec4 data;
32+
}
33+
original_sizes;
34+
35+
// Corresponds to {3,3,8,12} in the example below.
36+
layout(set = 0, binding = 4) uniform PRECISION restrict AlignedSizes {
37+
ivec4 data;
38+
}
39+
padded_sizes;
40+
41+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
42+
43+
/*
44+
* Computes special prepacking for a 2D convolution. Each shader invocation
45+
* calculates the input buffer location to read into the desired texel. This
46+
* packing was originally developed on CPU and that approach is described in the
47+
* rest of this comment. Refer to the code-level comments, for how we translate
48+
* it to GPU by reversing the steps.
49+
*
50+
* Consider example weight tensor of size {10,7,3,3}. The following
51+
* transformations will be applied.
52+
*
53+
* 1. Pad the N and C dims so that both are a multiple of 4. In this case, 2
54+
* batches and 1 channel of padding are added, producing a tensor of size
55+
* {12,8,3,3}.
56+
* at::pad(x, {0,0,0,0,0,2,0,1}, "constant", 0);
57+
*
58+
* 2. Split the tensor along the C dim so that each split has 4 channels.
59+
* x.reshape({12,2,4,3,3});
60+
*
61+
* 3. For each split, "fold" the C dim into the W dim. Suppose the first rows
62+
* at H=0 of the split have values
63+
* 0,1,2 | 10,11,12 | 20,21,22 | 30,31,32
64+
*
65+
* where | denotes a channel boundary. Then, the goal is to combine those rows
66+
* into one row with the values
67+
* 0, 10, 20, 30, 1, 11, 21, 31, 2, 12, 22, 32
68+
*
69+
* x.permute({0,1,3,4,2}).reshape({12,2,3,12});
70+
*
71+
* 4. Stack the splits belonging to the same batch horizontally by swapping the
72+
* C and H dims.
73+
* x.permute({0,2,1,3}).reshape({12,3,24});
74+
*
75+
* 5. Repeat a similar process to "fold" the N dim into the C dim. Split along
76+
* the N dim so that each split has 4 batches.
77+
* x.reshape({3,4,3,24});
78+
*
79+
* 6. Stack the batches on each other vertically by swapping the N and C dims.
80+
* x.permute({1,0,2,3}).reshape({4,9,24});
81+
*/
82+
void main() {
83+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
84+
const ivec4 coord = POS_TO_COORD_CHANNELS_PACKED(pos, gpu_sizes.data);
85+
86+
if (any(greaterThanEqual(coord, gpu_sizes.data))) {
87+
return;
88+
}
89+
90+
// As in usual staging shaders, map from GPU texel position to normal CPU
91+
// buffer indices: (24,9) -> (4,9,24)
92+
const int base_index = COORD_TO_BUFFER_IDX(coord, gpu_sizes.data);
93+
const ivec4 p0 =
94+
base_index + ivec4(0, 1, 2, 3) * STRIDE_CHANNELS_PACKED(gpu_sizes.data);
95+
96+
// Re-map the normal CPU buffer indices to special indices, through a series
97+
// of permutations: reshape is a no-op to the underlying indices, and permute
98+
// is one of the hardest math problems I've ever solved.
99+
//
100+
// Undo step 6 premute: (4,3,3,24) -> (3,4,3,24)
101+
// Undo step 4 permute: (12,3,2,12) -> (12,2,3,12)
102+
// Undo step 3 permute, part 1: (12,2,3h,3w,4) -> (12,2,3h,4,3w)
103+
// Undo step 3 permute, part 2: (12,2,3h,4,3w) -> (12,2,4,3h,3w)
104+
const ivec4 p1 = SWAP_DIMS(
105+
p0,
106+
4,
107+
(padded_sizes.data.w / 4),
108+
(padded_sizes.data.y * padded_sizes.data.z * padded_sizes.data.x));
109+
const ivec4 p2 = SWAP_DIMS(
110+
p1,
111+
padded_sizes.data.y,
112+
(padded_sizes.data.z / 4),
113+
(padded_sizes.data.x * 4));
114+
const ivec4 p3 = SWAP_DIMS(p2, padded_sizes.data.x, 4, 1);
115+
const ivec4 p4 = SWAP_DIMS(p3, padded_sizes.data.y, 4, padded_sizes.data.x);
116+
117+
// For values in the padded region, write zero instead of buffer data.
118+
//
119+
// Undo step 1 pad: (12,8,3,3) -> (10,7,3,3)
120+
const ivec4 c = p4 %
121+
(padded_sizes.data.z * padded_sizes.data.y * padded_sizes.data.x) /
122+
(padded_sizes.data.y * padded_sizes.data.x);
123+
const ivec4 n =
124+
p4 / (padded_sizes.data.z * padded_sizes.data.y * padded_sizes.data.x);
125+
const ivec4 p5 = p4 -
126+
n * (padded_sizes.data.z - original_sizes.data.z) * padded_sizes.data.y *
127+
padded_sizes.data.x;
128+
const ivec4 mask = ivec4(greaterThanEqual(c, original_sizes.data.zzzz)) |
129+
ivec4(greaterThanEqual(n, original_sizes.data.wwww));
130+
131+
${T[DTYPE]} val_x = mix(buffer_in.data[p5.x], 0, mask.x);
132+
${T[DTYPE]} val_y = mix(buffer_in.data[p5.y], 0, mask.y);
133+
${T[DTYPE]} val_z = mix(buffer_in.data[p5.z], 0, mask.z);
134+
${T[DTYPE]} val_w = mix(buffer_in.data[p5.w], 0, mask.w);
135+
136+
${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w);
137+
138+
imageStore(image_out, pos.xy, texel);
139+
}

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,12 @@
4444
#define STRIDE_WIDTH_PACKED(vec) (1)
4545

4646
#define STRIDE_HEIGHT_PACKED(vec) (vec.x)
47+
48+
// Given a buffer(1-D) index cur, compute a new index where the corresponding
49+
// tensor(N-D)'s x and y dimensions are swapped, and size is of the M-D plane of
50+
// dimensions lower than x and y.
51+
#define SWAP_DIMS(cur, x, y, size) \
52+
cur + \
53+
size*( \
54+
(1 - y) * ((cur % (x * y * size)) / (y * size)) + \
55+
(x - 1) * ((cur % (y * size)) / size))

0 commit comments

Comments
 (0)