Skip to content

Commit 1cd8c02

Browse files
committed
Update on "[ET-VK][Ops] aten.convolution (SlidingWindow)"
## The Operator `nn.Module` invocations of [`nn.Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) and [`nn.ConvTranspose2d`](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d) get compiled to `aten.convolution.default` in the Edge Dialect, which carries the signature ``` - func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor ``` ## Summary (cases handled) We introduce support for the convolution cases covered by [ATen-VK's default SlidingWindow implementation](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L73). This is achieved by - reusing the [existing `conv2d.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d.glsl), and - [moving special weights prepacking from CPU](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L134-L235) to the GPU in `conv2d_prepack_weights.glsl`. We also include resizing support for dynamic shapes. Note that only height and width of the input can vary. ## Cases not handled The implementation is on-par with ATen-VK's SlidingWindow. This means the following cases are missing: 1. **Groups G > 1.** Largely not covered by ATen-VK. `G = in_channels` is covered by ATen-VK's Depthwise impl and will be added soon. 2. **Batch (input) N > 1.** Not covered by ATen-VK. 3. **Padding > 0 while Dilation, Kernel > 1.** Not covered by ATen-VK. ## Coming soon For our CUNET model, the first two are required and the third is useful. 1. Transpose convolution 2. Depthwise convolution (for completeness) 3. Pointwise convolution (for optimization) 4. Null bias Differential Revision: [D55346778](https://our.internmc.facebook.com/intern/diff/D55346778/) [ghstack-poisoned]
2 parents c156160 + 1f1a2c2 commit 1cd8c02

File tree

6 files changed

+25
-21
lines changed

6 files changed

+25
-21
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ void main() {
7878
kstart.y += pos.z * params.kernel_size.y;
7979

8080
// Perform the convolution by iterating over the overlay region.
81-
vec4 sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
81+
${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
8282
const int ic4 = extra_params.in_group_size / 4;
8383
for (int z4 = 0; z4 < ic4; ++z4, kstart.x += params.kernel_size.x * 4) {
8484
for (int y = start.y, ky = kstart.y; y < end.y; y += params.dilation.y, ++ky) {
8585
for (int x = start.x, kx = kstart.x; x < end.x; x += params.dilation.x, kx += 4) {
86-
const vec4 in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);
86+
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);
8787

8888
// To explain the calculation below, the contents of in_texel and the
8989
// group of 4 texels loaded from kernel_in are shown:
@@ -115,18 +115,18 @@ void main() {
115115
// | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 |
116116
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
117117
//
118-
// which is what is expressed in the following calculations.
118+
// which is expressed in the following statements.
119119

120-
const vec4 ktex_0 = texelFetch(kernel_in, ivec2(kx + 0, ky), 0);
120+
const ${VEC4_T[DTYPE]} ktex_0 = texelFetch(kernel_in, ivec2(kx + 0, ky), 0);
121121
sum = fma(in_texel.xxxx, ktex_0, sum);
122122

123-
const vec4 ktex_1 = texelFetch(kernel_in, ivec2(kx + 1, ky), 0);
123+
const ${VEC4_T[DTYPE]} ktex_1 = texelFetch(kernel_in, ivec2(kx + 1, ky), 0);
124124
sum = fma(in_texel.yyyy, ktex_1, sum);
125125

126-
const vec4 ktex_2 = texelFetch(kernel_in, ivec2(kx + 2, ky), 0);
126+
const ${VEC4_T[DTYPE]} ktex_2 = texelFetch(kernel_in, ivec2(kx + 2, ky), 0);
127127
sum = fma(in_texel.zzzz, ktex_2, sum);
128128

129-
const vec4 ktex_3 = texelFetch(kernel_in, ivec2(kx + 3, ky), 0);
129+
const ${VEC4_T[DTYPE]} ktex_3 = texelFetch(kernel_in, ivec2(kx + 3, ky), 0);
130130
sum = fma(in_texel.wwww, ktex_3, sum);
131131
}
132132
}

backends/vulkan/runtime/graph/ops/glsl/conv2d.yaml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,3 @@ conv2d:
1616
SUFFIX: float
1717
shader_variants:
1818
- NAME: conv2d
19-
20-
conv2d_prepack_weights:
21-
parameter_names_with_default_values:
22-
DTYPE: float
23-
generate_variant_forall:
24-
DTYPE:
25-
- VALUE: half
26-
SUFFIX: half
27-
- VALUE: float
28-
SUFFIX: float
29-
shader_variants:
30-
- NAME: conv2d_prepack_weights

backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4747
* rest of this comment. Refer to the code-level comments, for how we translate
4848
* it to GPU by reversing the steps.
4949
*
50-
* Consider example weight tensor of size {10,7,3,3}. The following
50+
* Consider an example weight tensor of size {10,7,3,3}. The following
5151
* transformations will be applied.
5252
*
5353
* 1. Pad the N and C dims so that both are a multiple of 4. In this case, 2
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv2d_prepack_weights:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
generate_variant_forall:
11+
DTYPE:
12+
- VALUE: half
13+
SUFFIX: half
14+
- VALUE: float
15+
SUFFIX: float
16+
shader_variants:
17+
- NAME: conv2d_prepack_weights

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ void record_conv2d_prepack_weights_op(
138138
api::MemoryAccessType::WRITE),
139139
src_buffer,
140140
v_dst.gpu_sizes_ubo()->buffer(),
141-
v_dst.cpu_sizes_ubo()->buffer(),
142141
original_sizes_ubo.buffer(),
143142
padded_sizes_ubo.buffer());
144143
}

0 commit comments

Comments
 (0)