Skip to content

Commit 64facea

Browse files
committed
[6/n][ET-VK][Ops] aten.flip
Pull Request resolved: #5879 Port from LiteInterpreter's [`flip.glsl`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/vulkan/glsl/flip.glsl). ``` - func: flip(Tensor self, int[] dims) -> Tensor ``` Will use this to verify AHB image is interpreted correctly. Differential Revision: [D63843843](https://our.internmc.facebook.com/intern/diff/D63843843/) ghstack-source-id: 246607070
1 parent 71ba888 commit 64facea

File tree

6 files changed

+221
-0
lines changed

6 files changed

+221
-0
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __contains__(self, op):
101101
exir_ops.edge.aten.t_copy.default,
102102
# Indexing and lookup
103103
exir_ops.edge.aten.embedding.default,
104+
exir_ops.edge.aten.flip.default,
104105
exir_ops.edge.aten.index_select.default,
105106
exir_ops.edge.aten.select_copy.int,
106107
exir_ops.edge.aten.slice_copy.Tensor,
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14+
15+
#include "indexing_utils.h"
16+
17+
layout(std430) buffer;
18+
19+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
20+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
21+
${layout_declare_ubo(B, "ivec3", "out_limits")}
22+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
23+
${layout_declare_ubo(B, "ivec4", "dims")}
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
void main() {
28+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
29+
30+
if (any(greaterThanEqual(pos, out_limits))) {
31+
return;
32+
}
33+
34+
VEC4_T out_texel = VEC4_T(0);
35+
uint src_x = pos.x;
36+
uint src_y = pos.y;
37+
uint src_z = pos.z;
38+
39+
int flattened_channels = int(ceil(out_sizes.z / 4.0));
40+
41+
// Width
42+
if (dims.x == 1) {
43+
src_x = out_sizes.x - 1 - pos.x;
44+
}
45+
// Height
46+
if (dims.y == 1) {
47+
src_y = out_sizes.y - 1 - pos.y;
48+
}
49+
// Batch
50+
if (dims.w == 1) {
51+
uint n = pos.z / flattened_channels;
52+
uint src_n = out_sizes.w - 1 - n;
53+
uint c4 = pos.z - n * flattened_channels;
54+
src_z = src_n * flattened_channels + c4;
55+
}
56+
57+
uint prev_src_z = src_z;
58+
for (int p = 0; p < 4; ++p) {
59+
uint src_p = p;
60+
61+
// Channel
62+
if (dims.z == 1) {
63+
uint nc = (pos.z / flattened_channels) * flattened_channels;
64+
uint c4 = pos.z - nc;
65+
uint c = c4 * 4 + p;
66+
uint src_c = out_sizes.z - 1 - c;
67+
68+
src_z = (dims.w == 1)
69+
? prev_src_z - c4 + src_c / 4 // Batch and Channel
70+
: nc + src_c / 4; // Channel only
71+
src_p = src_c % 4;
72+
}
73+
74+
VEC4_T in_texel = VEC4_T(texelFetch(t_in, ivec3(src_x, src_y, src_z), 0));
75+
out_texel[p] = in_texel[src_p];
76+
}
77+
imageStore(t_out, pos, out_texel);
78+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
flip:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
STORAGE: texture3d
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
- VALUE: int
10+
- VALUE: int8
11+
- VALUE: uint8
12+
shader_variants:
13+
- NAME: flip
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
15+
16+
namespace vkcompute {
17+
18+
void check_flip_args(const api::vTensor& in, const api::vTensor& out) {
19+
VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
20+
VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
21+
}
22+
23+
void resize_flip_node(
24+
ComputeGraph* graph,
25+
const std::vector<ArgGroup>& args,
26+
const std::vector<ValueRef>& extra_args) {
27+
(void)extra_args;
28+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
29+
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
30+
31+
out->virtual_resize(in->sizes());
32+
}
33+
34+
utils::ivec4 create_whcn_bitmap(
35+
const std::vector<int64_t>& list,
36+
const int64_t ndim) {
37+
std::vector<int64_t> bm(4, 0);
38+
for (const auto e : list) {
39+
auto x = (e % ndim + ndim) % ndim; // normalize
40+
x = ndim - 1 - x; // reverse
41+
bm.at(x) = 1;
42+
}
43+
return utils::make_ivec4(bm);
44+
}
45+
46+
void add_flip_node(
47+
ComputeGraph& graph,
48+
const ValueRef in,
49+
const std::vector<int64_t>& dim_list,
50+
const ValueRef out) {
51+
vTensorPtr t_in = graph.get_tensor(in);
52+
vTensorPtr t_out = graph.get_tensor(out);
53+
check_flip_args(*t_in, *t_out);
54+
55+
const auto dim_bitmap = create_whcn_bitmap(dim_list, t_in->dim());
56+
57+
std::string kernel_name("flip");
58+
kernel_name.reserve(kShaderNameReserve);
59+
add_dtype_suffix(kernel_name, *t_out);
60+
61+
graph.execute_nodes().emplace_back(new ExecuteNode(
62+
graph,
63+
VK_KERNEL_FROM_STR(kernel_name),
64+
graph.create_global_wg_size(out),
65+
graph.create_local_wg_size(out),
66+
// Inputs and Outputs
67+
{
68+
{out, vkapi::MemoryAccessType::WRITE},
69+
{in, vkapi::MemoryAccessType::READ},
70+
},
71+
// Parameter buffers
72+
{
73+
graph.logical_limits_ubo(out),
74+
graph.sizes_ubo(out),
75+
graph.create_params_buffer(dim_bitmap),
76+
},
77+
// Specialization Constants
78+
{},
79+
// Resizing Logic
80+
resize_flip_node));
81+
}
82+
83+
void flip(ComputeGraph& graph, const std::vector<ValueRef>& args) {
84+
ValueRef in = args[0];
85+
auto dims = graph.get_int_list(args[1]);
86+
ValueRef out = args[2];
87+
88+
add_flip_node(graph, in, *dims, out);
89+
}
90+
91+
REGISTER_OPERATORS {
92+
VK_REGISTER_OP(aten.flip.default, flip);
93+
}
94+
95+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,3 +1159,23 @@ def get_squeeze_copy_dim_inputs():
11591159
]
11601160
)
11611161
return test_suite
1162+
1163+
1164+
@register_test_suite("aten.flip.default")
1165+
def get_flip_inputs():
1166+
Test = namedtuple("VkIndexSelectTest", ["self", "dim"])
1167+
Test.__new__.__defaults__ = (None, 0)
1168+
1169+
test_cases = [
1170+
Test(self=[9], dim=[0]),
1171+
Test(self=[9, 9], dim=[0, 1]),
1172+
Test(self=[9, 9, 9], dim=[0, 2]),
1173+
Test(self=[9, 9, 9], dim=[0, 1, 2]),
1174+
Test(self=[9, 9, 9, 9], dim=[0]),
1175+
Test(self=[9, 9, 9, 9], dim=[0, 2, 3]),
1176+
Test(self=[9, 9, 9, 9], dim=[1, 3]),
1177+
Test(self=[9, 9, 9, 9], dim=[0, 1, 2, 3]),
1178+
]
1179+
1180+
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
1181+
return test_suite

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,6 +1649,20 @@ def forward(self, x):
16491649
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
16501650
)
16511651

1652+
def test_vulkan_backend_flip(self):
1653+
class FlipModule(torch.nn.Module):
1654+
def __init__(self):
1655+
super().__init__()
1656+
1657+
def forward(self, x):
1658+
return torch.flip(x, [0, 1, 2, 3])
1659+
1660+
self.lower_module_and_test_output(
1661+
FlipModule(),
1662+
(torch.arange(48).reshape(2, 3, 4, 2),),
1663+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1664+
)
1665+
16521666
def test_vulkan_backend_conv_with_clamp(self):
16531667
class ConvWithClampModule(torch.nn.Module):
16541668
def __init__(self):

0 commit comments

Comments
 (0)