Skip to content

Commit 9e1861b

Browse files
committed
[ET-VK][Ops] aten.convolution (Pointwise)
We port an optimization from ATen-VK for specific weight sizes: [`conv2d_pw.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl) Differential Revision: [D55814587](https://our.internmc.facebook.com/intern/diff/D55814587/) ghstack-source-id: 221526241 Pull Request resolved: #2886
1 parent b6bbee6 commit 9e1861b

File tree

4 files changed

+210
-0
lines changed

4 files changed

+210
-0
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#include "indexing_utils.h"
14+
15+
layout(std430) buffer;
16+
17+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
18+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
19+
layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in;
20+
layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in;
21+
22+
layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents {
23+
uvec4 data;
24+
}
25+
out_extents;
26+
27+
layout(set = 0, binding = 5) uniform PRECISION restrict InExtents {
28+
uvec4 data;
29+
}
30+
in_extents;
31+
32+
layout(set = 0, binding = 6) uniform PRECISION restrict Params {
33+
ivec2 kernel_size;
34+
ivec2 stride;
35+
ivec2 padding;
36+
ivec2 dilation;
37+
}
38+
params;
39+
40+
// If fields are separated, SwiftShader cannot identify in_group_size.
41+
layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams {
42+
ivec2 overlay_region;
43+
int in_group_size;
44+
}
45+
extra_params;
46+
47+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
48+
49+
/*
50+
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
51+
* output tile for pointwise convolution is more efficient because the kernel
52+
* size is only 1x1, making it easier to re-use loaded texels from kernel_in.
53+
*/
54+
void main() {
55+
const ivec3 gpos = ivec3(gl_GlobalInvocationID);
56+
57+
// Output position for TILE_SIZE = 2
58+
// +--------+--------+
59+
// | pos[0] | pos[1] |
60+
// +--------+--------+
61+
// | pos[2] | pos[3] |
62+
// +--------+--------+
63+
ivec3 pos[${TILE_SIZE * TILE_SIZE}];
64+
for (int y = 0, i = 0; y < 2; ++y) {
65+
for (int x = 0; x < 2; ++x) {
66+
pos[i] = ivec3(
67+
gpos.x * 2 + x, gpos.y * ${TILE_SIZE} + y, gpos.z);
68+
i++;
69+
}
70+
}
71+
72+
// If the top left position is out of bounds, then this invocation will have
73+
// no work to do.
74+
if (any(greaterThanEqual(pos[0], out_extents.data.xyz))) {
75+
return;
76+
}
77+
78+
// Compute the index of the input texture that needs to be loaded for each
79+
// output position. Note that negative indices can be produced indicating that
80+
// the top-left element is in a region added by padding.
81+
ivec2 ipos[${TILE_SIZE * TILE_SIZE}];
82+
for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
83+
ipos[i] = pos[i].xy * params.stride - params.padding;
84+
}
85+
86+
vec4 sum[${TILE_SIZE * TILE_SIZE}];
87+
sum[0] = texelFetch(bias_in, ivec2(gpos.z, 0), 0);
88+
for (int i = 1; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
89+
sum[i] = sum[0];
90+
}
91+
92+
// Since the kernel is 1x1, we only have to loop over the depth dimension.
93+
for (int z = 0, z4 = 0; z < extra_params.in_group_size; z += 4, ++z4) {
94+
// During prepacking, the weight tensor has been permuted so that the
95+
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
96+
// the z-axis.
97+
vec4 in_tex[${TILE_SIZE * TILE_SIZE}];
98+
const vec4 ktex_0 = texelFetch(kernel_in, ivec2(z + 0, gpos.z), 0);
99+
const vec4 ktex_1 = texelFetch(kernel_in, ivec2(z + 1, gpos.z), 0);
100+
const vec4 ktex_2 = texelFetch(kernel_in, ivec2(z + 2, gpos.z), 0);
101+
const vec4 ktex_3 = texelFetch(kernel_in, ivec2(z + 3, gpos.z), 0);
102+
103+
for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
104+
in_tex[i] = texelFetch(image_in, ivec3(ipos[i], z4), 0);
105+
}
106+
107+
for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
108+
// For 2x2 tile size algorithm works as follows.
109+
// To explain the calculations below, the contents of one in_tex and the
110+
// group of 4 texels loaded from kernel_in are shown:
111+
//
112+
// in_tex kernel_in
113+
// -x-> ---x--->
114+
// +---+ +----+----+----+----+
115+
// ^ | w | ^ | D0 | D1 | D2 | D3 |
116+
// | +---+ | +----+----+----+----+
117+
// | | z | | | C0 | C1 | C2 | C3 |
118+
// z +---+ z +----+----+----+----+
119+
// | | y | | | B0 | B2 | B2 | B3 |
120+
// | +---+ | +----+----+----+----+
121+
// | x | | A0 | A1 | A2 | A3 |
122+
// +---+ +----+----+----+----+
123+
//
124+
// In the kernel_in graphic, cells sharing the same letter are from
125+
// the same batch/output channel index, and the number denotes a unique
126+
// channel index. To calculate the output texel, the following
127+
// calculation is performed:
128+
//
129+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
130+
// | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 |
131+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
132+
// | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 |
133+
// +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+
134+
// | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 |
135+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
136+
// | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 |
137+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
138+
//
139+
// which is what is expressed in the following calculations. This is done
140+
// for each output position.
141+
sum[i] = fma(in_tex[i].xxxx, ktex_0, sum[i]);
142+
sum[i] = fma(in_tex[i].yyyy, ktex_1, sum[i]);
143+
sum[i] = fma(in_tex[i].zzzz, ktex_2, sum[i]);
144+
sum[i] = fma(in_tex[i].wwww, ktex_3, sum[i]);
145+
}
146+
}
147+
148+
for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
149+
if (all(lessThan(pos[i], out_extents.data.xyz))) {
150+
imageStore(image_out, pos[i], sum[i]);
151+
}
152+
}
153+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv2d_pw:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
TILE_SIZE: 2
12+
generate_variant_forall:
13+
DTYPE:
14+
- VALUE: half
15+
SUFFIX: half
16+
- VALUE: float
17+
SUFFIX: float
18+
shader_variants:
19+
- NAME: conv2d_pw

backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ ValueRef prepack_biases(ComputeGraph& graph, const ValueRef vref) {
8282

8383
enum class Conv2dMethod : uint8_t {
8484
Depthwise,
85+
Pointwise,
8586
SlidingWindow,
8687
Transposed,
8788
};
@@ -106,6 +107,13 @@ api::ShaderInfo get_conv2d_shader(
106107
}
107108
}
108109
break;
110+
case Conv2dMethod::Pointwise:
111+
if (prepack_weights) {
112+
kernel_name << "conv2d";
113+
} else {
114+
kernel_name << "conv2d_pw";
115+
}
116+
break;
109117
case Conv2dMethod::SlidingWindow:
110118
kernel_name << "conv2d";
111119
break;
@@ -136,6 +144,7 @@ std::vector<int64_t> get_final_sizes(
136144
case Conv2dMethod::Depthwise:
137145
return std::vector<int64_t>{
138146
4, batch_padded * channels / 4, height * width};
147+
case Conv2dMethod::Pointwise:
139148
case Conv2dMethod::SlidingWindow:
140149
return std::vector<int64_t>{
141150
4, batch_padded * height / 4, channels_padded * width};
@@ -156,6 +165,7 @@ std::vector<int64_t> get_padded_sizes(
156165
switch (method) {
157166
case Conv2dMethod::Depthwise:
158167
return std::vector<int64_t>{-1, batch_padded};
168+
case Conv2dMethod::Pointwise:
159169
case Conv2dMethod::SlidingWindow:
160170
case Conv2dMethod::Transposed:
161171
return std::vector<int64_t>{batch_padded, channels_padded};
@@ -265,6 +275,9 @@ Conv2dMethod get_conv2d_method(
265275
if (transposed) {
266276
return Conv2dMethod::Transposed;
267277
}
278+
if (weight_sizes.at(2) == 1 && weight_sizes.at(3) == 1) {
279+
return Conv2dMethod::Pointwise;
280+
}
268281
return Conv2dMethod::SlidingWindow;
269282
}
270283

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,3 +576,28 @@ def forward(self, x):
576576
sample_inputs,
577577
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
578578
)
579+
580+
def test_vulkan_backend_conv2d_pw(self):
581+
class Conv2dModule(torch.nn.Module):
582+
def __init__(self):
583+
super().__init__()
584+
self.conv = torch.nn.Conv2d(
585+
in_channels=8,
586+
out_channels=8,
587+
kernel_size=1,
588+
padding=1,
589+
groups=1,
590+
bias=True,
591+
)
592+
593+
def forward(self, x):
594+
return self.conv(x)
595+
596+
conv2d_module = Conv2dModule()
597+
sample_inputs = (torch.randn(size=(1, 8, 72, 96), dtype=torch.float32),)
598+
599+
self.lower_module_and_test_output(
600+
conv2d_module,
601+
sample_inputs,
602+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
603+
)

0 commit comments

Comments
 (0)