Skip to content

Commit 4a975ea

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Merge ArithmeticPrepack into PrepackNode (#2261)
Summary: bypass-github-export-checks Pull Request resolved: #2261 There's a lot of shared logic between - `add_staging_to_tensor_node()`, which handles I/O data on execute(), and - `ArithmeticPrepack`'s simple prepacking, on prepack(). Both just copy data to and from GPU, without any manipulation. Hence, I've decided to consolidate shared logic in this diff as well. Here are the final results: + Make `PrepackNode` a final class. + Remove all references of `impl/Packing.h`. - Extract shared util functions to new `StagingUtils.h/cpp`. ghstack-source-id: 217439331 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D54504449 fbshipit-source-id: 358f2f5acb396a05bf7758cce1f3314d3a85ba55
1 parent b2862ea commit 4a975ea

File tree

10 files changed

+369
-280
lines changed

10 files changed

+369
-280
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
12+
1113
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1214

1315
namespace at {
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
15+
16+
namespace at {
17+
namespace native {
18+
namespace vulkan {
19+
20+
void PrepackNode::encode(ComputeGraph* graph) {
21+
api::Context* const context = graph->context();
22+
api::PipelineBarrier pipeline_barrier{};
23+
24+
TensorRef tref = graph->get_val(tref_).toTensorRef();
25+
vTensor packed = graph->get_val(packed_).toTensor();
26+
27+
// TODO: Extract to standalone function, to support other types of prepacking.
28+
api::StorageBuffer staging(
29+
graph->context(), packed.dtype(), packed.gpu_nbytes());
30+
size_t numel = api::utils::multiply_integers(tref.sizes);
31+
size_t nbytes = numel * api::element_size(tref.dtype);
32+
copy_ptr_to_staging(tref.data, staging, nbytes);
33+
34+
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
35+
36+
api::DescriptorSet descriptor_set =
37+
context->get_descriptor_set(shader_, local_workgroup_size_);
38+
39+
uint32_t idx = 0;
40+
bind_tensor_to_descriptor_set(
41+
packed,
42+
pipeline_barrier,
43+
api::MemoryAccessType::WRITE,
44+
descriptor_set,
45+
idx++);
46+
bind_staging_to_descriptor_set(staging, descriptor_set, idx++);
47+
descriptor_set.bind(idx, params_.buffer());
48+
49+
context->register_shader_dispatch(
50+
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
51+
}
52+
53+
} // namespace vulkan
54+
} // namespace native
55+
} // namespace at

backends/vulkan/runtime/graph/ops/PrepackNode.h

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,37 @@ class ComputeGraph;
2828
* encoding of shaders transferring necessary data (such as weights and biases)
2929
* to the GPU.
3030
*/
31-
class PrepackNode {
31+
class PrepackNode final {
3232
friend class ComputeGraph;
3333

3434
public:
35-
PrepackNode(ValueRef tref, ValueRef packed) : tref_{tref}, packed_{packed} {}
35+
PrepackNode(
36+
const api::ShaderInfo& shader,
37+
const api::utils::uvec3& global_workgroup_size,
38+
const api::utils::uvec3& local_workgroup_size,
39+
const ValueRef tref,
40+
const ValueRef packed,
41+
api::UniformParamsBuffer&& params)
42+
: shader_(shader),
43+
global_workgroup_size_(global_workgroup_size),
44+
local_workgroup_size_(local_workgroup_size),
45+
tref_(tref),
46+
packed_(packed),
47+
params_(std::move(params)) {}
3648

37-
virtual ~PrepackNode() = default;
49+
~PrepackNode() = default;
3850

39-
protected:
40-
ValueRef tref_;
41-
ValueRef packed_;
51+
void encode(ComputeGraph* graph);
4252

43-
public:
44-
virtual void encode(ComputeGraph* graph) const = 0;
53+
protected:
54+
const api::ShaderInfo shader_;
55+
const api::utils::uvec3 global_workgroup_size_;
56+
const api::utils::uvec3 local_workgroup_size_;
57+
const ValueRef tref_;
58+
const ValueRef packed_;
59+
// TODO(T180906086): pass multiple buffers and index with ValueRef.
60+
// TODO(T180906457): allow re-computing param buffers.
61+
api::UniformParamsBuffer params_;
4562
};
4663

4764
} // namespace vulkan
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
12+
13+
#include <ATen/native/vulkan/impl/Common.h>
14+
15+
namespace at {
16+
namespace native {
17+
namespace vulkan {
18+
19+
void memcpy_to_mapping(
20+
const void* src,
21+
api::MemoryMap& dst_mapping,
22+
const size_t nbytes,
23+
const api::ScalarType dtype) {
24+
#define DTYPE_CASE(ctype, vkformat, name) \
25+
case api::ScalarType::name: \
26+
memcpy_to_mapping_impl<ctype>(src, dst_mapping, nbytes); \
27+
break;
28+
29+
switch (dtype) {
30+
VK_FORALL_SCALAR_TYPES(DTYPE_CASE)
31+
default:
32+
VK_THROW("Unrecognized dtype!");
33+
}
34+
#undef DTYPE_CASE
35+
}
36+
37+
void memcpy_from_mapping(
38+
api::MemoryMap& src_mapping,
39+
void* dst,
40+
const size_t nbytes,
41+
const api::ScalarType dtype) {
42+
#define DTYPE_CASE(ctype, vkformat, name) \
43+
case api::ScalarType::name: \
44+
memcpy_from_mapping_impl<ctype>(src_mapping, dst, nbytes); \
45+
break;
46+
47+
switch (dtype) {
48+
VK_FORALL_SCALAR_TYPES(DTYPE_CASE)
49+
default:
50+
VK_THROW("Unrecognized dtype!");
51+
}
52+
#undef DTYPE_CASE
53+
}
54+
55+
void copy_ptr_to_staging(
56+
const void* src,
57+
api::StorageBuffer& staging,
58+
const size_t nbytes) {
59+
api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
60+
mapping.invalidate();
61+
memcpy_to_mapping(src, mapping, nbytes, staging.dtype());
62+
}
63+
64+
void copy_staging_to_ptr(
65+
api::StorageBuffer& staging,
66+
void* dst,
67+
const size_t nbytes) {
68+
api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::READ);
69+
mapping.invalidate();
70+
memcpy_from_mapping(mapping, dst, nbytes, staging.dtype());
71+
}
72+
73+
api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) {
74+
if (v_dst.is_quantized()) {
75+
switch (v_dst.storage_type()) {
76+
case api::StorageType::TEXTURE_3D:
77+
switch (v_dst.dtype()) {
78+
case api::ScalarType::QUInt8:
79+
return VK_KERNEL(nchw_to_image_uint8);
80+
case api::ScalarType::QInt8:
81+
return VK_KERNEL(nchw_to_image_int8);
82+
case api::ScalarType::QInt32:
83+
return VK_KERNEL(nchw_to_image_int32);
84+
default:
85+
VK_THROW(
86+
"Vulkan quantization currently not supported for dtype ",
87+
v_dst.dtype());
88+
}
89+
case api::StorageType::TEXTURE_2D:
90+
switch (v_dst.dtype()) {
91+
case api::ScalarType::QUInt8:
92+
return VK_KERNEL(nchw_to_image2d_uint8);
93+
case api::ScalarType::QInt8:
94+
return VK_KERNEL(nchw_to_image2d_int8);
95+
case api::ScalarType::QInt32:
96+
return VK_KERNEL(nchw_to_image2d_int32);
97+
default:
98+
VK_THROW(
99+
"Vulkan quantization currently not supported for dtype ",
100+
v_dst.dtype());
101+
}
102+
default:
103+
VK_THROW("No kernel available!");
104+
case api::StorageType::BUFFER:
105+
case api::StorageType::UNKNOWN:
106+
VK_THROW("Requested storage type must be a texture type.");
107+
}
108+
}
109+
110+
if (v_dst.dtype() == api::kFloat) {
111+
switch (v_dst.storage_type()) {
112+
case api::StorageType::TEXTURE_3D:
113+
return VK_KERNEL(nchw_to_image);
114+
case api::StorageType::TEXTURE_2D:
115+
return VK_KERNEL(nchw_to_image2d);
116+
default:
117+
VK_THROW("No kernel available!");
118+
}
119+
} else if (v_dst.dtype() == api::kBool) {
120+
switch (v_dst.storage_type()) {
121+
case api::StorageType::TEXTURE_3D:
122+
return VK_KERNEL(nchw_to_image_bool);
123+
default:
124+
VK_THROW("No kernel available!");
125+
}
126+
} else {
127+
VK_THROW("Unsupported dtype!");
128+
}
129+
}
130+
131+
api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) {
132+
if (v_src.is_quantized() || v_src.dtype() == api::kBool) {
133+
auto plane_size =
134+
dim_at<Dim4D::Height>(v_src) * dim_at<Dim4D::Width>(v_src);
135+
switch (v_src.storage_type()) {
136+
case api::StorageType::TEXTURE_3D:
137+
switch (v_src.dtype()) {
138+
case api::ScalarType::QUInt8:
139+
case api::ScalarType::QInt8:
140+
case api::kBool:
141+
return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4)
142+
: VK_KERNEL(image_to_nchw_uint);
143+
case api::ScalarType::QInt32:
144+
return VK_KERNEL(image_to_nchw_int32);
145+
default:
146+
VK_THROW(
147+
"Vulkan quantization currently not supported for dtype ",
148+
v_src.dtype());
149+
}
150+
default:
151+
VK_THROW("No kernel available!");
152+
case api::StorageType::BUFFER:
153+
case api::StorageType::UNKNOWN:
154+
VK_THROW("Requested storage type must be a texture type.");
155+
}
156+
}
157+
158+
if (v_src.dtype() == api::kFloat) {
159+
switch (v_src.storage_type()) {
160+
case api::StorageType::TEXTURE_3D:
161+
return VK_KERNEL(image_to_nchw);
162+
case api::StorageType::TEXTURE_2D:
163+
return VK_KERNEL(image2d_to_nchw);
164+
default:
165+
VK_THROW("No kernel available!");
166+
}
167+
} else {
168+
VK_THROW("Unsupported dtype!");
169+
}
170+
}
171+
172+
} // namespace vulkan
173+
} // namespace native
174+
} // namespace at
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef USE_VULKAN_API
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
14+
15+
#include <cstring>
16+
17+
namespace at {
18+
namespace native {
19+
namespace vulkan {
20+
21+
//
22+
// Functions to memcpy data into staging buffer
23+
//
24+
25+
void memcpy_to_mapping(
26+
const void* src,
27+
api::MemoryMap& dst_mapping,
28+
const size_t nbytes,
29+
const api::ScalarType dtype);
30+
void memcpy_from_mapping(
31+
const api::MemoryMap& src_mapping,
32+
void* dst,
33+
const size_t nbytes,
34+
const api::ScalarType dtype);
35+
36+
//
37+
// Utility functions for memcpy
38+
//
39+
40+
template <typename T>
41+
void memcpy_to_mapping_impl(
42+
const void* src,
43+
api::MemoryMap& dst_mapping,
44+
const size_t nbytes) {
45+
T* data_ptr = dst_mapping.template data<T>();
46+
memcpy(data_ptr, reinterpret_cast<const T*>(src), nbytes);
47+
}
48+
49+
template <typename T>
50+
void memcpy_from_mapping_impl(
51+
api::MemoryMap& src_mapping,
52+
void* dst,
53+
const size_t nbytes) {
54+
T* data_ptr = src_mapping.template data<T>();
55+
memcpy(reinterpret_cast<T*>(dst), data_ptr, nbytes);
56+
}
57+
58+
//
59+
// Functions to copy data into and out of a staging buffer
60+
//
61+
62+
void copy_ptr_to_staging(
63+
const void* src,
64+
api::StorageBuffer& staging,
65+
const size_t nbytes);
66+
void copy_staging_to_ptr(
67+
api::StorageBuffer& staging,
68+
void* dst,
69+
const size_t nbytes);
70+
71+
//
72+
// Functions to get shaders
73+
//
74+
75+
api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst);
76+
api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src);
77+
78+
} // namespace vulkan
79+
} // namespace native
80+
} // namespace at
81+
82+
#endif /* USE_VULKAN_API */

0 commit comments

Comments
 (0)