Skip to content

Commit f95e6cb

Browse files
committed
[ET-VK][ez] Updates to DynamicDispatchNode
Pull Request resolved: #11254 ## Changes For `DynamicDispatchNode`: * Pass in global work group size to the local work group size determination function * Add additional constructor for which the shader is not dynamic * During `encode`, check that pick functions are not `nullptr` ## Motivation Oftentimes it is useful to know what the global work group size is when determining what the local group group size should be. ## Performance Impact None. ghstack-source-id: 287711100 @exported-using-ghexport Differential Revision: [D75686047](https://our.internmc.facebook.com/intern/diff/D75686047/)
1 parent 167c063 commit f95e6cb

File tree

3 files changed

+76
-10
lines changed

3 files changed

+76
-10
lines changed

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ DynamicDispatchNode::DynamicDispatchNode(
2525
const ResizeFunction& resize_fn)
2626
: DispatchNode(
2727
graph,
28-
pick_shader_fn(&graph, args, resize_args),
29-
pick_global_wg_fn(&graph, args, resize_args),
30-
pick_local_wg_fn(&graph, args, resize_args),
28+
vkapi::ShaderInfo(),
29+
{1u, 1u, 1u},
30+
{1u, 1u, 1u},
3131
args,
3232
params,
3333
push_constants,
@@ -36,13 +36,57 @@ DynamicDispatchNode::DynamicDispatchNode(
3636
resize_fn),
3737
pick_shader_fn_(pick_shader_fn),
3838
pick_global_wg_fn_(pick_global_wg_fn),
39+
pick_local_wg_fn_(pick_local_wg_fn) {
40+
shader_ = pick_shader_fn(&graph, args, resize_args);
41+
global_workgroup_size_ =
42+
pick_global_wg_fn(&graph, shader_, args, resize_args);
43+
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn(
44+
&graph, shader_, global_workgroup_size_, args, resize_args));
45+
}
46+
47+
DynamicDispatchNode::DynamicDispatchNode(
48+
ComputeGraph& graph,
49+
const vkapi::ShaderInfo& shader,
50+
const PickGlobalFn& pick_global_wg_fn,
51+
const PickLocalFn& pick_local_wg_fn,
52+
const std::vector<ArgGroup>& args,
53+
const vkapi::ParamsBindList& params,
54+
const std::vector<PushConstantDataInfo>& push_constants,
55+
const vkapi::SpecVarList& spec_vars,
56+
const std::vector<ValueRef>& resize_args,
57+
const ResizeFunction& resize_fn)
58+
: DispatchNode(
59+
graph,
60+
shader,
61+
pick_global_wg_fn(&graph, shader, args, resize_args),
62+
pick_local_wg_fn(
63+
&graph,
64+
shader,
65+
pick_global_wg_fn(&graph, shader, args, resize_args),
66+
args,
67+
resize_args),
68+
args,
69+
params,
70+
push_constants,
71+
spec_vars,
72+
resize_args,
73+
resize_fn),
74+
pick_shader_fn_{nullptr},
75+
pick_global_wg_fn_(pick_global_wg_fn),
3976
pick_local_wg_fn_(pick_local_wg_fn) {}
4077

4178
void DynamicDispatchNode::encode(ComputeGraph* graph) {
42-
shader_ = pick_shader_fn_(graph, args_, resize_args_);
43-
global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_);
44-
local_workgroup_size_ =
45-
utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_));
79+
if (pick_shader_fn_) {
80+
shader_ = pick_shader_fn_(graph, args_, resize_args_);
81+
}
82+
if (pick_global_wg_fn_) {
83+
global_workgroup_size_ =
84+
pick_global_wg_fn_(graph, shader_, args_, resize_args_);
85+
}
86+
if (pick_local_wg_fn_) {
87+
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn_(
88+
graph, shader_, global_workgroup_size_, args_, resize_args_));
89+
}
4690
DispatchNode::encode(graph);
4791
}
4892

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ class DynamicDispatchNode final : public DispatchNode {
3232
const std::vector<ValueRef>&)>;
3333
using PickGlobalFn = const std::function<utils::uvec3(
3434
ComputeGraph*,
35+
const vkapi::ShaderInfo& shader,
3536
const std::vector<ArgGroup>&,
3637
const std::vector<ValueRef>&)>;
3738
using PickLocalFn = const std::function<utils::uvec3(
3839
ComputeGraph*,
40+
const vkapi::ShaderInfo& shader,
41+
const utils::uvec3& global_workgroup_size,
3942
const std::vector<ArgGroup>&,
4043
const std::vector<ValueRef>&)>;
4144

@@ -51,6 +54,18 @@ class DynamicDispatchNode final : public DispatchNode {
5154
const std::vector<ValueRef>& resize_args,
5255
const ResizeFunction& resize_fn = nullptr);
5356

57+
explicit DynamicDispatchNode(
58+
ComputeGraph& graph,
59+
const vkapi::ShaderInfo& shader,
60+
const PickGlobalFn& pick_global_wg_fn,
61+
const PickLocalFn& pick_local_wg_fn,
62+
const std::vector<ArgGroup>& args,
63+
const vkapi::ParamsBindList& params,
64+
const std::vector<PushConstantDataInfo>& push_constants,
65+
const vkapi::SpecVarList& spec_vars,
66+
const std::vector<ValueRef>& resize_args,
67+
const ResizeFunction& resize_fn = nullptr);
68+
5469
~DynamicDispatchNode() override = default;
5570

5671
void encode(ComputeGraph* graph) override;

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <gtest/gtest.h>
1010

1111
#include <bitset>
12+
#include <iomanip>
1213
#include <utility>
1314
#include <vector>
1415

@@ -3314,17 +3315,23 @@ vkapi::ShaderInfo pick_dynamic_dispatch_shader(
33143315

33153316
utils::uvec3 pick_dynamic_dispatch_global_wg_size(
33163317
ComputeGraph* graph,
3318+
const vkapi::ShaderInfo& shader,
33173319
const std::vector<ArgGroup>& args,
3318-
const std::vector<ValueRef>& additional_args) {
3320+
const std::vector<ValueRef>& resize_args) {
3321+
(void)shader;
33193322
const ValueRef out = args[0].refs[0];
3320-
33213323
return graph->logical_limits_of(out);
33223324
}
33233325

33243326
utils::uvec3 pick_dynamic_dispatch_local_wg_size(
33253327
ComputeGraph* graph,
3328+
const vkapi::ShaderInfo& shader,
3329+
const utils::uvec3& global_workgroup_size,
33263330
const std::vector<ArgGroup>& args,
3327-
const std::vector<ValueRef>& additional_args) {
3331+
const std::vector<ValueRef>& resize_args) {
3332+
(void)graph;
3333+
(void)shader;
3334+
(void)global_workgroup_size;
33283335
return {64, 1, 1};
33293336
}
33303337

0 commit comments

Comments
 (0)