Skip to content

[ET-VK][6/n] aten.view_copy #3129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,21 @@

#define divup4(x) ((x + 3) / 4)

#define to_buffer_i(idx, sizes) \
idx.x + idx.y* sizes.x + idx.z* sizes.y* sizes.x + \
idx.w* sizes.z* sizes.y* sizes.x;
// Input: idx is a ivec4 user-level coordinate, sizes is the tensor shape
// Output: buffer_idx in the continuous nchw-buffer.
#define to_buffer_i(idx, sizes) \
(idx.x + idx.y * sizes.x + idx.z * sizes.y * sizes.x + \
idx.w * sizes.z * sizes.y * sizes.x)

// Inverse of to_buffer_i
// Input: buffer_idx in the continuous nchw-buffer, sizes is the tensor shape
// Output: ivec4 user-level coorindate
#define from_buffer_i(buf_i, sizes) \
ivec4( \
buf_i % sizes.x, \
(buf_i / (sizes.x)) % sizes.y, \
(buf_i / (sizes.x * sizes.y)) % sizes.z, \
(buf_i / (sizes.x * sizes.y * sizes.z)))

#define get_packed_dim_C_packed(vec) vec.z
#define get_packed_dim_W_packed(vec) vec.x
Expand All @@ -20,6 +32,8 @@
#define get_packed_stride_W_packed(vec) (1)
#define get_packed_stride_H_packed(vec) (vec.x)

// Input: pos is a texture position, sizes is a pack-aligned size.
// Output: a user-level (w, h, c, n) coordinate
#define to_tensor_idx_C_packed(pos, sizes) \
ivec4(pos.x, pos.y, (pos.z * 4) % sizes.z, (pos.z * 4) / sizes.z)

Expand All @@ -29,6 +43,9 @@
#define to_tensor_idx_H_packed(pos, sizes) \
ivec4(pos.x, (pos.y * 4), pos.z % sizes.z, pos.z / sizes.z)

// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned
// size.
// Output: texture location
#define to_texture_pos_C_packed(idx, sizes) \
ivec3(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4)

Expand All @@ -38,6 +55,19 @@
#define to_texture_pos_H_packed(idx, sizes) \
ivec3(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z))

// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned
// size with the index in the texel.
// Output: ivec4, xyz is the texture position, w is the element index in the
// texel.
#define to_texture_pos_elem_C_packed(idx, sizes) \
ivec4(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4, idx.z % 4)

#define to_texture_pos_elem_W_packed(idx, sizes) \
ivec4(idx.x / 4, idx.y, (idx.z + idx.w * sizes.z), idx.x % 4)

#define to_texture_pos_elem_H_packed(idx, sizes) \
ivec4(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z), idx.y % 4)

// Given a buffer(1-D) index cur, compute a new index where the corresponding
// tensor(N-D)'s adjacent dimensions are swapped. The parameters x,y and plane
// describe sizes. As an example, let's say we want to swap dimensions 0,1 for a
Expand Down
76 changes: 76 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/view.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_type(DTYPE)}

layout(std430) buffer;

#include "indexing_utils.h"

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;

#define VEC4_T ${texel_type(DTYPE)}

#define to_tensor_idx to_tensor_idx_${PACKING}
#define to_texture_pos_elem to_texture_pos_elem_${PACKING}
#define get_packed_stride get_packed_stride_${PACKING}

layout(set = 0, binding = 2) uniform PRECISION restrict OutGpuSizes {
uvec4 out_gpu_sizes;
};

layout(set = 0, binding = 3) uniform PRECISION restrict OutCpuSizes {
uvec4 out_cpu_sizes;
};

layout(set = 0, binding = 4) uniform PRECISION restrict InGpuSizes {
uvec4 in_gpu_sizes;
};

layout(set = 0, binding = 5) uniform PRECISION restrict InCpuSizes {
uvec4 in_cpu_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;


void main() {
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_gpu_sizes);

if (all(greaterThanEqual(out_tensor_idx, out_gpu_sizes))) {
return;
}

// Assume there is a virtual continous buffer in nchw format. From the output
// pos, we first calculate the index in the virual buffer, and then calculate
// the input position from the indx.

const uint base_index = to_buffer_i(out_tensor_idx, out_cpu_sizes);
const uvec4 buf_indices =
base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes);

VEC4_T value;
// Need to look up the 4 values in the output texel separately.
for (int i=0; i<4; i++) {
ivec4 user_coor = from_buffer_i(buf_indices[i], in_cpu_sizes);

ivec4 in_pos_elem = to_texture_pos_elem(user_coor, in_gpu_sizes);

VEC4_T intex = VEC4_T(texelFetch(image_in, in_pos_elem.xyz, 0));

value[i] = intex[in_pos_elem.w];
}

imageStore(image_out, out_pos, value);
}
14 changes: 14 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/view.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
view:
parameter_names_with_default_values:
DTYPE: float
NDIM: 3
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
PACKING:
- VALUE: C_packed
- VALUE: W_packed
- VALUE: H_packed
shader_variants:
- NAME: view
51 changes: 51 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/View.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void add_view_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);

std::string kernel_name = "view";
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);
add_memory_layout_suffix(kernel_name, *t_out);

api::utils::uvec3 global_size = t_out->extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_size,
local_size,
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
{t_out->gpu_sizes_ubo(),
t_out->cpu_sizes_ubo(),
t_in->gpu_sizes_ubo(),
t_in->cpu_sizes_ubo()}));
}

void view(ComputeGraph& graph, const std::vector<ValueRef>& args) {
// Note: The second argument size_ref is not used here. Since the output
// tensor's size have been determined during compilation.
return add_view_node(graph, args[0], args[2]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.view_copy.default, view);
}

} // namespace vkcompute
28 changes: 28 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,33 @@ def get_permute_inputs():
return test_suite


def get_view_inputs():
test_suite = VkTestSuite(
[
((3, 4, 5), [1, 1, -1]),
((3, 4, 5), [1, -1, 1]),
((3, 4, 5), [-1, 1, 1]),
((8, 7, 2, 3), [4, 3, 7, 4]),
((8, 7, 2, 3), [7, -1, 2, 1]),
((8, 7, 2, 3), [1, 1, 1, -1]),
((8, 7, 2, 3), [-1]),
((2, 3, 3, 7), [2, -1, 1, 1]),
((3, 5, 2, 7), [7, -1, 2, 1]),
((2, 2, 8, 6), [2, 6, -1, 1]),
((2, 2, 8, 6), [6, -1, 1]),
((S1, S2, S1, S2), [S2, -1, 1, S1]),
((S1, S2, S1, S2), [S1, 1, -1, S2]),
((S1, S2, S1, S2), [-1, 1, S1, S2]),
]
)
test_suite.layouts = [
"api::kWidthPacked",
"api::kHeightPacked",
"api::kChannelsPacked",
]
return test_suite


test_suites = {
"aten.add.Tensor": get_binary_elementwise_inputs(),
"aten.sub.Tensor": get_binary_elementwise_inputs(),
Expand All @@ -208,4 +235,5 @@ def get_permute_inputs():
"aten.select_copy.int": get_select_int_inputs(),
"aten.permute.default": get_permute_inputs(),
"aten.permute_copy.default": get_permute_inputs(),
"aten.view_copy.default": get_view_inputs(),
}
8 changes: 7 additions & 1 deletion backends/vulkan/test/op_tests/utils/codegen_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,16 @@ def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str:
for size in arg_sizes_or_val:
name_str += str(size) + "x"
name_str = name_str[:-1]
# minus sign is a invalid char for test case. change to "n".
name_str = name_str.replace("-", "n")

elif isinstance(arg_sizes_or_val, list):
for size in arg_sizes_or_val:
name_str += str(size) + "c"
name_str = name_str[:-1]
# minus sign is a invalid char for test case. change to "n".
name_str = name_str.replace("-", "n")

else:
name_str += str(arg_sizes_or_val).replace(".", "p")
return name_str
Expand Down Expand Up @@ -234,7 +240,7 @@ def generate_suite_cpp(self) -> str:
// from_blob doesn't take ownership of data. Hence must create a copy as
// "values" will go out of scope.
return at::from_blob(values.data(), sizes, dtype).detach().clone();
return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone();
}}
{test_suites_cpp}
Expand Down