Skip to content

Commit 96fb2bf

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Move ParamsBuffer and StorageBuffer to standalone files (#4120)
Summary: Pull Request resolved: #4120 Includes the renaming of `UniformParamsBuffer` to `ParamsBuffer` for brevity. These objects aren't tightly coupled to `Context` and hence they are better placed in standalone files. ghstack-source-id: 232366091 exported-using-ghexport bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: SS-JIA Differential Revision: D59281543 fbshipit-source-id: 93f72cccee1f9959f40410c29ef483dedc569fae
1 parent da04a60 commit 96fb2bf

File tree

11 files changed

+222
-179
lines changed

11 files changed

+222
-179
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@
88

99
#include <executorch/backends/vulkan/runtime/api/Context.h>
1010

11-
#include <cstdint>
12-
#include <cstring>
13-
#include <memory>
14-
#include <sstream>
15-
1611
#ifndef VULKAN_DESCRIPTOR_POOL_SIZE
1712
#define VULKAN_DESCRIPTOR_POOL_SIZE 1024u
1813
#endif
@@ -220,59 +215,5 @@ Context* context() {
220215
return context.get();
221216
}
222217

223-
//
224-
// UniformParamsBuffer
225-
//
226-
227-
namespace {
228-
229-
void memcpy_to_buffer(const VulkanBuffer& src, VulkanBuffer& dst) {
230-
MemoryMap dst_mapping(dst, MemoryAccessType::WRITE);
231-
232-
MemoryMap src_mapping(src, MemoryAccessType::READ);
233-
src_mapping.invalidate();
234-
235-
void* dst_ptr = dst_mapping.template data<void>();
236-
void* src_ptr = src_mapping.template data<void>();
237-
238-
// @lint-ignore CLANGTIDY facebook-security-vulnerable-memcpy
239-
memcpy(dst_ptr, src_ptr, src.mem_size());
240-
}
241-
242-
} // namespace
243-
244-
UniformParamsBuffer::UniformParamsBuffer(const UniformParamsBuffer& other)
245-
: context_p_(other.context_p_), vulkan_buffer_{} {
246-
if (other.vulkan_buffer_) {
247-
vulkan_buffer_ = context_p_->adapter_ptr()->vma().create_uniform_buffer(
248-
other.vulkan_buffer_.mem_size());
249-
250-
memcpy_to_buffer(other.vulkan_buffer_, vulkan_buffer_);
251-
}
252-
}
253-
254-
UniformParamsBuffer& UniformParamsBuffer::operator=(
255-
const UniformParamsBuffer& other) {
256-
if (&other != this) {
257-
context_p_ = other.context_p_;
258-
259-
// Move vulkan_buffer_ to another VulkanBuffer for cleanup
260-
if (vulkan_buffer_) {
261-
VulkanBuffer temp_buffer(std::move(vulkan_buffer_));
262-
context_p_->register_buffer_cleanup(temp_buffer);
263-
}
264-
// vulkan_buffer_ should now be empty
265-
266-
if (other.vulkan_buffer_) {
267-
vulkan_buffer_ = context_p_->adapter_ptr()->vma().create_uniform_buffer(
268-
other.vulkan_buffer_.mem_size());
269-
270-
memcpy_to_buffer(other.vulkan_buffer_, vulkan_buffer_);
271-
}
272-
}
273-
274-
return *this;
275-
}
276-
277218
} // namespace api
278219
} // namespace vkcompute

backends/vulkan/runtime/api/Context.h

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,12 @@
1010

1111
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
1212

13-
#include <executorch/backends/vulkan/runtime/api/vk_api.h>
14-
1513
#include <executorch/backends/vulkan/runtime/api/Adapter.h>
1614
#include <executorch/backends/vulkan/runtime/api/Command.h>
1715
#include <executorch/backends/vulkan/runtime/api/Descriptor.h>
1816
#include <executorch/backends/vulkan/runtime/api/Fence.h>
19-
#include <executorch/backends/vulkan/runtime/api/Pipeline.h>
2017
#include <executorch/backends/vulkan/runtime/api/QueryPool.h>
2118
#include <executorch/backends/vulkan/runtime/api/Runtime.h>
22-
#include <executorch/backends/vulkan/runtime/api/Shader.h>
23-
#include <executorch/backends/vulkan/runtime/api/Utils.h>
24-
25-
#include <executorch/backends/vulkan/runtime/api/memory/Buffer.h>
2619

2720
namespace vkcompute {
2821
namespace api {
@@ -218,103 +211,6 @@ class Context final {
218211
void flush();
219212
};
220213

221-
class UniformParamsBuffer final {
222-
private:
223-
Context* context_p_;
224-
size_t nbytes_;
225-
VulkanBuffer vulkan_buffer_;
226-
227-
public:
228-
UniformParamsBuffer() : context_p_{nullptr}, vulkan_buffer_{} {}
229-
230-
template <typename Block>
231-
UniformParamsBuffer(Context* context_p, const Block& block)
232-
: context_p_(context_p),
233-
nbytes_(sizeof(block)),
234-
vulkan_buffer_(
235-
context_p_->adapter_ptr()->vma().create_params_buffer(block)) {}
236-
237-
UniformParamsBuffer(const UniformParamsBuffer&);
238-
UniformParamsBuffer& operator=(const UniformParamsBuffer&);
239-
240-
UniformParamsBuffer(UniformParamsBuffer&&) = default;
241-
UniformParamsBuffer& operator=(UniformParamsBuffer&&) = default;
242-
243-
~UniformParamsBuffer() {
244-
if (vulkan_buffer_) {
245-
context_p_->register_buffer_cleanup(vulkan_buffer_);
246-
}
247-
}
248-
249-
const VulkanBuffer& buffer() const {
250-
return vulkan_buffer_;
251-
}
252-
253-
template <typename Block>
254-
void update(const Block& block) {
255-
if (sizeof(block) != nbytes_) {
256-
VK_THROW(
257-
"Attempted to update UniformParamsBuffer with data of different size");
258-
}
259-
// Fill the uniform buffer with data in block
260-
{
261-
MemoryMap mapping(vulkan_buffer_, MemoryAccessType::WRITE);
262-
Block* data_ptr = mapping.template data<Block>();
263-
264-
*data_ptr = block;
265-
}
266-
}
267-
};
268-
269-
class StorageBuffer final {
270-
private:
271-
Context* context_p_;
272-
ScalarType dtype_;
273-
size_t numel_;
274-
size_t nbytes_;
275-
VulkanBuffer vulkan_buffer_;
276-
277-
public:
278-
StorageBuffer(
279-
Context* context_p,
280-
const ScalarType dtype,
281-
const size_t numel,
282-
const bool gpuonly = false)
283-
: context_p_(context_p),
284-
dtype_(dtype),
285-
numel_(numel),
286-
nbytes_(element_size(dtype_) * numel_),
287-
vulkan_buffer_(context_p_->adapter_ptr()->vma().create_storage_buffer(
288-
nbytes_,
289-
gpuonly)) {}
290-
291-
StorageBuffer(const StorageBuffer&) = delete;
292-
StorageBuffer& operator=(const StorageBuffer&) = delete;
293-
294-
StorageBuffer(StorageBuffer&&) = default;
295-
StorageBuffer& operator=(StorageBuffer&&) = default;
296-
297-
~StorageBuffer() {
298-
context_p_->register_buffer_cleanup(vulkan_buffer_);
299-
}
300-
301-
inline ScalarType dtype() {
302-
return dtype_;
303-
}
304-
305-
inline VulkanBuffer& buffer() {
306-
return vulkan_buffer_;
307-
}
308-
309-
inline size_t numel() {
310-
return numel_;
311-
}
312-
313-
inline size_t nbytes() {
314-
return nbytes_;
315-
}
316-
};
317-
318214
bool available();
319215

320216
// The global runtime is retrieved using this function, where it is declared as
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/api/ParamsBuffer.h>
10+
11+
#include <cstring>
12+
13+
namespace vkcompute {
14+
namespace api {
15+
16+
namespace {
17+
18+
void memcpy_to_buffer(const VulkanBuffer& src, VulkanBuffer& dst) {
19+
MemoryMap dst_mapping(dst, MemoryAccessType::WRITE);
20+
21+
MemoryMap src_mapping(src, MemoryAccessType::READ);
22+
src_mapping.invalidate();
23+
24+
void* dst_ptr = dst_mapping.template data<void>();
25+
void* src_ptr = src_mapping.template data<void>();
26+
27+
// @lint-ignore CLANGTIDY facebook-security-vulnerable-memcpy
28+
memcpy(dst_ptr, src_ptr, src.mem_size());
29+
}
30+
31+
} // namespace
32+
33+
ParamsBuffer::ParamsBuffer(const ParamsBuffer& other)
34+
: context_p_(other.context_p_), vulkan_buffer_{} {
35+
if (other.vulkan_buffer_) {
36+
vulkan_buffer_ = context_p_->adapter_ptr()->vma().create_uniform_buffer(
37+
other.vulkan_buffer_.mem_size());
38+
39+
memcpy_to_buffer(other.vulkan_buffer_, vulkan_buffer_);
40+
}
41+
}
42+
43+
ParamsBuffer& ParamsBuffer::operator=(const ParamsBuffer& other) {
44+
if (&other != this) {
45+
context_p_ = other.context_p_;
46+
47+
// Move vulkan_buffer_ to another VulkanBuffer for cleanup
48+
if (vulkan_buffer_) {
49+
VulkanBuffer temp_buffer(std::move(vulkan_buffer_));
50+
context_p_->register_buffer_cleanup(temp_buffer);
51+
}
52+
// vulkan_buffer_ should now be empty
53+
54+
if (other.vulkan_buffer_) {
55+
vulkan_buffer_ = context_p_->adapter_ptr()->vma().create_uniform_buffer(
56+
other.vulkan_buffer_.mem_size());
57+
58+
memcpy_to_buffer(other.vulkan_buffer_, vulkan_buffer_);
59+
}
60+
}
61+
62+
return *this;
63+
}
64+
65+
} // namespace api
66+
} // namespace vkcompute
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
12+
13+
#include <executorch/backends/vulkan/runtime/api/Context.h>
14+
15+
#include <executorch/backends/vulkan/runtime/api/memory/Buffer.h>
16+
17+
namespace vkcompute {
18+
namespace api {
19+
20+
class ParamsBuffer final {
21+
private:
22+
Context* context_p_;
23+
size_t nbytes_;
24+
VulkanBuffer vulkan_buffer_;
25+
26+
public:
27+
ParamsBuffer() : context_p_{nullptr}, vulkan_buffer_{} {}
28+
29+
template <typename Block>
30+
ParamsBuffer(Context* context_p, const Block& block)
31+
: context_p_(context_p),
32+
nbytes_(sizeof(block)),
33+
vulkan_buffer_(
34+
context_p_->adapter_ptr()->vma().create_params_buffer(block)) {}
35+
36+
ParamsBuffer(const ParamsBuffer&);
37+
ParamsBuffer& operator=(const ParamsBuffer&);
38+
39+
ParamsBuffer(ParamsBuffer&&) = default;
40+
ParamsBuffer& operator=(ParamsBuffer&&) = default;
41+
42+
~ParamsBuffer() {
43+
if (vulkan_buffer_) {
44+
context_p_->register_buffer_cleanup(vulkan_buffer_);
45+
}
46+
}
47+
48+
const VulkanBuffer& buffer() const {
49+
return vulkan_buffer_;
50+
}
51+
52+
template <typename Block>
53+
void update(const Block& block) {
54+
if (sizeof(block) != nbytes_) {
55+
VK_THROW("Attempted to update ParamsBuffer with data of different size");
56+
}
57+
// Fill the uniform buffer with data in block
58+
{
59+
MemoryMap mapping(vulkan_buffer_, MemoryAccessType::WRITE);
60+
Block* data_ptr = mapping.template data<Block>();
61+
62+
*data_ptr = block;
63+
}
64+
}
65+
};
66+
67+
} // namespace api
68+
} // namespace vkcompute

0 commit comments

Comments
 (0)