Skip to content

[ET-VK] Introduce generalized shaders for transfer ops and use it for select and slice #11255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,14 +492,24 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
const ValueRef idx) {
if (values_.at(idx).isInt()) {
const int32_t val = extract_scalar<int32_t>(idx);
create_params_buffer(val);
return create_params_buffer(val);
} else if (values_.at(idx).isSymInt()) {
SymIntPtr symint = get_symint(idx);
return vkapi::BufferBindInfo(symint->gpu_buffer.buffer());
}
VK_THROW("Cannot create a int param buffer for the given value");
}

vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
const ValueRef idx,
const int32_t default_val) {
if (values_.at(idx).isNone()) {
return create_params_buffer(default_val);
} else {
return get_or_create_int_param_buffer(idx);
}
}

void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
get_symint(idx)->set(val);
}
Expand Down Expand Up @@ -693,6 +703,12 @@ void ComputeGraph::resize_input(
get_tensor(io_val.value)->virtual_resize(new_sizes);
}

void ComputeGraph::virtual_resize(
const ValueRef idx,
const std::vector<int64_t>& new_sizes) {
get_tensor(idx)->virtual_resize(new_sizes);
}

void ComputeGraph::propagate_resize() {
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
node->trigger_resize(this);
Expand Down
20 changes: 20 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,19 @@ class ComputeGraph final {
std::optional<T> extract_optional_scalar(const ValueRef idx) {
if (val_is_none(idx)) {
return ::std::nullopt;
} else if (val_is_symint(idx)) {
return utils::safe_downcast<T>(read_symint(idx));
} else {
return extract_scalar<T>(idx);
}
}

template <typename T>
T extract_optional_scalar(const ValueRef idx, const T default_val) {
if (val_is_none(idx)) {
return default_val;
} else if (val_is_symint(idx)) {
return utils::safe_downcast<T>(read_symint(idx));
} else {
return extract_scalar<T>(idx);
}
Expand Down Expand Up @@ -609,6 +622,10 @@ class ComputeGraph final {
*/
vkapi::BufferBindInfo get_or_create_int_param_buffer(const ValueRef idx);

vkapi::BufferBindInfo get_or_create_int_param_buffer(
const ValueRef idx,
const int32_t default_value);

void set_symint(const ValueRef idx, const int32_t val);

int32_t read_symint(const ValueRef idx);
Expand Down Expand Up @@ -753,6 +770,9 @@ class ComputeGraph final {
//

void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
void virtual_resize(
const ValueRef idx,
const std::vector<int64_t>& new_sizes);
void propagate_resize();

//
Expand Down
74 changes: 74 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/select.glslh
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#ifndef SELECT_GLSLH
#define SELECT_GLSLH

/*
* Enable the fast path if a texel loaded from the input texture can be used as
* is to store to the output texture. The following conditions must be met:
*
* 1. The input and output textures have the same packed dimension.
* 2. The selected_dim must not be the packed dimension of the input.
* 3. The packed dimension of the input must "map" to the packed dimension of
* the output. This occurs if selected_dim is greater than the packed dimension
* of the input.
*/
bool can_use_fast_path() {
if (out_packed_dim != in_packed_dim) {
return false;
}
if (selected_dim <= in_packed_dim) {
return false;
}
return true;
}

/*
* Given an output tensor index, return the corresponding input tensor index for
* the select operator. This is done by "inserting" the select index at the
* selected_dim in the input tensor index.
*
* A simple example is (note all tensor index are in WHCN order):
* out_tidx = [7, 5, 9]
* selected_dim = 2
* index = 3
* in_tidx = [7, 3, 5, 9]
*
* This function assumes that the following variables are defined in the layout:
* - in_sizes
* - selected_dim
* - index
*/
ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) {
ivec4 in_tidx = ivec4(0);

int adjusted_index = index;
if (index < 0) {
adjusted_index = index + in_sizes[selected_dim];
}

// Handle different dimensions for selection
if (selected_dim == 0) {
// Select from width dimension
in_tidx = ivec4(adjusted_index, out_tidx.x, out_tidx.y, out_tidx.z);
} else if (selected_dim == 1) {
// Select from height dimension
in_tidx = ivec4(out_tidx.x, adjusted_index, out_tidx.y, out_tidx.z);
} else if (selected_dim == 2) {
// Select from channel dimension
in_tidx = ivec4(out_tidx.x, out_tidx.y, adjusted_index, out_tidx.z);
} else if (selected_dim == 3) {
// Select from batch dimension
in_tidx = ivec4(out_tidx.x, out_tidx.y, out_tidx.z, adjusted_index);
}

return in_tidx;
}

#endif // SELECT_GLSLH
52 changes: 0 additions & 52 deletions backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl

This file was deleted.

50 changes: 0 additions & 50 deletions backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl

This file was deleted.

10 changes: 0 additions & 10 deletions backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.yaml

This file was deleted.

65 changes: 0 additions & 65 deletions backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl

This file was deleted.

62 changes: 0 additions & 62 deletions backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl

This file was deleted.

Loading
Loading