Skip to content

[ET-VK][Ops] aten.convolution (Pointwise) #2886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#include "indexing_utils.h"

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in;
layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in;

layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents {
uvec4 data;
}
out_extents;

layout(set = 0, binding = 5) uniform PRECISION restrict InExtents {
uvec4 data;
}
in_extents;

layout(set = 0, binding = 6) uniform PRECISION restrict Params {
ivec2 kernel_size;
ivec2 stride;
ivec2 padding;
ivec2 dilation;
}
params;

// If fields are separated, SwiftShader cannot identify in_group_size.
layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams {
ivec2 overlay_region;
int in_group_size;
}
extra_params;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
* output tile for pointwise convolution is more efficient because the kernel
* size is only 1x1, making it easier to re-use loaded texels from kernel_in.
*/
void main() {
const ivec3 gpos = ivec3(gl_GlobalInvocationID);

// Output position for TILE_SIZE = 2
// +--------+--------+
// | pos[0] | pos[1] |
// +--------+--------+
// | pos[2] | pos[3] |
// +--------+--------+
ivec3 pos[${TILE_SIZE * TILE_SIZE}];
for (int y = 0, i = 0; y < 2; ++y) {
for (int x = 0; x < 2; ++x) {
pos[i] = ivec3(
gpos.x * 2 + x, gpos.y * ${TILE_SIZE} + y, gpos.z);
i++;
}
}

// If the top left position is out of bounds, then this invocation will have
// no work to do.
if (any(greaterThanEqual(pos[0], out_extents.data.xyz))) {
return;
}

// Compute the index of the input texture that needs to be loaded for each
// output position. Note that negative indices can be produced indicating that
// the top-left element is in a region added by padding.
ivec2 ipos[${TILE_SIZE * TILE_SIZE}];
for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
ipos[i] = pos[i].xy * params.stride - params.padding;
}

vec4 sum[${TILE_SIZE * TILE_SIZE}];
sum[0] = texelFetch(bias_in, ivec2(gpos.z, 0), 0);
for (int i = 1; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
sum[i] = sum[0];
}

// Since the kernel is 1x1, we only have to loop over the depth dimension.
for (int z = 0, z4 = 0; z < extra_params.in_group_size; z += 4, ++z4) {
// During prepacking, the weight tensor has been permuted so that the
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
// the z-axis.
vec4 in_tex[${TILE_SIZE * TILE_SIZE}];
const vec4 ktex_0 = texelFetch(kernel_in, ivec2(z + 0, gpos.z), 0);
const vec4 ktex_1 = texelFetch(kernel_in, ivec2(z + 1, gpos.z), 0);
const vec4 ktex_2 = texelFetch(kernel_in, ivec2(z + 2, gpos.z), 0);
const vec4 ktex_3 = texelFetch(kernel_in, ivec2(z + 3, gpos.z), 0);

for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
in_tex[i] = texelFetch(image_in, ivec3(ipos[i], z4), 0);
}

for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
// For 2x2 tile size algorithm works as follows.
// To explain the calculations below, the contents of one in_tex and the
// group of 4 texels loaded from kernel_in are shown:
//
// in_tex kernel_in
// -x-> ---x--->
// +---+ +----+----+----+----+
// ^ | w | ^ | D0 | D1 | D2 | D3 |
// | +---+ | +----+----+----+----+
// | | z | | | C0 | C1 | C2 | C3 |
// z +---+ z +----+----+----+----+
// | | y | | | B0 | B2 | B2 | B3 |
// | +---+ | +----+----+----+----+
// | x | | A0 | A1 | A2 | A3 |
// +---+ +----+----+----+----+
//
// In the kernel_in graphic, cells sharing the same letter are from
// the same batch/output channel index, and the number denotes a unique
// channel index. To calculate the output texel, the following
// calculation is performed:
//
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
// | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 |
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
// | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 |
// +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+
// | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 |
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
// | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 |
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
//
// which is what is expressed in the following calculations. This is done
// for each output position.
sum[i] = fma(in_tex[i].xxxx, ktex_0, sum[i]);
sum[i] = fma(in_tex[i].yyyy, ktex_1, sum[i]);
sum[i] = fma(in_tex[i].zzzz, ktex_2, sum[i]);
sum[i] = fma(in_tex[i].wwww, ktex_3, sum[i]);
}
}

for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
if (all(lessThan(pos[i], out_extents.data.xyz))) {
imageStore(image_out, pos[i], sum[i]);
}
}
}
19 changes: 19 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

conv2d_pw:
parameter_names_with_default_values:
NDIM: 3
DTYPE: float
TILE_SIZE: 2
generate_variant_forall:
DTYPE:
- VALUE: half
SUFFIX: half
- VALUE: float
SUFFIX: float
shader_variants:
- NAME: conv2d_pw
13 changes: 13 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ ValueRef prepack_biases(ComputeGraph& graph, const ValueRef vref) {

enum class Conv2dMethod : uint8_t {
Depthwise,
Pointwise,
SlidingWindow,
Transposed,
};
Expand All @@ -106,6 +107,13 @@ api::ShaderInfo get_conv2d_shader(
}
}
break;
case Conv2dMethod::Pointwise:
if (prepack_weights) {
kernel_name << "conv2d";
} else {
kernel_name << "conv2d_pw";
}
break;
case Conv2dMethod::SlidingWindow:
kernel_name << "conv2d";
break;
Expand Down Expand Up @@ -136,6 +144,7 @@ std::vector<int64_t> get_final_sizes(
case Conv2dMethod::Depthwise:
return std::vector<int64_t>{
4, batch_padded * channels / 4, height * width};
case Conv2dMethod::Pointwise:
case Conv2dMethod::SlidingWindow:
return std::vector<int64_t>{
4, batch_padded * height / 4, channels_padded * width};
Expand All @@ -156,6 +165,7 @@ std::vector<int64_t> get_padded_sizes(
switch (method) {
case Conv2dMethod::Depthwise:
return std::vector<int64_t>{-1, batch_padded};
case Conv2dMethod::Pointwise:
case Conv2dMethod::SlidingWindow:
case Conv2dMethod::Transposed:
return std::vector<int64_t>{batch_padded, channels_padded};
Expand Down Expand Up @@ -265,6 +275,9 @@ Conv2dMethod get_conv2d_method(
if (transposed) {
return Conv2dMethod::Transposed;
}
if (weight_sizes.at(2) == 1 && weight_sizes.at(3) == 1) {
return Conv2dMethod::Pointwise;
}
return Conv2dMethod::SlidingWindow;
}

Expand Down
25 changes: 25 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,3 +576,28 @@ def forward(self, x):
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_conv2d_pw(self):
class Conv2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=8,
out_channels=8,
kernel_size=1,
padding=1,
groups=1,
bias=True,
)

def forward(self, x):
return self.conv(x)

conv2d_module = Conv2dModule()
sample_inputs = (torch.randn(size=(1, 8, 72, 96), dtype=torch.float32),)

self.lower_module_and_test_output(
conv2d_module,
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)