Skip to content

Commit 448c7d3

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Support int8 texture tensors without requiring int8 buffers (#4485)
Summary: Pull Request resolved: #4485 ## Context By default, storage buffers in Vulkan must contain 32 bit data types; using 8 bit and 16 bit data types in buffers can be enabled optionally by supporting the [VK_KHR_8bit_storage](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_8bit_storage.html) extension or the [VK_KHR_16bit_storage](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_16bit_storage.html) extension respectively. Previously, 8-bit and 16-bit tensors were enabled by using those extensions; however, this meant that 8-bit and 16-bit tensors could not be used if the Vulkan driver does not support the corresponding extension. This diff adds support for 8-bit texture-backed tensors without the need for the VK_KHR_8bit_storage extension. This is done by introducing shaders that manually pack and repack 4 8-bit integers into a single int32 value. Once the tensor data has been transferred to an image texture (which will use the `VK_FORMAT_R8G8B8A8_SINT` image format) the extension will no longer be required. Reviewed By: jorgep31415 Differential Revision: D60536832 fbshipit-source-id: 8d3d8b069582ab8c18d41701c864778621d2f6e3
1 parent 4483bb6 commit 448c7d3

16 files changed

+320
-73
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -319,24 +319,20 @@ utils::uvec3 ComputeGraph::create_global_wg_size(const ValueRef idx) {
319319
return image_extents_of(idx);
320320
}
321321

322-
utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {
322+
utils::uvec3 ComputeGraph::create_local_wg_size(
323+
const utils::uvec3 global_wg_size) {
323324
if (config_.enable_local_wg_size_override) {
324325
return config_.local_wg_size_override;
325326
}
326327

327-
if (is_buffer_storage(idx)) {
328-
return {64u, 1u, 1u};
329-
}
330-
331-
const utils::uvec3 image_extents = image_extents_of(idx);
332328
utils::uvec3 local_group_size = {4, 4, 4};
333329

334-
if (image_extents.data[2u] == 1) {
335-
if (image_extents.data[1u] == 1) {
330+
if (global_wg_size.data[2u] == 1) {
331+
if (global_wg_size.data[1u] == 1) {
336332
local_group_size.data[0u] = 64;
337333
local_group_size.data[1u] = 1;
338334
local_group_size.data[2u] = 1;
339-
} else if (image_extents.data[1u] < 8) {
335+
} else if (global_wg_size.data[1u] < 8) {
340336
local_group_size.data[0u] = 16;
341337
local_group_size.data[1u] = 4;
342338
local_group_size.data[2u] = 1;
@@ -349,6 +345,10 @@ utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {
349345
return local_group_size;
350346
}
351347

348+
utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {
349+
return create_local_wg_size(image_extents_of(idx));
350+
}
351+
352352
void ComputeGraph::copy_into_staging(
353353
const ValueRef idx,
354354
const void* data,

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ class ComputeGraph final {
180180
return values_.at(idx).type();
181181
}
182182

183-
// Get Tensor Property
183+
//
184+
// Tensor Properties Accessors
185+
//
184186

185187
std::vector<int64_t> sizes_of(const ValueRef idx) const;
186188

@@ -226,7 +228,9 @@ class ComputeGraph final {
226228
return values_.at(idx).toTensor().ntexels_ubo();
227229
}
228230

231+
//
229232
// Scalar Value Extraction
233+
//
230234

231235
template <typename T>
232236
T extract_scalar(const ValueRef idx) {
@@ -459,16 +463,21 @@ class ComputeGraph final {
459463
utils::uvec3 create_global_wg_size(const ValueRef idx);
460464

461465
/*
462-
* Suggest a local workgroup size for a given `api::vTensor` value, assuming
463-
* that every shader invocation calculates one texel element of the output
464-
* tensor.
466+
* Suggest a local workgroup size for a given global workgroup size.
465467
*
466468
* The local workgroup size will be formed to try and minimize the number of
467469
* inactive invocations.
468470
*
469471
* Currently, the local workgroup size is hard-coded to contain a total of 64
470472
* shader invocations. In the future, this value can be configured.
471473
*/
474+
utils::uvec3 create_local_wg_size(const utils::uvec3 global_wg_size);
475+
476+
/*
477+
* Convenience function to suggest a local workgroup size for a given
478+
* `api::vTensor` value, assuming that every shader invocation calculates one
479+
* texel element of the output tensor.
480+
*/
472481
utils::uvec3 create_local_wg_size(const ValueRef idx);
473482

474483
//
@@ -500,6 +509,17 @@ class ComputeGraph final {
500509
void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
501510
void propagate_resize();
502511

512+
//
513+
// Miscellaneous Utilities
514+
//
515+
516+
/*
517+
* Check whether the GPU supports 8 bit buffers.
518+
*/
519+
inline bool int8_buffers_enabled() const {
520+
return context_->adapter_ptr()->has_full_int8_buffers_support();
521+
}
522+
503523
//
504524
// Debug support (implemented in Logging.cpp)
505525
//

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ ivec4 from_nchw_buffer_i(int buf_i, ivec4 sizes) {
8080
* Returns: The (x, y, z, n) texel position corresponding to the first element
8181
* of the texel at the specified buffer index
8282
*/
83-
ivec4 to_texel_pos(int buf_i, ivec4 strides, int packed_dim) {
83+
ivec4 to_tensor_idx(int buf_i, ivec4 strides, int packed_dim) {
8484
ivec4 idx;
8585
for (int i = 3; i >= 0; i--) {
8686
if (i != packed_dim) {
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#include "indexing_utils.h"
14+
15+
layout(std430) buffer;
16+
17+
#extension GL_EXT_control_flow_attributes : require
18+
19+
${layout_declare_tensor(0, "r", "t_in", "int8", "texture3d")}
20+
${layout_declare_buffer(1, "w", "nchw_out", "int")}
21+
${layout_declare_ubo(2, "ivec4", "tensor_sizes")}
22+
${layout_declare_ubo(3, "int", "out_ntexels")}
23+
24+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
25+
26+
layout(constant_id = 3) const int packed_dim = C_DIM;
27+
28+
void main() {
29+
const int out_buf_idx = int(gl_GlobalInvocationID.x);
30+
if (out_buf_idx >= out_ntexels) {
31+
return;
32+
}
33+
34+
ivec4 values;
35+
int in_buf_idx = 4 * out_buf_idx;
36+
37+
[[unroll]] for (int i = 0; i < 4; ++i) {
38+
const ivec4 tensor_idx = from_nchw_buffer_i(in_buf_idx, tensor_sizes);
39+
const ivec4 texture_pos = to_texture_elem_pos(
40+
tensor_idx, tensor_sizes, packed_dim);
41+
values[i] = load_texel(t_in, texture_pos.xyz)[texture_pos.w];
42+
in_buf_idx++;
43+
}
44+
45+
// Manually pack 4x 8-bit integers into a 32 bit integer. Note that little
46+
// endian is assumed, since most processors use little endian. Thus the
47+
// "later" values are placed in most significant bytes.
48+
int packed = ((values[3] & 0xFF) << 24)
49+
| ((values[2] & 0xFF) << 16)
50+
| ((values[1] & 0xFF) << 8)
51+
| ((values[0] & 0xFF));
52+
53+
nchw_out[out_buf_idx] = packed;
54+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#include "indexing_utils.h"
14+
15+
layout(std430) buffer;
16+
17+
#extension GL_EXT_control_flow_attributes : require
18+
19+
${layout_declare_tensor(0, "w", "t_out", "int8", "texture3d")}
20+
${layout_declare_buffer(1, "r", "nchw_in", "int")}
21+
${layout_declare_ubo(2, "ivec4", "tensor_sizes")}
22+
23+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
24+
25+
layout(constant_id = 3) const int packed_dim = C_DIM;
26+
27+
/*
28+
* Extends sign of int8
29+
*/
30+
int extend_sign(int x) {
31+
if (x >> 7 == 1) {
32+
return x | 0xFFFFFF00;
33+
}
34+
return x;
35+
}
36+
37+
ivec4 read_texel(ivec4 tensor_idx) {
38+
const ivec4 buf_indices = get_texel_nchw_buffer_ixs(
39+
tensor_idx, tensor_sizes, packed_dim);
40+
41+
int shift = (1 << 8) - 1;
42+
ivec4 masks;
43+
// Masks used to unpack 4x 8-bit values from a 32 bit integer. Note that
44+
// little endian is assumed, as most processors use little endian. Thus the
45+
// most significant bytes correspond to the "latter" packed values.
46+
masks.x = shift << (8 * (buf_indices.x % 4));
47+
masks.y = shift << (8 * (buf_indices.y % 4));
48+
masks.z = shift << (8 * (buf_indices.z % 4));
49+
masks.w = shift << (8 * (buf_indices.w % 4));
50+
51+
ivec4 out_tex = ivec4(0);
52+
53+
[[unroll]] for (int i = 0; i < 4; ++i) {
54+
if (tensor_idx[packed_dim] + i < tensor_sizes[packed_dim]) {
55+
int in_texel = nchw_in[buf_indices[i] / 4];
56+
int extracted_val = (in_texel & masks[i]) >> (8 * (buf_indices[i] % 4));
57+
extracted_val = extend_sign(extracted_val);
58+
out_tex[i] = extracted_val;
59+
}
60+
}
61+
62+
return out_tex;
63+
}
64+
65+
void main() {
66+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
67+
const ivec4 tensor_idx = to_tensor_idx(pos, tensor_sizes, packed_dim);
68+
69+
if (any(greaterThanEqual(tensor_idx, tensor_sizes))) {
70+
return;
71+
}
72+
73+
write_texel(t_out, pos, read_texel(tensor_idx));
74+
}

backends/vulkan/runtime/graph/ops/glsl/nchw_to_tensor.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ void main() {
6262
return;
6363
}
6464

65-
ivec4 tensor_idx = to_texel_pos(t_id, gpu_strides, packed_dim);
65+
ivec4 tensor_idx = to_tensor_idx(t_id, gpu_strides, packed_dim);
6666
tensor_idx[packed_dim] *= 4;
6767
t_out[t_id] = read_texel(tensor_idx);
6868
}

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void main() {
5353
return;
5454
}
5555

56-
const ivec4 out_pos = to_texel_pos(t_id, out_strides, 0);
56+
const ivec4 out_pos = to_tensor_idx(t_id, out_strides, 0);
5757

5858
VEC4_T outtex = q_8w_linear(out_pos, mat1_sizes.x);
5959
write_texel(t_out, t_id, outtex);

backends/vulkan/runtime/graph/ops/glsl/tensor_to_nchw.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void main() {
6161
}
6262

6363
const VEC4_T intex = t_in[t_id];
64-
ivec4 tensor_idx = to_texel_pos(t_id, gpu_strides, packed_dim);
64+
ivec4 tensor_idx = to_tensor_idx(t_id, gpu_strides, packed_dim);
6565
tensor_idx[packed_dim] *= 4;
6666
write_out_texel(intex, tensor_idx);
6767
}

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ void add_staging_to_tensor_node(
2121
const ValueRef out_tensor) {
2222
VK_CHECK_COND(graph.val_is_staging(in_staging));
2323

24-
vkapi::ShaderInfo shader =
25-
get_nchw_to_tensor_shader(*graph.get_tensor(out_tensor));
24+
vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(
25+
*graph.get_tensor(out_tensor), graph.int8_buffers_enabled());
2626

2727
vkapi::ParamsBindList ubos({graph.sizes_ubo(out_tensor)});
2828
if (graph.is_buffer_storage(out_tensor)) {
@@ -55,10 +55,26 @@ void add_tensor_to_staging_node(
5555
const ValueRef out_staging) {
5656
VK_CHECK_COND(graph.val_is_staging(out_staging));
5757

58-
vkapi::ShaderInfo shader =
59-
get_tensor_to_nchw_shader(*graph.get_tensor(in_tensor));
58+
vkapi::ShaderInfo shader = get_tensor_to_nchw_shader(
59+
*graph.get_tensor(in_tensor), graph.int8_buffers_enabled());
6060

61+
utils::uvec3 global_wg_size = graph.create_global_wg_size(in_tensor);
6162
vkapi::ParamsBindList ubos({graph.sizes_ubo(in_tensor)});
63+
64+
// Normally, the tensor_to_nchw shader is structured so that each thread reads
65+
// one texel from the input texture and writes each component of the texel
66+
// into the corresponding location in the output buffer. However, this shader
67+
// is structured slightly differently in that each thread writes out a
68+
// complete 32 bit integer (containing 4 packed 8-bit integers) into the
69+
// output buffer. Therefore, the global work group size for this shader will
70+
// be the number of elements in the output buffer divided by 4, as opposed to
71+
// the extents of the input texture.
72+
if (shader.kernel_name == "int8_tensor_to_nchw_noint8") {
73+
uint32_t buffer_len = graph.get_staging(out_staging)->numel() / 4;
74+
global_wg_size = {buffer_len, 1, 1};
75+
ubos.append({graph.ntexels_ubo(in_tensor)});
76+
}
77+
6278
if (graph.is_buffer_storage(in_tensor)) {
6379
ubos.append({
6480
graph.texel_strides_ubo(in_tensor),
@@ -69,8 +85,8 @@ void add_tensor_to_staging_node(
6985
graph.execute_nodes().emplace_back(new ExecuteNode(
7086
graph,
7187
shader,
72-
graph.create_global_wg_size(in_tensor),
73-
graph.create_local_wg_size(in_tensor),
88+
global_wg_size,
89+
graph.create_local_wg_size(global_wg_size),
7490
// Input and Outputs
7591
{{in_tensor, vkapi::MemoryAccessType::READ},
7692
{out_staging, vkapi::MemoryAccessType::WRITE}},
@@ -86,7 +102,8 @@ ValueRef prepack(
86102
const utils::GPUMemoryLayout layout) {
87103
ValueRef v = graph.add_tensor_like(vref, layout);
88104

89-
vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(*graph.get_tensor(v));
105+
vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(
106+
*graph.get_tensor(v), graph.int8_buffers_enabled());
90107

91108
vkapi::ParamsBindList ubos({graph.sizes_ubo(v)});
92109
if (graph.is_buffer_storage(v)) {

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,35 @@ void set_staging_zeros(api::StorageBuffer& staging, const size_t nbytes) {
9595
memset(data_ptr, 0, staging.nbytes());
9696
}
9797

98-
vkapi::ShaderInfo get_nchw_to_tensor_shader(const api::vTensor& v_dst) {
98+
vkapi::ShaderInfo get_nchw_to_tensor_shader(
99+
const api::vTensor& v_dst,
100+
const bool int8_buffer_enabled) {
99101
std::string kernel_name;
100102
kernel_name.reserve(kShaderNameReserve);
101103

104+
if (v_dst.dtype() == vkapi::kChar &&
105+
v_dst.storage_type() == utils::kTexture3D && !int8_buffer_enabled) {
106+
return VK_KERNEL(nchw_to_int8_tensor_noint8);
107+
}
108+
102109
kernel_name = "nchw_to_tensor";
103110
add_dtype_suffix(kernel_name, v_dst);
104111
add_storage_type_suffix(kernel_name, v_dst);
105112

106113
return VK_KERNEL_FROM_STR(kernel_name);
107114
}
108115

109-
vkapi::ShaderInfo get_tensor_to_nchw_shader(const api::vTensor& v_src) {
116+
vkapi::ShaderInfo get_tensor_to_nchw_shader(
117+
const api::vTensor& v_src,
118+
bool int8_buffer_enabled) {
110119
std::string kernel_name;
111120
kernel_name.reserve(kShaderNameReserve);
112121

122+
if (v_src.dtype() == vkapi::kChar &&
123+
v_src.storage_type() == utils::kTexture3D && !int8_buffer_enabled) {
124+
return VK_KERNEL(int8_tensor_to_nchw_noint8);
125+
}
126+
113127
kernel_name = "tensor_to_nchw";
114128
add_dtype_suffix(kernel_name, v_src);
115129
add_storage_type_suffix(kernel_name, v_src);

backends/vulkan/runtime/graph/ops/utils/StagingUtils.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ void set_staging_zeros(api::StorageBuffer& staging, const size_t nbytes);
3131
// Functions to get shaders
3232
//
3333

34-
vkapi::ShaderInfo get_nchw_to_tensor_shader(const api::vTensor& v_dst);
35-
vkapi::ShaderInfo get_tensor_to_nchw_shader(const api::vTensor& v_src);
34+
vkapi::ShaderInfo get_nchw_to_tensor_shader(
35+
const api::vTensor& v_dst,
36+
bool int8_buffer_enabled = true);
37+
vkapi::ShaderInfo get_tensor_to_nchw_shader(
38+
const api::vTensor& v_src,
39+
bool int8_buffer_enabled = true);
3640

3741
} // namespace vkcompute

0 commit comments

Comments
 (0)