Skip to content

[ET-VK] Enable additional specialization constants in compute shaders #3079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/vulkan/runtime/api/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class Context final {
PipelineBarrier&,
const utils::uvec3&,
const utils::uvec3&,
const SpecVarList&,
VkFence fence_handle,
Arguments&&...);

Expand Down Expand Up @@ -494,6 +495,7 @@ inline bool Context::submit_compute_job(
PipelineBarrier& pipeline_barrier,
const utils::uvec3& global_work_group,
const utils::uvec3& local_work_group_size,
const SpecVarList& specialization_constants,
VkFence fence_handle,
Arguments&&... arguments) {
// If any of the provided arguments does not have memory associated with it,
Expand Down Expand Up @@ -536,8 +538,8 @@ inline bool Context::submit_compute_job(
#endif /* USE_VULKAN_GPU_DIAGNOSTICS */

// Factor out template parameter independent code to minimize code bloat.
DescriptorSet descriptor_set =
get_descriptor_set(shader, local_work_group_size);
DescriptorSet descriptor_set = get_descriptor_set(
shader, local_work_group_size, specialization_constants);

detail::bind(
descriptor_set,
Expand Down
8 changes: 5 additions & 3 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ ExecuteNode::ExecuteNode(
const std::vector<ArgGroup>& args,
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
const ResizeFunction& resize_fn,
const std::vector<ValueRef>& resize_args)
const std::vector<ValueRef>& resize_args,
const api::SpecVarList& spec_vars)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
args_(args),
params_(params),
resize_fn_(resize_fn),
resize_args_(resize_args) {
resize_args_(resize_args),
spec_vars_(spec_vars) {
graph.update_descriptor_counts(shader, /*execute = */ true);
}

Expand All @@ -40,7 +42,7 @@ void ExecuteNode::encode(ComputeGraph* graph) {
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

api::DescriptorSet descriptor_set =
context->get_descriptor_set(shader_, local_workgroup_size_);
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);

uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class ExecuteNode final {
const std::vector<ArgGroup>& args,
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});
const std::vector<ValueRef>& resize_args = {},
const api::SpecVarList& spec_vars = {});

~ExecuteNode() = default;

Expand All @@ -76,6 +77,7 @@ class ExecuteNode final {
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
const ResizeFunction resize_fn_;
const std::vector<ValueRef> resize_args_;
const api::SpecVarList spec_vars_;
};

} // namespace vkcompute
46 changes: 46 additions & 0 deletions backends/vulkan/test/glsl/fill_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

$PRECISION = "highp"
$DTYPE = "float"

#define PRECISION ${PRECISION}

#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}

#include "indexing_utils.h"

layout(std430) buffer;

layout(set = 0, binding = 0) buffer PRECISION restrict writeonly Buffer {
VEC4_T data[];
}
buffer_in;

layout(set = 0, binding = 1) uniform PRECISION restrict Params {
int len;
}
params;



layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const float scale = 1;
layout(constant_id = 4) const float offset = 0;

void main() {
const int i = ivec3(gl_GlobalInvocationID).x;

const int base = 4 * i;
if (base < params.len) {
buffer_in.data[i] = scale * (VEC4_T(base) + VEC4_T(0, 1, 2, 3)) + offset;
}
}
9 changes: 9 additions & 0 deletions backends/vulkan/test/utils/test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ void record_nchw_to_image_op(
api::VulkanBuffer& src_buffer,
vTensor& v_dst) {
api::PipelineBarrier pipeline_barrier{};
api::SpecVarList specialization_constants = {};

context->submit_compute_job(
get_nchw_to_image_shader(v_dst),
pipeline_barrier,
v_dst.virtual_extents(),
adaptive_work_group_size(v_dst.virtual_extents()),
specialization_constants,
VK_NULL_HANDLE,
v_dst.image(
pipeline_barrier,
Expand All @@ -42,11 +44,14 @@ void record_image_to_nchw_op(
vTensor& v_src,
api::VulkanBuffer& dst_buffer) {
api::PipelineBarrier pipeline_barrier{};
api::SpecVarList specialization_constants = {};

context->submit_compute_job(
get_image_to_nchw_shader(v_src),
pipeline_barrier,
v_src.virtual_extents(),
adaptive_work_group_size(v_src.virtual_extents()),
specialization_constants,
VK_NULL_HANDLE,
v_src.image(pipeline_barrier, api::PipelineStage::COMPUTE),
dst_buffer,
Expand Down Expand Up @@ -78,11 +83,13 @@ void record_conv2d_prepack_weights_op(
api::UniformParamsBuffer padded_sizes_ubo(
context, api::utils::make_ivec2(padded_sizes, /*reverse = */ true));

api::SpecVarList specialization_constants = {};
context->submit_compute_job(
shader,
pipeline_barrier,
v_dst.virtual_extents(),
adaptive_work_group_size(v_dst.virtual_extents()),
specialization_constants,
VK_NULL_HANDLE,
v_dst.image(
pipeline_barrier,
Expand All @@ -104,11 +111,13 @@ void record_binary_op(
add_dtype_suffix(kernel_name, v_dst);

api::PipelineBarrier pipeline_barrier{};
api::SpecVarList specialization_constants = {};
context->submit_compute_job(
VK_KERNEL_FROM_STR(kernel_name),
pipeline_barrier,
v_dst.virtual_extents(),
adaptive_work_group_size(v_dst.virtual_extents()),
specialization_constants,
VK_NULL_HANDLE,
v_dst.image(
pipeline_barrier,
Expand Down
38 changes: 38 additions & 0 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,38 @@ TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
ASSERT_TRUE(*(reinterpret_cast<const float*>(copy_data + 6)) == 5.5f);
}

TEST_F(VulkanComputeAPITest, spec_var_shader_test) {
size_t len = 16;
api::StorageBuffer buffer(api::context(), api::kFloat, len);

float scale = 3.0f;
float offset = 1.5f;

{
api::UniformParamsBuffer params(api::context(), int32_t(len));
uint32_t len_div4 = api::utils::div_up(uint32_t(len), uint32_t(4));
api::PipelineBarrier pipeline_barrier{};
api::context()->submit_compute_job(
VK_KERNEL(fill_buffer),
pipeline_barrier,
{64, 1, 1},
{len_div4, 1, 1},
{SV(scale), SV(offset)},
VK_NULL_HANDLE,
buffer.buffer(),
params.buffer());
}

submit_to_gpu();

std::vector<float> data(len);
copy_staging_to_ptr(buffer, data.data(), buffer.nbytes());

for (size_t i = 0; i < len; ++i) {
CHECK_VALUE(data, i, scale * i + offset);
}
}

TEST_F(VulkanComputeAPITest, update_params_between_submit) {
api::context()->set_cmd(/*reusable = */ true);
std::vector<int64_t> sizes = {4, 4, 2};
Expand All @@ -121,11 +153,13 @@ TEST_F(VulkanComputeAPITest, update_params_between_submit) {

{
api::PipelineBarrier pipeline_barrier{};
api::SpecVarList specialization_constants = {};
api::context()->submit_compute_job(
VK_KERNEL_FROM_STR(kernel_name),
pipeline_barrier,
{4, 4, 4},
{4, 4, 4},
specialization_constants,
VK_NULL_HANDLE,
a.image(
pipeline_barrier,
Expand Down Expand Up @@ -180,11 +214,13 @@ void test_storage_buffer_type(const size_t len) {
{
uint32_t len_div4 = api::utils::div_up(uint32_t(len), uint32_t(4));
api::PipelineBarrier pipeline_barrier{};
api::SpecVarList specialization_constants = {};
api::context()->submit_compute_job(
VK_KERNEL_FROM_STR(kernel_name),
pipeline_barrier,
{64, 1, 1},
{len_div4, 1, 1},
specialization_constants,
VK_NULL_HANDLE,
buffer.buffer(),
params.buffer());
Expand Down Expand Up @@ -875,11 +911,13 @@ void run_from_gpu_test(

{
api::PipelineBarrier pipeline_barrier{};
api::SpecVarList specialization_constants = {};
api::context()->submit_compute_job(
VK_KERNEL_FROM_STR(kernel_name),
pipeline_barrier,
vten.virtual_extents(),
{4, 4, 4},
specialization_constants,
VK_NULL_HANDLE,
vten.image(
pipeline_barrier,
Expand Down