Skip to content

[5/n][ET-VK][Ops] aten.flip #5879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __contains__(self, op):
exir_ops.edge.aten.t_copy.default,
# Indexing and lookup
exir_ops.edge.aten.embedding.default,
exir_ops.edge.aten.flip.default,
exir_ops.edge.aten.index_select.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.slice_copy.Tensor,
Expand Down
78 changes: 78 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/flip.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

#include "indexing_utils.h"

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
${layout_declare_ubo(B, "ivec3", "out_limits")}
${layout_declare_ubo(B, "ivec4", "out_sizes")}
${layout_declare_ubo(B, "ivec4", "dims")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_limits))) {
return;
}

VEC4_T out_texel = VEC4_T(0);
uint src_x = pos.x;
uint src_y = pos.y;
uint src_z = pos.z;

int flattened_channels = int(ceil(out_sizes.z / 4.0));

// Width
if (dims.x == 1) {
src_x = out_sizes.x - 1 - pos.x;
}
// Height
if (dims.y == 1) {
src_y = out_sizes.y - 1 - pos.y;
}
// Batch
if (dims.w == 1) {
uint n = pos.z / flattened_channels;
uint src_n = out_sizes.w - 1 - n;
uint c4 = pos.z - n * flattened_channels;
src_z = src_n * flattened_channels + c4;
}

uint prev_src_z = src_z;
for (int p = 0; p < 4; ++p) {
uint src_p = p;

// Channel
if (dims.z == 1) {
uint nc = (pos.z / flattened_channels) * flattened_channels;
uint c4 = pos.z - nc;
uint c = c4 * 4 + p;
uint src_c = out_sizes.z - 1 - c;

src_z = (dims.w == 1)
? prev_src_z - c4 + src_c / 4 // Batch and Channel
: nc + src_c / 4; // Channel only
src_p = src_c % 4;
}

VEC4_T in_texel = VEC4_T(texelFetch(t_in, ivec3(src_x, src_y, src_z), 0));
out_texel[p] = in_texel[src_p];
}
imageStore(t_out, pos, out_texel);
}
13 changes: 13 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/flip.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
flip:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
- VALUE: int
- VALUE: int8
- VALUE: uint8
shader_variants:
- NAME: flip
95 changes: 95 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Flip.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void check_flip_args(const api::vTensor& in, const api::vTensor& out) {
VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
}

void resize_flip_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr in = graph->get_tensor(args[1].refs[0]);

out->virtual_resize(in->sizes());
}

utils::ivec4 create_whcn_bitmap(
const std::vector<int64_t>& list,
const int64_t ndim) {
std::vector<int64_t> bm(4, 0);
for (const auto e : list) {
auto x = (e % ndim + ndim) % ndim; // normalize
x = ndim - 1 - x; // reverse
bm.at(x) = 1;
}
return utils::make_ivec4(bm);
}

void add_flip_node(
ComputeGraph& graph,
const ValueRef in,
const std::vector<int64_t>& dim_list,
const ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);
check_flip_args(*t_in, *t_out);

const auto dim_bitmap = create_whcn_bitmap(dim_list, t_in->dim());

std::string kernel_name("flip");
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
// Inputs and Outputs
{
{out, vkapi::kWrite},
{in, vkapi::kRead},
},
// Parameter buffers
{
graph.logical_limits_ubo(out),
graph.sizes_ubo(out),
graph.create_params_buffer(dim_bitmap),
},
// Specialization Constants
{},
// Resizing Logic
resize_flip_node));
}

void flip(ComputeGraph& graph, const std::vector<ValueRef>& args) {
ValueRef in = args[0];
auto dims = graph.get_int_list(args[1]);
ValueRef out = args[2];

add_flip_node(graph, in, *dims, out);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.flip.default, flip);
}

} // namespace vkcompute
20 changes: 20 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,3 +1159,23 @@ def get_squeeze_copy_dim_inputs():
]
)
return test_suite


@register_test_suite("aten.flip.default")
def get_flip_inputs():
Test = namedtuple("Flip", ["self", "dim"])
Test.__new__.__defaults__ = (None, 0)

test_cases = [
Test(self=[9], dim=[0]),
Test(self=[9, 9], dim=[0, 1]),
Test(self=[9, 9, 9], dim=[0, 2]),
Test(self=[9, 9, 9], dim=[0, 1, 2]),
Test(self=[9, 9, 9, 9], dim=[0]),
Test(self=[9, 9, 9, 9], dim=[0, 2, 3]),
Test(self=[9, 9, 9, 9], dim=[1, 3]),
Test(self=[9, 9, 9, 9], dim=[0, 1, 2, 3]),
]

test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
return test_suite
14 changes: 14 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,6 +1649,20 @@ def forward(self, x):
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_flip(self):
class FlipModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.flip(x, [0, 1, 2, 3])

self.lower_module_and_test_output(
FlipModule(),
(torch.arange(48).reshape(2, 3, 4, 2),),
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_conv_with_clamp(self):
class ConvWithClampModule(torch.nn.Module):
def __init__(self):
Expand Down
Loading