Skip to content

Commit 7ce2faa

Browse files
committed
Update on "[ET-VK] Introduce DynamicDispatchNode"
## Context The `DynamicDispatchNode` class in introduced in this diff to allow for shader re-selection upon input resize. See the previous diff in the stack for more context on why this functionality is needed. Differential Revision: [D75013780](https://our.internmc.facebook.com/intern/diff/D75013780/) [ghstack-poisoned]
2 parents e4fd85d + 6b631b0 commit 7ce2faa

File tree

4 files changed

+9
-13
lines changed

4 files changed

+9
-13
lines changed

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212

1313
#include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h>
1414

15-
#include <iostream>
16-
1715
namespace vkcompute {
1816

1917
DispatchNode::DispatchNode(
@@ -41,8 +39,6 @@ void DispatchNode::encode(ComputeGraph* graph) {
4139
if (!shader_) {
4240
return;
4341
}
44-
std::cout << "dynamically dispatching... " << shader_.kernel_name
45-
<< std::endl;
4642
api::Context* const context = graph->context();
4743
vkapi::PipelineBarrier pipeline_barrier{};
4844

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ DynamicDispatchNode::DynamicDispatchNode(
3434
pick_local_wg_fn(&graph, args, resize_args),
3535
args,
3636
params,
37+
push_constants,
3738
spec_vars,
38-
resize_fn,
3939
resize_args,
40-
push_constants),
40+
resize_fn),
4141
pick_shader_fn_(pick_shader_fn),
4242
pick_global_wg_fn_(pick_global_wg_fn),
4343
pick_local_wg_fn_(pick_local_wg_fn) {}

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ class DynamicDispatchNode final : public DispatchNode {
2929
using PickShaderFn = const std::function<vkapi::ShaderInfo(
3030
ComputeGraph*,
3131
const std::vector<ArgGroup>&,
32-
const std::vector<ValueRef>)>;
32+
const std::vector<ValueRef>&)>;
3333
using PickGlobalFn = const std::function<utils::uvec3(
3434
ComputeGraph*,
3535
const std::vector<ArgGroup>&,
36-
const std::vector<ValueRef>)>;
36+
const std::vector<ValueRef>&)>;
3737
using PickLocalFn = const std::function<utils::uvec3(
3838
ComputeGraph*,
3939
const std::vector<ArgGroup>&,
40-
const std::vector<ValueRef>)>;
40+
const std::vector<ValueRef>&)>;
4141

4242
explicit DynamicDispatchNode(
4343
ComputeGraph& graph,

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3301,7 +3301,7 @@ TEST(VulkanComputeGraphOpsTest, test_to_copy) {
33013301
vkapi::ShaderInfo pick_dynamic_dispatch_shader(
33023302
ComputeGraph* graph,
33033303
const std::vector<ArgGroup>& args,
3304-
const std::vector<ValueRef> additional_args) {
3304+
const std::vector<ValueRef>& additional_args) {
33053305
const ValueRef mat1 = args[1].refs[0];
33063306

33073307
std::string kernel_name = "dynamic_dispatch_test";
@@ -3316,7 +3316,7 @@ vkapi::ShaderInfo pick_dynamic_dispatch_shader(
33163316
utils::uvec3 pick_dynamic_dispatch_global_wg_size(
33173317
ComputeGraph* graph,
33183318
const std::vector<ArgGroup>& args,
3319-
const std::vector<ValueRef> additional_args) {
3319+
const std::vector<ValueRef>& additional_args) {
33203320
const ValueRef out = args[0].refs[0];
33213321

33223322
return graph->logical_limits_of(out);
@@ -3325,14 +3325,14 @@ utils::uvec3 pick_dynamic_dispatch_global_wg_size(
33253325
utils::uvec3 pick_dynamic_dispatch_local_wg_size(
33263326
ComputeGraph* graph,
33273327
const std::vector<ArgGroup>& args,
3328-
const std::vector<ValueRef> additional_args) {
3328+
const std::vector<ValueRef>& additional_args) {
33293329
return {64, 1, 1};
33303330
}
33313331

33323332
void resize_dynamic_dispatch_node(
33333333
ComputeGraph* graph,
33343334
const std::vector<ArgGroup>& args,
3335-
const std::vector<ValueRef> additional_args) {
3335+
const std::vector<ValueRef>& additional_args) {
33363336
const ValueRef out = args[0].refs[0];
33373337
const ValueRef mat1 = args[1].refs[0];
33383338

0 commit comments

Comments
 (0)