Skip to content

Commit 9c38cf7

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Add API to read value of SymInt andParamsBuffer (#5754)
Summary: Pull Request resolved: #5754 ## Context This diff adds an API to read the value from a `SymInt`. This functionality will be useful because `SymInt`s may be needed to set tensor sizes, in addition to being used as arguments to shaders. ghstack-source-id: 245517196 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D63642093 fbshipit-source-id: 14ce9daecd5520cf078d3b3259eaafb5b654d834
1 parent a5a76f7 commit 9c38cf7

File tree

6 files changed

+33
-1
lines changed

6 files changed

+33
-1
lines changed

backends/vulkan/runtime/api/containers/ParamsBuffer.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,29 @@ class ParamsBuffer final {
5656
}
5757
// Fill the uniform buffer with data in block
5858
{
59-
vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::MemoryAccessType::WRITE);
59+
vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::kWrite);
6060
Block* data_ptr = mapping.template data<Block>();
6161

6262
*data_ptr = block;
6363
}
6464
}
65+
66+
template <typename T>
67+
T read() const {
68+
T val;
69+
if (sizeof(val) != nbytes_) {
70+
VK_THROW(
71+
"Attempted to store value from ParamsBuffer to type of different size");
72+
}
73+
// Read value from uniform buffer and store in val
74+
{
75+
vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::kRead);
76+
T* data_ptr = mapping.template data<T>();
77+
78+
val = *data_ptr;
79+
}
80+
return val;
81+
}
6582
};
6683

6784
} // namespace api

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,10 @@ void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
416416
get_symint(idx)->set(val);
417417
}
418418

419+
int32_t ComputeGraph::read_symint(const ValueRef idx) {
420+
return get_symint(idx)->get();
421+
}
422+
419423
SharedObject& ComputeGraph::get_shared_object(const int64_t idx) {
420424
if (idx >= shared_objects_.size()) {
421425
shared_objects_.resize(static_cast<size_t>(idx + 1));

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,8 @@ class ComputeGraph final {
536536

537537
void set_symint(const ValueRef idx, const int32_t val);
538538

539+
int32_t read_symint(const ValueRef idx);
540+
539541
/*
540542
* Convenience function to add an input tensor along with its staging buffer
541543
*/

backends/vulkan/runtime/graph/containers/SymInt.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ void SymInt::set(const int32_t val) {
1717
gpu_buffer.update(val);
1818
}
1919

20+
int32_t SymInt::get() {
21+
return gpu_buffer.read<int32_t>();
22+
}
23+
2024
void SymInt::operator=(const int32_t val) {
2125
gpu_buffer.update(val);
2226
}

backends/vulkan/runtime/graph/containers/SymInt.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ struct SymInt final {
3535

3636
void set(const int32_t val);
3737

38+
int32_t get();
39+
3840
void operator=(const int32_t val);
3941
};
4042

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,6 +1433,9 @@ TEST(VulkanComputeGraphTest, test_simple_graph_with_symint) {
14331433
int scalar_val = i - 3.0f;
14341434
graph.set_symint(scalar, scalar_val);
14351435

1436+
int32_t scalar_val_read = graph.read_symint(scalar);
1437+
EXPECT_TRUE(scalar_val_read == scalar_val);
1438+
14361439
float val_a = i + 2.0f;
14371440
float val_out = val_a + scalar_val;
14381441

0 commit comments

Comments
 (0)