Skip to content

Commit 0f65aed

Browse files
committed
[ET-VK] Enable additional specialization constants in compute shaders
## Context Building on top of the previous changeset in the stack, this changeset modifies shader dispatch APIs to accept additional specialization constants for a shader. Differential Revision: [D56225042](https://our.internmc.facebook.com/intern/diff/D56225042/) ghstack-source-id: 222806317 Pull Request resolved: #3079
1 parent 478c8f0 commit 0f65aed

File tree

6 files changed

+96
-5
lines changed

6 files changed

+96
-5
lines changed

backends/vulkan/runtime/api/Context.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ class Context final {
205205
PipelineBarrier&,
206206
const utils::uvec3&,
207207
const utils::uvec3&,
208+
const SpecVarList&,
208209
VkFence fence_handle,
209210
Arguments&&...);
210211

@@ -494,6 +495,7 @@ inline bool Context::submit_compute_job(
494495
PipelineBarrier& pipeline_barrier,
495496
const utils::uvec3& global_work_group,
496497
const utils::uvec3& local_work_group_size,
498+
const SpecVarList& specialization,
497499
VkFence fence_handle,
498500
Arguments&&... arguments) {
499501
// If any of the provided arguments does not have memory associated with it,
@@ -537,7 +539,7 @@ inline bool Context::submit_compute_job(
537539

538540
// Factor out template parameter independent code to minimize code bloat.
539541
DescriptorSet descriptor_set =
540-
get_descriptor_set(shader, local_work_group_size);
542+
get_descriptor_set(shader, local_work_group_size, specialization);
541543

542544
detail::bind(
543545
descriptor_set,

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@ ExecuteNode::ExecuteNode(
2222
const std::vector<ArgGroup>& args,
2323
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
2424
const ResizeFunction& resize_fn,
25-
const std::vector<ValueRef>& resize_args)
25+
const std::vector<ValueRef>& resize_args,
26+
const api::SpecVarList& spec_vars)
2627
: shader_(shader),
2728
global_workgroup_size_(global_workgroup_size),
2829
local_workgroup_size_(local_workgroup_size),
2930
args_(args),
3031
params_(params),
3132
resize_fn_(resize_fn),
32-
resize_args_(resize_args) {
33+
resize_args_(resize_args),
34+
spec_vars_(spec_vars) {
3335
graph.update_descriptor_counts(shader, /*execute = */ true);
3436
}
3537

@@ -40,7 +42,7 @@ void ExecuteNode::encode(ComputeGraph* graph) {
4042
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
4143

4244
api::DescriptorSet descriptor_set =
43-
context->get_descriptor_set(shader_, local_workgroup_size_);
45+
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);
4446

4547
uint32_t idx = 0;
4648
idx = bind_values_to_descriptor_set(

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class ExecuteNode final {
5656
const std::vector<ArgGroup>& args,
5757
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
5858
const ResizeFunction& resize_fn = nullptr,
59-
const std::vector<ValueRef>& resize_args = {});
59+
const std::vector<ValueRef>& resize_args = {},
60+
const api::SpecVarList& spec_vars = {});
6061

6162
~ExecuteNode() = default;
6263

@@ -76,6 +77,7 @@ class ExecuteNode final {
7677
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
7778
const ResizeFunction resize_fn_;
7879
const std::vector<ValueRef> resize_args_;
80+
const api::SpecVarList spec_vars_;
7981
};
8082

8183
} // namespace vkcompute
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
$PRECISION = "highp"
12+
$DTYPE = "float"
13+
14+
#define PRECISION ${PRECISION}
15+
16+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
17+
18+
#include "indexing_utils.h"
19+
20+
layout(std430) buffer;
21+
22+
layout(set = 0, binding = 0) buffer PRECISION restrict writeonly Buffer {
23+
VEC4_T data[];
24+
}
25+
buffer_in;
26+
27+
layout(set = 0, binding = 1) uniform PRECISION restrict Params {
28+
int len;
29+
}
30+
params;
31+
32+
33+
34+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
35+
36+
layout(constant_id = 3) const float scale = 1;
37+
layout(constant_id = 4) const float offset = 0;
38+
39+
void main() {
40+
const int i = ivec3(gl_GlobalInvocationID).x;
41+
42+
const int base = 4 * i;
43+
if (base < params.len) {
44+
buffer_in.data[i] = scale * (VEC4_T(base) + VEC4_T(0, 1, 2, 3)) + offset;
45+
}
46+
}

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void record_nchw_to_image_op(
2727
pipeline_barrier,
2828
v_dst.virtual_extents(),
2929
adaptive_work_group_size(v_dst.virtual_extents()),
30+
{},
3031
VK_NULL_HANDLE,
3132
v_dst.image(
3233
pipeline_barrier,
@@ -47,6 +48,7 @@ void record_image_to_nchw_op(
4748
pipeline_barrier,
4849
v_src.virtual_extents(),
4950
adaptive_work_group_size(v_src.virtual_extents()),
51+
{},
5052
VK_NULL_HANDLE,
5153
v_src.image(pipeline_barrier, api::PipelineStage::COMPUTE),
5254
dst_buffer,
@@ -83,6 +85,7 @@ void record_conv2d_prepack_weights_op(
8385
pipeline_barrier,
8486
v_dst.virtual_extents(),
8587
adaptive_work_group_size(v_dst.virtual_extents()),
88+
{},
8689
VK_NULL_HANDLE,
8790
v_dst.image(
8891
pipeline_barrier,
@@ -109,6 +112,7 @@ void record_binary_op(
109112
pipeline_barrier,
110113
v_dst.virtual_extents(),
111114
adaptive_work_group_size(v_dst.virtual_extents()),
115+
{},
112116
VK_NULL_HANDLE,
113117
v_dst.image(
114118
pipeline_barrier,

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,38 @@ TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
9696
ASSERT_TRUE(*(reinterpret_cast<const float*>(copy_data + 6)) == 5.5f);
9797
}
9898

99+
TEST_F(VulkanComputeAPITest, spec_var_shader_test) {
100+
size_t len = 16;
101+
api::StorageBuffer buffer(api::context(), api::kFloat, len);
102+
103+
float scale = 3.0f;
104+
float offset = 1.5f;
105+
106+
{
107+
api::UniformParamsBuffer params(api::context(), int32_t(len));
108+
uint32_t len_div4 = api::utils::div_up(uint32_t(len), uint32_t(4));
109+
api::PipelineBarrier pipeline_barrier{};
110+
api::context()->submit_compute_job(
111+
VK_KERNEL(fill_buffer),
112+
pipeline_barrier,
113+
{64, 1, 1},
114+
{len_div4, 1, 1},
115+
{SV(scale), SV(offset)},
116+
VK_NULL_HANDLE,
117+
buffer.buffer(),
118+
params.buffer());
119+
}
120+
121+
submit_to_gpu();
122+
123+
std::vector<float> data(len);
124+
copy_staging_to_ptr(buffer, data.data(), buffer.nbytes());
125+
126+
for (size_t i = 0; i < len; ++i) {
127+
CHECK_VALUE(data, i, scale * i + offset);
128+
}
129+
}
130+
99131
TEST_F(VulkanComputeAPITest, update_params_between_submit) {
100132
api::context()->set_cmd(/*reusable = */ true);
101133
std::vector<int64_t> sizes = {4, 4, 2};
@@ -125,6 +157,7 @@ TEST_F(VulkanComputeAPITest, update_params_between_submit) {
125157
pipeline_barrier,
126158
{4, 4, 4},
127159
{4, 4, 4},
160+
{},
128161
VK_NULL_HANDLE,
129162
a.image(
130163
pipeline_barrier,
@@ -184,6 +217,7 @@ void test_storage_buffer_type(const size_t len) {
184217
pipeline_barrier,
185218
{64, 1, 1},
186219
{len_div4, 1, 1},
220+
{},
187221
VK_NULL_HANDLE,
188222
buffer.buffer(),
189223
params.buffer());
@@ -879,6 +913,7 @@ void run_from_gpu_test(
879913
pipeline_barrier,
880914
vten.virtual_extents(),
881915
{4, 4, 4},
916+
{},
882917
VK_NULL_HANDLE,
883918
vten.image(
884919
pipeline_barrier,

0 commit comments

Comments
 (0)