Skip to content

Commit 78cb141

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Enable additional specialization constants in compute shaders (#3079)
Summary: Pull Request resolved: #3079 ## Context Building on top of the previous changeset in the stack, this changeset modifies shader dispatch APIs to accept additional specialization constants for a shader. ghstack-source-id: 222903463 Reviewed By: copyrightly, jorgep31415 Differential Revision: D56225042 fbshipit-source-id: 154c51f927116e4a658f224794ec354151398a8a
1 parent 0815c2b commit 78cb141

File tree

6 files changed

+105
-6
lines changed

6 files changed

+105
-6
lines changed

backends/vulkan/runtime/api/Context.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ class Context final {
205205
PipelineBarrier&,
206206
const utils::uvec3&,
207207
const utils::uvec3&,
208+
const SpecVarList&,
208209
VkFence fence_handle,
209210
Arguments&&...);
210211

@@ -494,6 +495,7 @@ inline bool Context::submit_compute_job(
494495
PipelineBarrier& pipeline_barrier,
495496
const utils::uvec3& global_work_group,
496497
const utils::uvec3& local_work_group_size,
498+
const SpecVarList& specialization_constants,
497499
VkFence fence_handle,
498500
Arguments&&... arguments) {
499501
// If any of the provided arguments does not have memory associated with it,
@@ -536,8 +538,8 @@ inline bool Context::submit_compute_job(
536538
#endif /* USE_VULKAN_GPU_DIAGNOSTICS */
537539

538540
// Factor out template parameter independent code to minimize code bloat.
539-
DescriptorSet descriptor_set =
540-
get_descriptor_set(shader, local_work_group_size);
541+
DescriptorSet descriptor_set = get_descriptor_set(
542+
shader, local_work_group_size, specialization_constants);
541543

542544
detail::bind(
543545
descriptor_set,

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@ ExecuteNode::ExecuteNode(
2222
const std::vector<ArgGroup>& args,
2323
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
2424
const ResizeFunction& resize_fn,
25-
const std::vector<ValueRef>& resize_args)
25+
const std::vector<ValueRef>& resize_args,
26+
const api::SpecVarList& spec_vars)
2627
: shader_(shader),
2728
global_workgroup_size_(global_workgroup_size),
2829
local_workgroup_size_(local_workgroup_size),
2930
args_(args),
3031
params_(params),
3132
resize_fn_(resize_fn),
32-
resize_args_(resize_args) {
33+
resize_args_(resize_args),
34+
spec_vars_(spec_vars) {
3335
graph.update_descriptor_counts(shader, /*execute = */ true);
3436
}
3537

@@ -40,7 +42,7 @@ void ExecuteNode::encode(ComputeGraph* graph) {
4042
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
4143

4244
api::DescriptorSet descriptor_set =
43-
context->get_descriptor_set(shader_, local_workgroup_size_);
45+
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);
4446

4547
uint32_t idx = 0;
4648
idx = bind_values_to_descriptor_set(

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class ExecuteNode final {
5656
const std::vector<ArgGroup>& args,
5757
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
5858
const ResizeFunction& resize_fn = nullptr,
59-
const std::vector<ValueRef>& resize_args = {});
59+
const std::vector<ValueRef>& resize_args = {},
60+
const api::SpecVarList& spec_vars = {});
6061

6162
~ExecuteNode() = default;
6263

@@ -76,6 +77,7 @@ class ExecuteNode final {
7677
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
7778
const ResizeFunction resize_fn_;
7879
const std::vector<ValueRef> resize_args_;
80+
const api::SpecVarList spec_vars_;
7981
};
8082

8183
} // namespace vkcompute
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
$PRECISION = "highp"
12+
$DTYPE = "float"
13+
14+
#define PRECISION ${PRECISION}
15+
16+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
17+
18+
#include "indexing_utils.h"
19+
20+
layout(std430) buffer;
21+
22+
layout(set = 0, binding = 0) buffer PRECISION restrict writeonly Buffer {
23+
VEC4_T data[];
24+
}
25+
buffer_in;
26+
27+
layout(set = 0, binding = 1) uniform PRECISION restrict Params {
28+
int len;
29+
}
30+
params;
31+
32+
33+
34+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
35+
36+
layout(constant_id = 3) const float scale = 1;
37+
layout(constant_id = 4) const float offset = 0;
38+
39+
void main() {
40+
const int i = ivec3(gl_GlobalInvocationID).x;
41+
42+
const int base = 4 * i;
43+
if (base < params.len) {
44+
buffer_in.data[i] = scale * (VEC4_T(base) + VEC4_T(0, 1, 2, 3)) + offset;
45+
}
46+
}

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ void record_nchw_to_image_op(
2121
api::VulkanBuffer& src_buffer,
2222
vTensor& v_dst) {
2323
api::PipelineBarrier pipeline_barrier{};
24+
api::SpecVarList specialization_constants = {};
2425

2526
context->submit_compute_job(
2627
get_nchw_to_image_shader(v_dst),
2728
pipeline_barrier,
2829
v_dst.virtual_extents(),
2930
adaptive_work_group_size(v_dst.virtual_extents()),
31+
specialization_constants,
3032
VK_NULL_HANDLE,
3133
v_dst.image(
3234
pipeline_barrier,
@@ -42,11 +44,14 @@ void record_image_to_nchw_op(
4244
vTensor& v_src,
4345
api::VulkanBuffer& dst_buffer) {
4446
api::PipelineBarrier pipeline_barrier{};
47+
api::SpecVarList specialization_constants = {};
48+
4549
context->submit_compute_job(
4650
get_image_to_nchw_shader(v_src),
4751
pipeline_barrier,
4852
v_src.virtual_extents(),
4953
adaptive_work_group_size(v_src.virtual_extents()),
54+
specialization_constants,
5055
VK_NULL_HANDLE,
5156
v_src.image(pipeline_barrier, api::PipelineStage::COMPUTE),
5257
dst_buffer,
@@ -78,11 +83,13 @@ void record_conv2d_prepack_weights_op(
7883
api::UniformParamsBuffer padded_sizes_ubo(
7984
context, api::utils::make_ivec2(padded_sizes, /*reverse = */ true));
8085

86+
api::SpecVarList specialization_constants = {};
8187
context->submit_compute_job(
8288
shader,
8389
pipeline_barrier,
8490
v_dst.virtual_extents(),
8591
adaptive_work_group_size(v_dst.virtual_extents()),
92+
specialization_constants,
8693
VK_NULL_HANDLE,
8794
v_dst.image(
8895
pipeline_barrier,
@@ -104,11 +111,13 @@ void record_binary_op(
104111
add_dtype_suffix(kernel_name, v_dst);
105112

106113
api::PipelineBarrier pipeline_barrier{};
114+
api::SpecVarList specialization_constants = {};
107115
context->submit_compute_job(
108116
VK_KERNEL_FROM_STR(kernel_name),
109117
pipeline_barrier,
110118
v_dst.virtual_extents(),
111119
adaptive_work_group_size(v_dst.virtual_extents()),
120+
specialization_constants,
112121
VK_NULL_HANDLE,
113122
v_dst.image(
114123
pipeline_barrier,

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,38 @@ TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
9797
ASSERT_TRUE(*(reinterpret_cast<const float*>(copy_data + 6)) == 5.5f);
9898
}
9999

100+
TEST_F(VulkanComputeAPITest, spec_var_shader_test) {
101+
size_t len = 16;
102+
api::StorageBuffer buffer(api::context(), api::kFloat, len);
103+
104+
float scale = 3.0f;
105+
float offset = 1.5f;
106+
107+
{
108+
api::UniformParamsBuffer params(api::context(), int32_t(len));
109+
uint32_t len_div4 = api::utils::div_up(uint32_t(len), uint32_t(4));
110+
api::PipelineBarrier pipeline_barrier{};
111+
api::context()->submit_compute_job(
112+
VK_KERNEL(fill_buffer),
113+
pipeline_barrier,
114+
{64, 1, 1},
115+
{len_div4, 1, 1},
116+
{SV(scale), SV(offset)},
117+
VK_NULL_HANDLE,
118+
buffer.buffer(),
119+
params.buffer());
120+
}
121+
122+
submit_to_gpu();
123+
124+
std::vector<float> data(len);
125+
copy_staging_to_ptr(buffer, data.data(), buffer.nbytes());
126+
127+
for (size_t i = 0; i < len; ++i) {
128+
CHECK_VALUE(data, i, scale * i + offset);
129+
}
130+
}
131+
100132
TEST_F(VulkanComputeAPITest, update_params_between_submit) {
101133
api::context()->set_cmd(/*reusable = */ true);
102134
std::vector<int64_t> sizes = {4, 4, 2};
@@ -121,11 +153,13 @@ TEST_F(VulkanComputeAPITest, update_params_between_submit) {
121153

122154
{
123155
api::PipelineBarrier pipeline_barrier{};
156+
api::SpecVarList specialization_constants = {};
124157
api::context()->submit_compute_job(
125158
VK_KERNEL_FROM_STR(kernel_name),
126159
pipeline_barrier,
127160
{4, 4, 4},
128161
{4, 4, 4},
162+
specialization_constants,
129163
VK_NULL_HANDLE,
130164
a.image(
131165
pipeline_barrier,
@@ -180,11 +214,13 @@ void test_storage_buffer_type(const size_t len) {
180214
{
181215
uint32_t len_div4 = api::utils::div_up(uint32_t(len), uint32_t(4));
182216
api::PipelineBarrier pipeline_barrier{};
217+
api::SpecVarList specialization_constants = {};
183218
api::context()->submit_compute_job(
184219
VK_KERNEL_FROM_STR(kernel_name),
185220
pipeline_barrier,
186221
{64, 1, 1},
187222
{len_div4, 1, 1},
223+
specialization_constants,
188224
VK_NULL_HANDLE,
189225
buffer.buffer(),
190226
params.buffer());
@@ -875,11 +911,13 @@ void run_from_gpu_test(
875911

876912
{
877913
api::PipelineBarrier pipeline_barrier{};
914+
api::SpecVarList specialization_constants = {};
878915
api::context()->submit_compute_job(
879916
VK_KERNEL_FROM_STR(kernel_name),
880917
pipeline_barrier,
881918
vten.virtual_extents(),
882919
{4, 4, 4},
920+
specialization_constants,
883921
VK_NULL_HANDLE,
884922
vten.image(
885923
pipeline_barrier,

0 commit comments

Comments
 (0)