Skip to content

Commit 5675406

Browse files
committed
Update on "[ET-VK] Enable dynamic operator registration"
This change follows 1. in the footsteps of #2222 for static initialization and 2. the popular `TorchLibraryImpl` for wrapping with macros. https://www.internalfb.com/code/fbsource/[b6860acf0fd7a95224f2ed3f6fe48f699a9a45c0]/fbcode/caffe2/torch/library.h?lines=1004%2C1012-1026 Contributors can now write their operator and register them within the same file using `REGISTER_OPERATORS` + `VK_REGISTER_OP()`, as shown in `Arithmetic.h/cpp`. Typically in Linux/Android C++ environments, the symbols corresponding to `OperatorRegisterInit` static instances are discarded since they aren't used for anything other than static initialization. Hence, we need to `link_whole = True` for the `vulkan_graph_runtime` library. We update our Compute API tests to verify we can go through `OperatorRegistry` with proper static initialization. Differential Revision: [D54641117](https://our.internmc.facebook.com/intern/diff/D54641117/) [ghstack-poisoned]
2 parents 7eade98 + 4c08350 commit 5675406

File tree

2 files changed

+159
-33
lines changed

2 files changed

+159
-33
lines changed

backends/vulkan/runtime/graph/ops/impl/Staging.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ struct StagingParams final {
3535

3636
ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v);
3737

38+
// Expose for the Vulkan Compute API tests.
39+
StagingParams create_staging_params(const vTensor& t);
40+
3841
} // namespace vulkan
3942
} // namespace native
4043
} // namespace at

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 156 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
#include <ATen/native/vulkan/api/api.h>
1212

13-
#include <ATen/native/vulkan/impl/Arithmetic.h>
14-
#include <ATen/native/vulkan/impl/Packing.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
1514

1615
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1716
#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
@@ -21,10 +20,6 @@
2120

2221
using namespace at::native::vulkan;
2322

24-
//
25-
// Utilities
26-
//
27-
2823
#define CREATE_FLOAT_TEXTURE(sizes, allocate_memory) \
2924
vTensor( \
3025
api::context(), \
@@ -43,23 +38,159 @@ using namespace at::native::vulkan;
4338
api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, \
4439
allocate_memory);
4540

41+
//
42+
// Simplified versions of ATen Vulkan legacy functions
43+
//
44+
45+
void record_nchw_to_buffer_op(
46+
api::Context* const context,
47+
api::VulkanBuffer& src_buffer,
48+
vTensor& v_dst) {
49+
uint32_t buf_len = api::utils::safe_downcast<uint32_t>(v_dst.gpu_numel());
50+
api::utils::uvec3 global_size = {buf_len, 1u, 1u};
51+
api::utils::uvec3 local_size = {32u, 1u, 1u};
52+
53+
api::UniformParamsBuffer cpu_buffer_metadata(
54+
context, v_dst.get_cpu_buffer_metadata());
55+
api::PipelineBarrier pipeline_barrier{};
56+
57+
context->submit_compute_job(
58+
VK_KERNEL(buffer_to_buffer),
59+
pipeline_barrier,
60+
global_size,
61+
local_size,
62+
VK_NULL_HANDLE,
63+
v_dst.buffer(
64+
pipeline_barrier,
65+
api::PipelineStage::COMPUTE,
66+
api::MemoryAccessType::WRITE),
67+
v_dst.buffer_metadata(),
68+
src_buffer,
69+
cpu_buffer_metadata.buffer());
70+
}
71+
72+
bool record_buffer_to_nchw_op(
73+
api::Context* const context,
74+
vTensor& v_src,
75+
api::VulkanBuffer& dst_buffer) {
76+
uint32_t buf_len = api::utils::safe_downcast<uint32_t>(v_src.numel());
77+
api::utils::uvec3 global_size = {buf_len, 1u, 1u};
78+
api::utils::uvec3 local_size = {4u, 1u, 1u};
79+
80+
api::UniformParamsBuffer cpu_buffer_metadata(
81+
context, v_src.get_cpu_buffer_metadata());
82+
api::PipelineBarrier pipeline_barrier{};
83+
84+
return context->submit_compute_job(
85+
VK_KERNEL(buffer_to_buffer),
86+
pipeline_barrier,
87+
global_size,
88+
local_size,
89+
VK_NULL_HANDLE,
90+
dst_buffer,
91+
cpu_buffer_metadata.buffer(),
92+
v_src.buffer(
93+
pipeline_barrier,
94+
api::PipelineStage::COMPUTE,
95+
api::MemoryAccessType::WRITE),
96+
v_src.buffer_metadata());
97+
}
98+
99+
void record_nchw_to_image_op(
100+
api::Context* const context,
101+
api::VulkanBuffer& src_buffer,
102+
vTensor& v_dst) {
103+
api::utils::uvec3 global_size = v_dst.extents();
104+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
105+
106+
api::UniformParamsBuffer params(context, create_staging_params(v_dst));
107+
api::PipelineBarrier pipeline_barrier{};
108+
109+
context->submit_compute_job(
110+
get_nchw_to_image_shader(v_dst),
111+
pipeline_barrier,
112+
global_size,
113+
local_size,
114+
VK_NULL_HANDLE,
115+
v_dst.image(
116+
pipeline_barrier,
117+
api::PipelineStage::COMPUTE,
118+
api::MemoryAccessType::WRITE),
119+
src_buffer,
120+
params.buffer());
121+
}
122+
123+
bool record_image_to_nchw_op(
124+
api::Context* const context,
125+
vTensor& v_src,
126+
api::VulkanBuffer& dst_buffer) {
127+
api::utils::uvec3 global_size = v_src.extents();
128+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
129+
130+
api::UniformParamsBuffer params(context, create_staging_params(v_src));
131+
api::PipelineBarrier pipeline_barrier{};
132+
133+
return context->submit_compute_job(
134+
get_image_to_nchw_shader(v_src),
135+
pipeline_barrier,
136+
global_size,
137+
local_size,
138+
VK_NULL_HANDLE,
139+
v_src.image(
140+
pipeline_barrier,
141+
api::PipelineStage::COMPUTE,
142+
api::MemoryAccessType::WRITE),
143+
dst_buffer,
144+
params.buffer());
145+
}
146+
147+
void record_arithmetic_op(
148+
api::Context* const context,
149+
const api::ShaderInfo& compute_shader,
150+
vTensor& v_in1,
151+
vTensor& v_in2,
152+
vTensor& v_dst,
153+
const float alpha) {
154+
api::utils::uvec3 global_size = v_dst.extents();
155+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
156+
157+
ArithmeticParams block{
158+
get_size_as_ivec4(v_dst),
159+
get_size_as_ivec4(v_in1),
160+
get_size_as_ivec4(v_in2),
161+
alpha,
162+
};
163+
api::UniformParamsBuffer params(context, block);
164+
api::PipelineBarrier pipeline_barrier{};
165+
166+
context->submit_compute_job(
167+
compute_shader,
168+
pipeline_barrier,
169+
global_size,
170+
local_size,
171+
VK_NULL_HANDLE,
172+
v_dst.image(
173+
pipeline_barrier,
174+
api::PipelineStage::COMPUTE,
175+
api::MemoryAccessType::WRITE),
176+
v_in1.image(pipeline_barrier, api::PipelineStage::COMPUTE),
177+
v_in2.image(pipeline_barrier, api::PipelineStage::COMPUTE),
178+
params.buffer());
179+
}
180+
181+
//
182+
// Utilities
183+
//
184+
46185
void fill_vtensor(vTensor& vten, std::vector<float>& data) {
47186
api::StorageBuffer staging_buffer(api::context(), api::kFloat, data.size());
48187

49188
copy_ptr_to_staging(data.data(), staging_buffer, vten.gpu_nbytes());
50189

51190
if (vten.storage_type() == api::StorageType::BUFFER) {
52-
packing::record_nchw_to_buffer_op(
53-
api::context(), staging_buffer.buffer(), vten, {}, VK_NULL_HANDLE);
191+
record_nchw_to_buffer_op(api::context(), staging_buffer.buffer(), vten);
54192
} else {
55-
api::ShaderInfo compute_shader = packing::get_nchw_to_image_shader(vten);
56-
packing::record_nchw_to_image_op(
57-
api::context(),
58-
compute_shader,
59-
staging_buffer.buffer(),
60-
vten,
61-
{},
62-
VK_NULL_HANDLE);
193+
record_nchw_to_image_op(api::context(), staging_buffer.buffer(), vten);
63194
}
64195
}
65196

@@ -75,17 +206,9 @@ void extract_vtensor(vTensor& vten, std::vector<float>& data) {
75206
api::context(), api::kFloat, vten.gpu_numel());
76207

77208
if (vten.storage_type() == api::StorageType::BUFFER) {
78-
packing::record_buffer_to_nchw_op(
79-
api::context(), vten, staging_buffer.buffer(), {}, VK_NULL_HANDLE);
209+
record_buffer_to_nchw_op(api::context(), vten, staging_buffer.buffer());
80210
} else {
81-
api::ShaderInfo compute_shader = packing::get_image_to_nchw_shader(vten);
82-
packing::record_image_to_nchw_op(
83-
api::context(),
84-
compute_shader,
85-
vten,
86-
staging_buffer.buffer(),
87-
{},
88-
VK_NULL_HANDLE);
211+
record_image_to_nchw_op(api::context(), vten, staging_buffer.buffer());
89212
}
90213

91214
api::VulkanFence fence = api::context()->fences().get_fence();
@@ -208,14 +331,14 @@ TEST_F(VulkanComputeAPITest, texture_add_sanity_check) {
208331
std::fill(data_b.begin(), data_b.end(), 1.5f);
209332

210333
// Add shader kernel
211-
api::ShaderInfo kernel = arithmetic::get_shader(arithmetic::OpType::ADD);
334+
api::ShaderInfo kernel = VK_KERNEL(add);
212335

213336
// Fill input tensors
214337
fill_vtensor(a, data_a);
215338
fill_vtensor(b, data_b);
216339

217340
// a + b -> c
218-
arithmetic::record_op(api::context(), kernel, a, b, c, 1.0f);
341+
record_arithmetic_op(api::context(), kernel, a, b, c, 1.0f);
219342

220343
// Extract output tensor
221344
std::vector<float> data_out(c.gpu_numel());
@@ -244,7 +367,7 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) {
244367
std::vector<float> data_b(b.gpu_numel());
245368
std::fill(data_b.begin(), data_b.end(), 1.5f);
246369

247-
api::ShaderInfo kernel = arithmetic::get_shader(arithmetic::OpType::ADD);
370+
api::ShaderInfo kernel = VK_KERNEL(add);
248371

249372
// Allocate memory at the last possible opportunity
250373
api::MemoryAllocation a_mem = allocate_memory_for(a);
@@ -260,7 +383,7 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) {
260383
fill_vtensor(a, data_a);
261384
fill_vtensor(b, data_b);
262385

263-
arithmetic::record_op(api::context(), kernel, a, b, c, 1.0f);
386+
record_arithmetic_op(api::context(), kernel, a, b, c, 1.0f);
264387

265388
std::vector<float> data_c(c.gpu_numel());
266389
extract_vtensor(c, data_c);
@@ -310,20 +433,20 @@ TEST_F(VulkanComputeAPITest, texture_resource_aliasing_test) {
310433
std::fill(data_d.begin(), data_d.end(), 1.0f);
311434

312435
// Get shader kernel for add
313-
api::ShaderInfo kernel = arithmetic::get_shader(arithmetic::OpType::ADD);
436+
api::ShaderInfo kernel = VK_KERNEL(add);
314437

315438
// First, fill a and b with data
316439
fill_vtensor(a, data_a);
317440
fill_vtensor(b, data_b);
318441

319442
// a + b -> c
320-
arithmetic::record_op(api::context(), kernel, a, b, c, 1.0f);
443+
record_arithmetic_op(api::context(), kernel, a, b, c, 1.0f);
321444

322445
// Now d can be filled with data
323446
fill_vtensor(d, data_d);
324447

325448
// c + d -> e
326-
arithmetic::record_op(api::context(), kernel, c, d, e, 1.0f);
449+
record_arithmetic_op(api::context(), kernel, c, d, e, 1.0f);
327450

328451
// Extract data from e
329452
std::vector<float> data_e(e.gpu_numel());

0 commit comments

Comments
 (0)