Skip to content

Commit ad47a9a

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Support int8 texture tensors without requiring int8 buffers
Summary: ## Context By default, storage buffers in Vulkan must contain 32 bit data types; using 8 bit and 16 bit data types in buffers can be enabled optionally by supporting the [VK_KHR_8bit_storage](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_8bit_storage.html) extension or the [VK_KHR_16bit_storage](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_16bit_storage.html) extension respectively. Previously, 8-bit and 16-bit tensors were enabled by using those extensions; however, this meant that 8-bit and 16-bit tensors could not be used if the Vulkan driver does not support the corresponding extension. This diff adds support for 8-bit texture-backed tensors without the need for the VK_KHR_8bit_storage extension. This is done by introducing shaders that manually pack and repack 4 8-bit integers into a single int32 value. Once the tensor data has been transferred to an image texture (which will use the `VK_FORMAT_R8G8B8A8_SINT` image format) the extension will no longer be required. Differential Revision: D60536832
1 parent f611219 commit ad47a9a

14 files changed

+313
-63
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -319,24 +319,20 @@ utils::uvec3 ComputeGraph::create_global_wg_size(const ValueRef idx) {
319319
return image_extents_of(idx);
320320
}
321321

322-
utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {
322+
utils::uvec3 ComputeGraph::create_local_wg_size(
323+
const utils::uvec3 global_wg_size) {
323324
if (config_.enable_local_wg_size_override) {
324325
return config_.local_wg_size_override;
325326
}
326327

327-
if (is_buffer_storage(idx)) {
328-
return {64u, 1u, 1u};
329-
}
330-
331-
const utils::uvec3 image_extents = image_extents_of(idx);
332328
utils::uvec3 local_group_size = {4, 4, 4};
333329

334-
if (image_extents.data[2u] == 1) {
335-
if (image_extents.data[1u] == 1) {
330+
if (global_wg_size.data[2u] == 1) {
331+
if (global_wg_size.data[1u] == 1) {
336332
local_group_size.data[0u] = 64;
337333
local_group_size.data[1u] = 1;
338334
local_group_size.data[2u] = 1;
339-
} else if (image_extents.data[1u] < 8) {
335+
} else if (global_wg_size.data[1u] < 8) {
340336
local_group_size.data[0u] = 16;
341337
local_group_size.data[1u] = 4;
342338
local_group_size.data[2u] = 1;
@@ -349,6 +345,10 @@ utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {
349345
return local_group_size;
350346
}
351347

348+
utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {
349+
return create_local_wg_size(image_extents_of(idx));
350+
}
351+
352352
void ComputeGraph::copy_into_staging(
353353
const ValueRef idx,
354354
const void* data,

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ class ComputeGraph final {
180180
return values_.at(idx).type();
181181
}
182182

183-
// Get Tensor Property
183+
//
184+
// Get Tensor Properties
185+
//
184186

185187
std::vector<int64_t> sizes_of(const ValueRef idx) const;
186188

@@ -226,7 +228,9 @@ class ComputeGraph final {
226228
return values_.at(idx).toTensor().ntexels_ubo();
227229
}
228230

231+
//
229232
// Scalar Value Extraction
233+
//
230234

231235
template <typename T>
232236
T extract_scalar(const ValueRef idx) {
@@ -459,16 +463,21 @@ class ComputeGraph final {
459463
utils::uvec3 create_global_wg_size(const ValueRef idx);
460464

461465
/*
462-
* Suggest a local workgroup size for a given `api::vTensor` value, assuming
463-
* that every shader invocation calculates one texel element of the output
464-
* tensor.
466+
* Suggest a local workgroup size for a given global workgroup size.
465467
*
466468
* The local workgroup size will be formed to try and minimize the number of
467469
* inactive invocations.
468470
*
469471
* Currently, the local workgroup size is hard-coded to contain a total of 64
470472
* shader invocations. In the future, this value can be configured.
471473
*/
474+
utils::uvec3 create_local_wg_size(const utils::uvec3 global_wg_size);
475+
476+
/*
477+
* Convenience function to suggest a local workgroup size for a given
478+
* `api::vTensor` value, assuming that every shader invocation calculates one
479+
* texel element of the output tensor.
480+
*/
472481
utils::uvec3 create_local_wg_size(const ValueRef idx);
473482

474483
//
@@ -500,6 +509,17 @@ class ComputeGraph final {
500509
void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
501510
void propagate_resize();
502511

512+
//
513+
// Miscellaneous Utilities
514+
//
515+
516+
/*
517+
* Check whether the GPU supports 8 bit buffers.
518+
*/
519+
inline bool int8_buffers_enabled() const {
520+
return context_->adapter_ptr()->has_full_int8_buffers_support();
521+
}
522+
503523
//
504524
// Debug support (implemented in Logging.cpp)
505525
//

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ ivec4 from_nchw_buffer_i(int buf_i, ivec4 sizes) {
8080
* Returns: The (x, y, z, n) texel position corresponding to the first element
8181
* of the texel at the specified buffer index
8282
*/
83-
ivec4 to_texel_pos(int buf_i, ivec4 strides, int packed_dim) {
83+
ivec4 to_tensor_idx(int buf_i, ivec4 strides, int packed_dim) {
8484
ivec4 idx;
8585
for (int i = 3; i >= 0; i--) {
8686
if (i != packed_dim) {
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#include "indexing_utils.h"
14+
15+
layout(std430) buffer;
16+
17+
#extension GL_EXT_control_flow_attributes : require
18+
19+
${layout_declare_tensor(0, "r", "t_in", "int8", "texture3d")}
20+
${layout_declare_buffer(1, "w", "nchw_out", "int")}
21+
${layout_declare_ubo(2, "ivec4", "tensor_sizes")}
22+
23+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
24+
25+
layout(constant_id = 3) const int packed_dim = C_DIM;
26+
27+
void main() {
28+
const int out_buf_idx = int(gl_GlobalInvocationID.x);
29+
int in_buf_idx = 4 * out_buf_idx;
30+
31+
ivec4 values;
32+
33+
[[unroll]] for (int i = 0; i < 4; ++i) {
34+
const ivec4 tensor_idx = from_nchw_buffer_i(in_buf_idx, tensor_sizes);
35+
const ivec4 texture_pos = to_texture_elem_pos(
36+
tensor_idx, tensor_sizes, packed_dim);
37+
values[i] = load_texel(t_in, texture_pos.xyz)[texture_pos.w];
38+
in_buf_idx++;
39+
}
40+
41+
// Manually pack 4x 8-bit integers into a 32 bit integer. Note that little
42+
// endian is assumed, since most processors use little endian. Thus the
43+
// "later" values are placed in most significant bytes.
44+
int packed = ((values[3] & 0xFF) << 24)
45+
| ((values[2] & 0xFF) << 16)
46+
| ((values[1] & 0xFF) << 8)
47+
| ((values[0] & 0xFF));
48+
49+
nchw_out[out_buf_idx] = packed;
50+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#include "indexing_utils.h"
14+
15+
layout(std430) buffer;
16+
17+
#extension GL_EXT_control_flow_attributes : require
18+
19+
${layout_declare_tensor(0, "w", "t_out", "int8", "texture3d")}
20+
${layout_declare_buffer(1, "r", "nchw_in", "int")}
21+
${layout_declare_ubo(2, "ivec4", "tensor_sizes")}
22+
23+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
24+
25+
layout(constant_id = 3) const int packed_dim = C_DIM;
26+
27+
/*
28+
* Extends sign of int8
29+
*/
30+
int extend_sign(int x) {
31+
if (x >> 7 == 1) {
32+
return x | 0xFFFFFF00;
33+
}
34+
return x;
35+
}
36+
37+
ivec4 read_texel(ivec4 tensor_idx) {
38+
const ivec4 buf_indices = get_texel_nchw_buffer_ixs(
39+
tensor_idx, tensor_sizes, packed_dim);
40+
41+
int shift = (1 << 8) - 1;
42+
ivec4 masks;
43+
// Masks used to unpack 4x 8-bit values from a 32 bit integer. Note that
44+
// little endian is assumed, as most processors use little endian. Thus the
45+
// most significant bytes correspond to the "latter" packed values.
46+
masks.x = shift << (8 * (buf_indices.x % 4));
47+
masks.y = shift << (8 * (buf_indices.y % 4));
48+
masks.z = shift << (8 * (buf_indices.z % 4));
49+
masks.w = shift << (8 * (buf_indices.w % 4));
50+
51+
ivec4 out_tex = ivec4(0);
52+
53+
[[unroll]] for (int i = 0; i < 4; ++i) {
54+
if (tensor_idx[packed_dim] + i < tensor_sizes[packed_dim]) {
55+
int in_texel = nchw_in[buf_indices[i] / 4];
56+
int extracted_val = (in_texel & masks[i]) >> (8 * (buf_indices[i] % 4));
57+
extracted_val = extend_sign(extracted_val);
58+
out_tex[i] = extracted_val;
59+
}
60+
}
61+
62+
return out_tex;
63+
}
64+
65+
void main() {
66+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
67+
const ivec4 tensor_idx = to_tensor_idx(pos, tensor_sizes, packed_dim);
68+
69+
if (any(greaterThanEqual(tensor_idx, tensor_sizes))) {
70+
return;
71+
}
72+
73+
write_texel(t_out, pos, read_texel(tensor_idx));
74+
}

backends/vulkan/runtime/graph/ops/glsl/nchw_to_tensor.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ void main() {
6262
return;
6363
}
6464

65-
ivec4 tensor_idx = to_texel_pos(t_id, gpu_strides, packed_dim);
65+
ivec4 tensor_idx = to_tensor_idx(t_id, gpu_strides, packed_dim);
6666
tensor_idx[packed_dim] *= 4;
6767
t_out[t_id] = read_texel(tensor_idx);
6868
}

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void main() {
5353
return;
5454
}
5555

56-
const ivec4 out_pos = to_texel_pos(t_id, out_strides, 0);
56+
const ivec4 out_pos = to_tensor_idx(t_id, out_strides, 0);
5757

5858
VEC4_T outtex = q_8w_linear(out_pos, mat1_sizes.x);
5959
write_texel(t_out, t_id, outtex);

backends/vulkan/runtime/graph/ops/glsl/tensor_to_nchw.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void main() {
6161
}
6262

6363
const VEC4_T intex = t_in[t_id];
64-
ivec4 tensor_idx = to_texel_pos(t_id, gpu_strides, packed_dim);
64+
ivec4 tensor_idx = to_tensor_idx(t_id, gpu_strides, packed_dim);
6565
tensor_idx[packed_dim] *= 4;
6666
write_out_texel(intex, tensor_idx);
6767
}

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ void add_staging_to_tensor_node(
2424
vkapi::ShaderInfo shader =
2525
get_nchw_to_tensor_shader(*graph.get_tensor(out_tensor));
2626

27+
if (graph.dtype_of(out_tensor) == vkapi::kChar &&
28+
graph.storage_type_of(out_tensor) == utils::kTexture3D &&
29+
!graph.int8_buffers_enabled()) {
30+
shader = VK_KERNEL(nchw_to_int8_tensor_noint8);
31+
}
32+
2733
vkapi::ParamsBindList ubos({graph.sizes_ubo(out_tensor)});
2834
if (graph.is_buffer_storage(out_tensor)) {
2935
ubos.append({
@@ -58,6 +64,15 @@ void add_tensor_to_staging_node(
5864
vkapi::ShaderInfo shader =
5965
get_tensor_to_nchw_shader(*graph.get_tensor(in_tensor));
6066

67+
utils::uvec3 global_wg_size = graph.create_global_wg_size(in_tensor);
68+
69+
if (graph.dtype_of(in_tensor) == vkapi::kChar &&
70+
!graph.is_buffer_storage(in_tensor) && !graph.int8_buffers_enabled()) {
71+
shader = VK_KERNEL(int8_tensor_to_nchw_noint8);
72+
uint32_t buffer_len = graph.get_staging(out_staging)->numel() / 4;
73+
global_wg_size = {buffer_len, 1, 1};
74+
}
75+
6176
vkapi::ParamsBindList ubos({graph.sizes_ubo(in_tensor)});
6277
if (graph.is_buffer_storage(in_tensor)) {
6378
ubos.append({
@@ -69,8 +84,8 @@ void add_tensor_to_staging_node(
6984
graph.execute_nodes().emplace_back(new ExecuteNode(
7085
graph,
7186
shader,
72-
graph.create_global_wg_size(in_tensor),
73-
graph.create_local_wg_size(in_tensor),
87+
global_wg_size,
88+
graph.create_local_wg_size(global_wg_size),
7489
// Input and Outputs
7590
{{in_tensor, vkapi::MemoryAccessType::READ},
7691
{out_staging, vkapi::MemoryAccessType::WRITE}},

backends/vulkan/test/glsl/all_shaders.yaml

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,12 @@ idx_fill_buffer:
4747
idx_fill_texture:
4848
parameter_names_with_default_values:
4949
DTYPE: float
50-
NDIM: 3
51-
PACKING: CHANNELS_PACKED
5250
generate_variant_forall:
53-
PACKING:
54-
- VALUE: "CHANNELS_PACKED"
55-
SUFFIX: "C_packed"
56-
- VALUE: "WIDTH_PACKED"
57-
SUFFIX: "W_packed"
58-
- VALUE: "HEIGHT_PACKED"
59-
SUFFIX: "H_packed"
6051
DTYPE:
61-
- VALUE: "half"
62-
SUFFIX: "half"
63-
- VALUE: "float"
64-
SUFFIX: "float"
52+
- VALUE: half
53+
- VALUE: float
54+
- VALUE: int
55+
- VALUE: int8
6556
shader_variants:
6657
- NAME: idx_fill_texture
6758

backends/vulkan/test/glsl/idx_fill_texture.glsl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,17 @@
1212

1313
#define VEC4_T ${texel_type(DTYPE)}
1414

15-
#define POS ${get_pos[NDIM]("pos")}
16-
1715
#include "indexing_utils.h"
1816

1917
layout(std430) buffer;
2018

21-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
22-
23-
layout(set = 0, binding = 1) uniform PRECISION restrict Sizes {
24-
ivec4 sizes;
25-
};
19+
${layout_declare_tensor(0, "w", "image_out", DTYPE, "texture3d")}
20+
${layout_declare_ubo(1, "ivec4", "sizes")}
2621

2722
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2823

2924
layout(constant_id = 3) const int packed_dim = C_DIM;
25+
layout(constant_id = 4) const int offset = 10;
3026

3127
void main() {
3228
const ivec3 pos = ivec3(gl_GlobalInvocationID);
@@ -37,6 +33,6 @@ void main() {
3733
}
3834

3935
const ivec4 buf_indices = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim);
40-
VEC4_T texel = VEC4_T(buf_indices);
41-
imageStore(image_out, POS, texel);
36+
VEC4_T texel = VEC4_T(buf_indices) + offset;
37+
imageStore(image_out, pos, texel);
4238
}

0 commit comments

Comments
 (0)