Skip to content

Commit c0807b1

Browse files
abhishekchandrafacebook-github-bot
authored andcommitted
Blit Node
Summary: Introduce a graph node to call vkcmdBlitImage which can convert between dtypes (and also perform scaling, filtering etc. but we don't need them right now). Differential Revision: D63839654
1 parent d038985 commit c0807b1

File tree

6 files changed

+163
-0
lines changed

6 files changed

+163
-0
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,14 @@ void Context::register_shader_dispatch(
143143
cmd_.dispatch(effective_global_wg);
144144
}
145145

146+
void Context::register_blit(
147+
vkapi::PipelineBarrier& pipeline_barrier,
148+
vkapi::VulkanImage& src,
149+
vkapi::VulkanImage& dst) {
150+
cmd_.insert_barrier(pipeline_barrier);
151+
cmd_.blit(src, dst);
152+
}
153+
146154
void Context::submit_cmd_to_gpu(VkFence fence_handle, const bool final_use) {
147155
if (cmd_) {
148156
cmd_.end();

backends/vulkan/runtime/api/Context.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,11 @@ class Context final {
196196
const vkapi::ShaderInfo&,
197197
const utils::uvec3&);
198198

199+
void register_blit(
200+
vkapi::PipelineBarrier&,
201+
vkapi::VulkanImage& src,
202+
vkapi::VulkanImage& dst);
203+
199204
template <typename... Arguments>
200205
bool submit_compute_job(
201206
const vkapi::ShaderInfo&,
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/BlitNode.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
13+
namespace vkcompute {
14+
15+
BlitNode::BlitNode(
16+
ComputeGraph& graph,
17+
ValueRef src,
18+
ValueRef dst,
19+
// const vkapi::ScalarType& dtype,
20+
const ResizeFunction& resize_fn,
21+
const std::vector<ValueRef>& resize_args)
22+
: ExecuteNode(resize_fn, resize_args, {}, "Blit Node"),
23+
src_(src),
24+
dst_(dst) {
25+
(void)graph;
26+
}
27+
28+
void BlitNode::encode(ComputeGraph* graph) {
29+
auto src_tensor = graph->get_tensor(src_);
30+
auto dst_tensor = graph->get_tensor(dst_);
31+
VK_CHECK_COND(
32+
src_tensor->storage_type() != utils::kBuffer &&
33+
dst_tensor->storage_type() != utils::kBuffer,
34+
"BlitNode: Only texture backed tensors are supported.");
35+
36+
api::Context* const context = graph->context();
37+
vkapi::PipelineBarrier pipeline_barrier{};
38+
39+
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
40+
41+
// Hack to get timing data for non shader op
42+
std::string kernel_name("Blit_");
43+
kernel_name.reserve(32);
44+
kernel_name += vkapi::to_string(src_tensor->dtype());
45+
kernel_name += "_to_";
46+
kernel_name += vkapi::to_string(dst_tensor->dtype());
47+
48+
context->report_shader_dispatch_start(
49+
kernel_name, utils::uvec3(), utils::uvec3(), node_id_);
50+
51+
context->register_blit(
52+
pipeline_barrier,
53+
src_tensor->image(
54+
pipeline_barrier,
55+
vkapi::PipelineStage::TRANSFER,
56+
vkapi::kRead),
57+
dst_tensor->image(
58+
pipeline_barrier,
59+
vkapi::PipelineStage::TRANSFER,
60+
vkapi::kWrite));
61+
62+
context->report_shader_dispatch_end();
63+
}
64+
65+
} // namespace vkcompute
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
13+
14+
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
17+
18+
namespace vkcompute {
19+
20+
/*
21+
* Represents a tensor blit execution op in a ML model.
22+
*/
23+
class BlitNode final : public ExecuteNode {
24+
friend class ComputeGraph;
25+
26+
public:
27+
explicit BlitNode(
28+
ComputeGraph& graph,
29+
ValueRef src,
30+
ValueRef dst,
31+
/*const vkapi::ScalarType& dtype,*/
32+
const ResizeFunction& resize_fn = nullptr,
33+
const std::vector<ValueRef>& resize_args = {});
34+
35+
~BlitNode() = default;
36+
37+
void encode(ComputeGraph* graph) override;
38+
39+
protected:
40+
ValueRef src_;
41+
ValueRef dst_;
42+
// const vkapi::ScalarType &dtype_;
43+
};
44+
45+
} // namespace vkcompute

backends/vulkan/runtime/vk_api/Command.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,45 @@ void CommandBuffer::dispatch(const utils::uvec3& global_workgroup_size) {
179179
state_ = CommandBuffer::State::RECORDING;
180180
}
181181

182+
void CommandBuffer::blit(vkapi::VulkanImage& src, vkapi::VulkanImage& dst) {
183+
VK_CHECK_COND(
184+
state_ == CommandBuffer::State::BARRIERS_INSERTED,
185+
"Vulkan CommandBuffer: called blit() on a command buffer whose state "
186+
"is not BARRIERS_INSERTED.");
187+
188+
auto src_extents = src.extents();
189+
auto dst_extents = dst.extents();
190+
191+
VkImageBlit blit{};
192+
blit.srcOffsets[0] = {0, 0, 0},
193+
blit.srcOffsets[1] =
194+
{static_cast<int32_t>(src_extents.width),
195+
static_cast<int32_t>(src_extents.height),
196+
static_cast<int32_t>(src_extents.depth)},
197+
blit.srcSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT,
198+
blit.srcSubresource.mipLevel = 0, blit.srcSubresource.baseArrayLayer = 0,
199+
blit.srcSubresource.layerCount = 1, blit.dstOffsets[0] = {0, 0, 0},
200+
blit.dstOffsets[1] =
201+
{static_cast<int32_t>(dst_extents.width),
202+
static_cast<int32_t>(dst_extents.height),
203+
static_cast<int32_t>(dst_extents.depth)},
204+
blit.dstSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT,
205+
blit.dstSubresource.mipLevel = 0, blit.dstSubresource.baseArrayLayer = 0,
206+
blit.dstSubresource.layerCount = 1,
207+
208+
vkCmdBlitImage(
209+
handle_,
210+
src.handle(),
211+
src.layout(),
212+
dst.handle(),
213+
dst.layout(),
214+
1,
215+
&blit,
216+
VK_FILTER_NEAREST);
217+
218+
state_ = CommandBuffer::State::RECORDING;
219+
}
220+
182221
void CommandBuffer::write_timestamp(VkQueryPool querypool, const uint32_t idx)
183222
const {
184223
VK_CHECK_COND(

backends/vulkan/runtime/vk_api/Command.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class CommandBuffer final {
9292

9393
void insert_barrier(PipelineBarrier& pipeline_barrier);
9494
void dispatch(const utils::uvec3&);
95+
void blit(vkapi::VulkanImage& src, vkapi::VulkanImage& dst);
9596

9697
void write_timestamp(VkQueryPool, const uint32_t) const;
9798
void reset_querypool(VkQueryPool, const uint32_t, const uint32_t) const;

0 commit comments

Comments
 (0)