Skip to content

Commit 10bd491

Browse files
[ET-VK] Add PushConstantDataInfo and vector to hold push constants data in DispatchNode.
Pull Request resolved: #7223 This diff adds a new class called `PushConstantDataInfo` to the `DispatchNode` class in the Vulkan backend for Executorch. This class represents a push constant data entry, which can either be a shared pointer to a tensor's uniform data with an attribute or data with a maximum size of 16 bytes. The `write` method is also added to this class, which writes the data to a destination buffer. ghstack-source-id: 257899587 @exported-using-ghexport Differential Revision: [D66796049](https://our.internmc.facebook.com/intern/diff/D66796049/) Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 6158674 commit 10bd491

File tree

2 files changed

+83
-4
lines changed

2 files changed

+83
-4
lines changed

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,22 @@
1414

1515
namespace vkcompute {
1616

17+
uint32_t PushConstantDataInfo::write(
18+
void* dst,
19+
const uint32_t dst_offset,
20+
const uint32_t max_dst_size) const {
21+
if (tensorUniformData != nullptr) {
22+
return tensorUniformData->write_attribute(
23+
dst, dst_offset, max_dst_size, payload_.attr);
24+
}
25+
26+
VK_CHECK_COND(
27+
(dst_offset + payload_.dataSize) <= max_dst_size,
28+
"Attempting to write push constant data outside data boundary.");
29+
memcpy((uint8_t*)dst + dst_offset, payload_.data, payload_.dataSize);
30+
return payload_.dataSize;
31+
}
32+
1733
DispatchNode::DispatchNode(
1834
ComputeGraph& graph,
1935
const vkapi::ShaderInfo& shader,
@@ -23,13 +39,15 @@ DispatchNode::DispatchNode(
2339
const vkapi::ParamsBindList& params,
2440
const vkapi::SpecVarList& spec_vars,
2541
const ResizeFunction& resize_fn,
26-
const std::vector<ValueRef>& resize_args)
42+
const std::vector<ValueRef>& resize_args,
43+
const std::vector<PushConstantDataInfo>& push_constants)
2744
: ExecuteNode(resize_fn, resize_args, args, shader.kernel_name),
2845
shader_(shader),
2946
global_workgroup_size_(global_workgroup_size),
3047
local_workgroup_size_(local_workgroup_size),
3148
params_(params),
32-
spec_vars_(spec_vars) {
49+
spec_vars_(spec_vars),
50+
push_constants_(push_constants) {
3351
graph.update_descriptor_counts(shader, /*execute = */ true);
3452
}
3553

@@ -57,8 +75,22 @@ void DispatchNode::encode(ComputeGraph* graph) {
5775

5876
bind_params_to_descriptor_set(params_, descriptor_set, idx);
5977

78+
std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
79+
uint32_t push_constants_offset = 0;
80+
81+
for (const auto& push_constant : push_constants_) {
82+
push_constants_offset += push_constant.write(
83+
push_constants_data.data(),
84+
push_constants_offset,
85+
kMaxPushConstantSize);
86+
}
6087
context->register_shader_dispatch(
61-
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
88+
descriptor_set,
89+
pipeline_barrier,
90+
shader_,
91+
global_workgroup_size_,
92+
push_constants_data.data(),
93+
push_constants_offset);
6294

6395
context->report_shader_dispatch_end();
6496
}

backends/vulkan/runtime/graph/ops/DispatchNode.h

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,51 @@ namespace vkcompute {
1818

1919
class ComputeGraph;
2020

21+
constexpr uint32_t kMaxPushConstantSize = 128;
22+
/*
23+
* Represents a push constant data entry
24+
* Which is either shared pointer to a tensor's uniform data with an attribute
25+
* Or data with a maximum size of 16 bytes
26+
*/
27+
class PushConstantDataInfo {
28+
std::shared_ptr<api::vTensor::UniformData> tensorUniformData;
29+
union Payload {
30+
struct {
31+
api::vTensor::Attribute attr;
32+
};
33+
struct {
34+
uint8_t data[16];
35+
uint32_t dataSize;
36+
};
37+
};
38+
39+
Payload payload_;
40+
41+
public:
42+
explicit PushConstantDataInfo(
43+
const std::shared_ptr<api::vTensor::UniformData>& tensorUniformData,
44+
api::vTensor::Attribute attr)
45+
: tensorUniformData(tensorUniformData) {
46+
payload_.attr = attr;
47+
}
48+
49+
explicit PushConstantDataInfo(const void* data, uint32_t dataLen)
50+
: tensorUniformData(nullptr) {
51+
VK_CHECK_COND(
52+
dataLen <= 16, "Single push constant data size must be <= 16 bytes");
53+
payload_.dataSize = dataLen;
54+
memcpy(payload_.data, data, payload_.dataSize);
55+
}
56+
57+
/*
58+
* Function writes push constant data to the destination buffer
59+
*/
60+
uint32_t write(
61+
void* dst,
62+
const uint32_t dst_offset,
63+
const uint32_t max_dst_size) const;
64+
};
65+
2166
/*
2267
* Represents a single shader execution op in a ML model.
2368
*/
@@ -34,7 +79,8 @@ class DispatchNode final : public ExecuteNode {
3479
const vkapi::ParamsBindList& params,
3580
const vkapi::SpecVarList& spec_vars = {},
3681
const ResizeFunction& resize_fn = nullptr,
37-
const std::vector<ValueRef>& resize_args = {});
82+
const std::vector<ValueRef>& resize_args = {},
83+
const std::vector<PushConstantDataInfo>& push_constants = {});
3884

3985
~DispatchNode() override = default;
4086

@@ -46,6 +92,7 @@ class DispatchNode final : public ExecuteNode {
4692
const utils::uvec3 local_workgroup_size_;
4793
const vkapi::ParamsBindList params_;
4894
const vkapi::SpecVarList spec_vars_;
95+
const std::vector<PushConstantDataInfo> push_constants_;
4996

5097
public:
5198
operator bool() const {

0 commit comments

Comments
 (0)