Skip to content

[ET-VK] Introduce DynamicDispatchNode #10979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>

#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
#include <executorch/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h>
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
#include <executorch/backends/vulkan/runtime/graph/ops/PrepackNode.h>

Expand Down
8 changes: 4 additions & 4 deletions backends/vulkan/runtime/graph/ops/DispatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ComputeGraph;
/*
* Represents a single shader execution op in a ML model.
*/
class DispatchNode final : public ExecuteNode {
class DispatchNode : public ExecuteNode {
friend class ComputeGraph;

public:
Expand All @@ -43,9 +43,9 @@ class DispatchNode final : public ExecuteNode {
void encode(ComputeGraph* graph) override;

protected:
const vkapi::ShaderInfo shader_;
const utils::uvec3 global_workgroup_size_;
const utils::WorkgroupSize local_workgroup_size_;
vkapi::ShaderInfo shader_;
utils::uvec3 global_workgroup_size_;
utils::WorkgroupSize local_workgroup_size_;
const vkapi::ParamsBindList params_;
const vkapi::SpecVarList spec_vars_;
const std::vector<PushConstantDataInfo> push_constants_;
Expand Down
49 changes: 49 additions & 0 deletions backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

namespace vkcompute {

DynamicDispatchNode::DynamicDispatchNode(
ComputeGraph& graph,
const PickShaderFn& pick_shader_fn,
const PickGlobalFn& pick_global_wg_fn,
const PickLocalFn& pick_local_wg_fn,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const std::vector<PushConstantDataInfo>& push_constants,
const vkapi::SpecVarList& spec_vars,
const std::vector<ValueRef>& resize_args,
const ResizeFunction& resize_fn)
: DispatchNode(
graph,
pick_shader_fn(&graph, args, resize_args),
pick_global_wg_fn(&graph, args, resize_args),
pick_local_wg_fn(&graph, args, resize_args),
args,
params,
push_constants,
spec_vars,
resize_args,
resize_fn),
pick_shader_fn_(pick_shader_fn),
pick_global_wg_fn_(pick_global_wg_fn),
pick_local_wg_fn_(pick_local_wg_fn) {}

void DynamicDispatchNode::encode(ComputeGraph* graph) {
shader_ = pick_shader_fn_(graph, args_, resize_args_);
global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_);
local_workgroup_size_ =
utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_));
DispatchNode::encode(graph);
}

} // namespace vkcompute
69 changes: 69 additions & 0 deletions backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/vulkan/runtime/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/containers/PushConstantData.h>
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>

#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>

namespace vkcompute {

class ComputeGraph;

/*
* Represents a single shader execution op in a ML model.
*/
class DynamicDispatchNode final : public DispatchNode {
friend class ComputeGraph;

public:
using PickShaderFn = const std::function<vkapi::ShaderInfo(
ComputeGraph*,
const std::vector<ArgGroup>&,
const std::vector<ValueRef>&)>;
using PickGlobalFn = const std::function<utils::uvec3(
ComputeGraph*,
const std::vector<ArgGroup>&,
const std::vector<ValueRef>&)>;
using PickLocalFn = const std::function<utils::uvec3(
ComputeGraph*,
const std::vector<ArgGroup>&,
const std::vector<ValueRef>&)>;

explicit DynamicDispatchNode(
ComputeGraph& graph,
const PickShaderFn& pick_shader_fn,
const PickGlobalFn& pick_global_wg_fn,
const PickLocalFn& pick_local_wg_fn,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const std::vector<PushConstantDataInfo>& push_constants,
const vkapi::SpecVarList& spec_vars,
const std::vector<ValueRef>& resize_args,
const ResizeFunction& resize_fn = nullptr);

~DynamicDispatchNode() override = default;

void encode(ComputeGraph* graph) override;

protected:
const PickShaderFn pick_shader_fn_;
const PickGlobalFn pick_global_wg_fn_;
const PickLocalFn pick_local_wg_fn_;

public:
operator bool() const {
return shader_;
}
};

} // namespace vkcompute
45 changes: 45 additions & 0 deletions backends/vulkan/test/glsl/dynamic_dispatch_test.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

layout(std430) buffer;

${layout_declare_tensor(0, "w", "t_out", "float", "texture3d")}
${layout_declare_tensor(1, "r", "t_in1", "float", "texture3d")}
${layout_declare_tensor(2, "r", "t_in2", "float", "texture3d")}

layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
ivec4 in1_sizes;
ivec4 in2_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_sizes.xyz))) {
return;
}


vec4 out_texel = vec4(0.0);
for (int row = 0; row < in1_sizes.y; ++row) {
ivec3 in_pos = ivec3(pos.x, row, pos.z);
vec4 in1_texel = texelFetch(t_in1, in_pos, 0);
vec4 in2_texel = texelFetch(t_in2, in_pos, 0);

out_texel += in1_texel * in2_texel;
}

imageStore(t_out, pos, out_texel + ${OFFSET});
}
7 changes: 7 additions & 0 deletions backends/vulkan/test/glsl/dynamic_dispatch_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
dynamic_dispatch_test:
parameter_names_with_default_values:
OFFSET: 2.25
shader_variants:
- NAME: dynamic_dispatch_test_var1
- NAME: dynamic_dispatch_test_var2
OFFSET: 5.5
137 changes: 137 additions & 0 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3297,3 +3297,140 @@ TEST(VulkanComputeGraphOpsTest, test_to_copy) {
test_to_copy();
}
}

vkapi::ShaderInfo pick_dynamic_dispatch_shader(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& additional_args) {
const ValueRef mat1 = args[1].refs[0];

std::string kernel_name = "dynamic_dispatch_test";
if (graph->size_at<int32_t>(-2, mat1) == 1) {
kernel_name += "_var1";
} else {
kernel_name += "_var2";
}
return VK_KERNEL_FROM_STR(kernel_name);
}

utils::uvec3 pick_dynamic_dispatch_global_wg_size(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& additional_args) {
const ValueRef out = args[0].refs[0];

return graph->logical_limits_of(out);
}

utils::uvec3 pick_dynamic_dispatch_local_wg_size(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& additional_args) {
return {64, 1, 1};
}

void resize_dynamic_dispatch_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& additional_args) {
const ValueRef out = args[0].refs[0];
const ValueRef mat1 = args[1].refs[0];

std::vector<int64_t> out_sizes = graph->sizes_of(mat1);
out_sizes.at(out_sizes.size() - 2) = 1;

graph->get_tensor(out)->virtual_resize(out_sizes);
}

void add_dynamic_dispatch_test_node(
ComputeGraph& graph,
const ValueRef mat1,
const ValueRef mat2,
const ValueRef out) {
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
pick_dynamic_dispatch_shader,
pick_dynamic_dispatch_global_wg_size,
pick_dynamic_dispatch_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}},
// Shader params buffers
{},
// Push Constants
{graph.sizes_pc_of(out),
graph.sizes_pc_of(mat1),
graph.sizes_pc_of(mat2)},
// Specialization constants
{},
// Resize Logic
{},
resize_dynamic_dispatch_node));
}

vkcompute::ComputeGraph build_dynamic_dispatch_test_graph(int M, int N) {
using namespace vkcompute;
GraphConfig config;
ComputeGraph graph(config);

vkapi::ScalarType dtype = vkapi::kFloat;
utils::StorageType in_out_stype = utils::kTexture3D;
utils::GPUMemoryLayout memory_layout = utils::kWidthPacked;

std::vector<int64_t> mat1_size = {M, N};
std::vector<int64_t> mat2_size = {M, N};
std::vector<int64_t> out_size = {1, N};

IOValueRef mat1 =
graph.add_input_tensor(mat1_size, dtype, in_out_stype, memory_layout);
IOValueRef mat2{};

mat2.value = graph.add_tensor(mat2_size, dtype, in_out_stype, memory_layout);
mat2.staging = graph.set_input_tensor(mat2.value);

IOValueRef out;
out.value = graph.add_tensor(out_size, dtype, in_out_stype, memory_layout);

add_dynamic_dispatch_test_node(graph, mat1, mat2, out);

out.staging = graph.set_output_tensor(out.value);

return graph;
}

void test_dynamic_dispatch(int M, int N) {
ComputeGraph graph = build_dynamic_dispatch_test_graph(M, N);

graph.prepare();
graph.encode_prepack();
graph.prepack();
graph.encode_execute();

for (int i = 1; i < 4; i++) {
float val_mat1 = i;
float val_mat2 = i + 1;
// 5.3 is a hardcoded offset in the compute shader
float val_out = M * (val_mat1 * val_mat2) + 5.5;
execute_graph_and_check_output(graph, {val_mat1, val_mat2}, {val_out});
}

// Switch to GEMV mode
int new_N = N / 2;
std::vector<int64_t> new_mat1_size = {1, new_N};
std::vector<int64_t> new_mat2_size = {1, new_N};
graph.resize_input(0, new_mat1_size);
graph.resize_input(1, new_mat2_size);
graph.propagate_resize();

graph.encode_execute();

for (int i = 1; i < 4; i++) {
float val_mat1 = i;
float val_mat2 = i + 1;
float val_out = (val_mat1 * val_mat2) + 2.25;
execute_graph_and_check_output(graph, {val_mat1, val_mat2}, {val_out});
}
}

TEST(VulkanComputeGraphOpsTest, test_dynamic_dispatch_graph) {
test_dynamic_dispatch(128, 128);
}
Loading