Skip to content

Implement Graph node #6037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/Logging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ void ComputeGraph::print_readable() {
size_t node_idx = 0;
for (const std::unique_ptr<ExecuteNode>& node : execute_nodes()) {
std::cout << std::setw(6) << node_idx;
std::cout << std::setw(32) << node->shader_.kernel_name;
std::cout << std::setw(32) << node->name();

std::stringstream read_s;
for (const ArgGroup& arg_group : node->args_) {
Expand Down
66 changes: 66 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h>

namespace vkcompute {

DispatchNode::DispatchNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const vkapi::SpecVarList& spec_vars,
const ResizeFunction& resize_fn,
const std::vector<ValueRef>& resize_args)
: ExecuteNode(resize_fn, resize_args, args, shader.kernel_name),
shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
params_(params),
spec_vars_(spec_vars) {
graph.update_descriptor_counts(shader, /*execute = */ true);
}

void DispatchNode::encode(ComputeGraph* graph) {
if (!shader_) {
return;
}
api::Context* const context = graph->context();
vkapi::PipelineBarrier pipeline_barrier{};

std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

context->report_shader_dispatch_start(
shader_.kernel_name,
global_workgroup_size_,
local_workgroup_size_,
node_id_);

vkapi::DescriptorSet descriptor_set =
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);

uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
graph, args_, pipeline_barrier, descriptor_set, idx);

bind_params_to_descriptor_set(params_, descriptor_set, idx);

context->register_shader_dispatch(
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);

context->report_shader_dispatch_end();
}

} // namespace vkcompute
56 changes: 56 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/vulkan/runtime/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>

#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>

namespace vkcompute {

class ComputeGraph;

/*
* Represents a single shader execution op in a ML model.
*/
class DispatchNode final : public ExecuteNode {
friend class ComputeGraph;

public:
explicit DispatchNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const vkapi::SpecVarList& spec_vars = {},
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});

~DispatchNode() override = default;

void encode(ComputeGraph* graph) override;

protected:
const vkapi::ShaderInfo shader_;
const utils::uvec3 global_workgroup_size_;
const utils::uvec3 local_workgroup_size_;
const vkapi::ParamsBindList params_;
const vkapi::SpecVarList spec_vars_;

public:
operator bool() const {
return shader_;
}
};

} // namespace vkcompute
71 changes: 6 additions & 65 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,73 +8,14 @@

#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h>

namespace vkcompute {

ExecuteNode::ExecuteNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const vkapi::SpecVarList& spec_vars,
const ResizeFunction& resize_fn,
const std::vector<ValueRef>& resize_args)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
const std::vector<ValueRef>& resize_args,
const std::vector<ArgGroup>& args,
const std::string& name)
: resize_fn_(resize_fn),
resize_args_(resize_args),
args_(args),
params_(params),
spec_vars_(spec_vars),
resize_fn_(resize_fn),
resize_args_(resize_args) {
graph.update_descriptor_counts(shader, /*execute = */ true);
}

ExecuteNode::ExecuteNode(
const ResizeFunction& resize_fn,
const std::vector<ValueRef>& resize_args)
: shader_(),
global_workgroup_size_({0u, 0u, 0u}),
local_workgroup_size_({0u, 0u, 0u}),
args_(),
params_(),
spec_vars_(),
resize_fn_(resize_fn),
resize_args_(resize_args) {}

void ExecuteNode::encode(ComputeGraph* graph) {
if (!shader_) {
return;
}
api::Context* const context = graph->context();
vkapi::PipelineBarrier pipeline_barrier{};

std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

context->report_shader_dispatch_start(
shader_.kernel_name,
global_workgroup_size_,
local_workgroup_size_,
node_id_);

vkapi::DescriptorSet descriptor_set =
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);

uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
graph, args_, pipeline_barrier, descriptor_set, idx);

bind_params_to_descriptor_set(params_, descriptor_set, idx);

context->register_shader_dispatch(
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);

context->report_shader_dispatch_end();
}

name_(name) {}
} // namespace vkcompute
42 changes: 15 additions & 27 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct ArgGroup {
* encoding of the shader corresponding to the op into the command buffer of a
* ComputeGraph.
*/
class ExecuteNode final {
class ExecuteNode {
friend class ComputeGraph;

public:
Expand All @@ -48,29 +48,22 @@ class ExecuteNode final {
const std::vector<ArgGroup>&,
const std::vector<ValueRef>&)>;

explicit ExecuteNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const vkapi::SpecVarList& spec_vars = {},
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});

/*
* This overload of the ExecuteNode constructor is used to register ops which
* This overload of the DispatchNode constructor is used to register ops which
* update a tensor view. No shader is dispatched, but the node still needs to
* update the view's sizes and strides after a resize.
*/
explicit ExecuteNode(
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});
const std::vector<ValueRef>& resize_args = {},
const std::vector<ArgGroup>& args = {},
const std::string& name = "Graph Node");

~ExecuteNode() = default;
virtual ~ExecuteNode() = default;

void encode(ComputeGraph* graph);
virtual void encode(ComputeGraph* graph) {
(void)graph;
}

inline void trigger_resize(ComputeGraph* graph) {
if (resize_fn_ != nullptr) {
Expand All @@ -82,21 +75,16 @@ class ExecuteNode final {
node_id_ = node_id;
}

inline const std::string& name() const {
return name_;
}

protected:
uint32_t node_id_;
const vkapi::ShaderInfo shader_;
const utils::uvec3 global_workgroup_size_;
const utils::uvec3 local_workgroup_size_;
const std::vector<ArgGroup> args_;
const vkapi::ParamsBindList params_;
const vkapi::SpecVarList spec_vars_;
const ResizeFunction resize_fn_;
const std::vector<ValueRef> resize_args_;

public:
operator bool() const {
return shader_;
}
const std::vector<ArgGroup> args_;
const std::string name_;
};

} // namespace vkcompute
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/OperatorRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>

#include <functional>
#include <unordered_map>
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/PrepackNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Arange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void add_arange_node(
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void add_native_batch_norm_node(
int32_t num_texel_per_batch =
utils::div_up_4((dim_at<kChannel4D>(t_in->sizes())));

graph.execute_nodes().emplace_back(new ExecuteNode(
graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out_ref),
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void add_binary_op_node(
kernel_name += op_name;
add_dtype_suffix(kernel_name, *t_out);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Clone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void add_clone_node(
std::string kernel_name = "clone";
add_dtype_suffix(kernel_name, *t_out);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ void add_conv2d_node(
vkapi::ShaderInfo shader = get_conv2d_shader(
graph, *t_out, /*prepack_weights = */ false, method, weight, clamp_out);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph.execute_nodes().emplace_back(new DispatchNode(
graph,
shader,
create_conv2d_global_wg_size(graph, method, out),
Expand Down Expand Up @@ -464,7 +464,7 @@ void add_conv1d_node(

add_dtype_suffix(kernel_name, *t_out);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_size,
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void add_copy_offset_node(

auto shader = VK_KERNEL_FROM_STR(kernel_name);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
Expand Down Expand Up @@ -155,7 +155,7 @@ void add_copy_channel_offset_node(

auto shader = VK_KERNEL_FROM_STR(kernel_name);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_size,
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void add_embedding_node(
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Flip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void add_flip_node(
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
Expand Down
Loading
Loading