Skip to content

Commit d3cd09c

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Support staging any bitw8 image, take 2 (#6028)
Summary: Pull Request resolved: #6028 "bitw8" = bit width 8, which is equivalent to 8-bit. We use "bitw8" as the name since shader compilation disallows names starting with a digit. Changes follow from #4485 to support `texture2d` and support `uint8`, respectively. (Redo of #5934) ghstack-source-id: 247078906 Reviewed By: SS-JIA Differential Revision: D64076249 fbshipit-source-id: 541cfddf92c55ebd4c6e39c6bd041fb5aa78b3b9
1 parent fb63da9 commit d3cd09c

9 files changed

+86
-16
lines changed

backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/bitw8_image_to_nchw_nobitw8buffer.glsl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
${define_active_storage_type(STORAGE)}
14+
1315
#include "indexing_utils.h"
1416

1517
layout(std430) buffer;
1618

1719
#extension GL_EXT_control_flow_attributes : require
1820

1921
${layout_declare_buffer(B, "w", "nchw_out", "int")}
20-
${layout_declare_tensor(B, "r", "t_in", "int8", "texture3d")}
22+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2123
${layout_declare_ubo(B, "ivec4", "tensor_sizes")}
2224
${layout_declare_ubo(B, "ivec4", "axis_map")}
2325
${layout_declare_ubo(B, "int", "out_numel")}
@@ -44,7 +46,7 @@ void main() {
4446
const ivec4 tidx = nchwi_to_tidx(in_buf_idx, tensor_sizes);
4547
const ivec4 texture_pos = to_texture_elem_pos(
4648
tidx, tensor_sizes, packed_dim);
47-
values[i] = load_texel(t_in, texture_pos.xyz)[texture_pos.w];
49+
values[i] = ivec4(load_texel(t_in, texture_pos.xyz))[texture_pos.w];
4850
in_buf_idx++;
4951
}
5052

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
bitw8_image_to_nchw_nobitw8buffer:
8+
parameter_names_with_default_values:
9+
STORAGE: texture3d
10+
DTYPE: int8
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: int8
14+
- VALUE: uint8
15+
STORAGE:
16+
- VALUE: texture2d
17+
- VALUE: texture3d
18+
shader_variants:
19+
- NAME: bitw8_image_to_nchw_nobitw8buffer

backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14+
15+
${define_active_storage_type(STORAGE)}
16+
1317
#include "indexing_utils.h"
1418

1519
layout(std430) buffer;
1620

1721
#extension GL_EXT_control_flow_attributes : require
1822

19-
${layout_declare_tensor(B, "w", "t_out", "int8", "texture3d")}
23+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
2024
${layout_declare_buffer(B, "r", "nchw_in", "int")}
2125
${layout_declare_ubo(B, "ivec4", "sizes")}
2226
${layout_declare_ubo(B, "ivec4", "axis_map")}
@@ -71,5 +75,5 @@ void main() {
7175
return;
7276
}
7377

74-
write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx));
78+
write_texel(t_out, lpos_to_pos(lpos, axis_map), VEC4_T(read_texel(tidx)));
7579
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
nchw_to_bitw8_image_nobitw8buffer:
8+
parameter_names_with_default_values:
9+
STORAGE: texture3d
10+
DTYPE: int8
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: int8
14+
- VALUE: uint8
15+
STORAGE:
16+
- VALUE: texture2d
17+
- VALUE: texture3d
18+
shader_variants:
19+
- NAME: nchw_to_bitw8_image_nobitw8buffer

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ void add_staging_to_tensor_node(
5151
{}));
5252
}
5353

54+
const std::string kBitw8PrefixStr = "bitw8_image_to_nchw_nobitw8buffer";
55+
56+
bool is_bitw8_shader(const vkapi::ShaderInfo& shader) {
57+
const auto size = kBitw8PrefixStr.size();
58+
const std::string& shader_prefix_str = shader.kernel_name.substr(0, size);
59+
return shader_prefix_str == kBitw8PrefixStr;
60+
}
61+
5462
void add_tensor_to_staging_node(
5563
ComputeGraph& graph,
5664
const ValueRef in_tensor,
@@ -80,7 +88,7 @@ void add_tensor_to_staging_node(
8088
// output buffer. Therefore, the global work group size for this shader will
8189
// be the number of elements in the output buffer divided by 4, as opposed to
8290
// the extents of the input texture.
83-
if (shader.kernel_name == "int8_image_to_nchw_noint8") {
91+
if (is_bitw8_shader(shader)) {
8492
uint32_t buffer_len = graph.get_staging(out_staging)->numel() / 4;
8593
global_wg_size = {buffer_len, 1, 1};
8694
ubos.append({graph.numel_ubo(in_tensor)});

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,23 @@
1515

1616
namespace vkcompute {
1717

18+
bool is_bitw8(vkapi::ScalarType dtype) {
19+
return dtype == vkapi::kByte || dtype == vkapi::kChar ||
20+
dtype == vkapi::kQInt8 || dtype == vkapi::kQUInt8;
21+
}
22+
1823
vkapi::ShaderInfo get_nchw_to_tensor_shader(
1924
const api::vTensor& v_dst,
2025
const bool int8_buffer_enabled) {
2126
std::string kernel_name;
2227
kernel_name.reserve(kShaderNameReserve);
2328

24-
if (v_dst.dtype() == vkapi::kChar &&
25-
v_dst.storage_type() == utils::kTexture3D && !int8_buffer_enabled) {
26-
return VK_KERNEL(nchw_to_int8_image_noint8);
29+
if (is_bitw8(v_dst.dtype()) && v_dst.storage_type() != utils::kBuffer &&
30+
!int8_buffer_enabled) {
31+
kernel_name = "nchw_to_bitw8_image_nobitw8buffer";
32+
add_dtype_suffix(kernel_name, v_dst);
33+
add_storage_type_suffix(kernel_name, v_dst);
34+
return VK_KERNEL_FROM_STR(kernel_name);
2735
}
2836

2937
if (v_dst.storage_type() == utils::kBuffer) {
@@ -45,9 +53,12 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader(
4553
std::string kernel_name;
4654
kernel_name.reserve(kShaderNameReserve);
4755

48-
if (v_src.dtype() == vkapi::kChar &&
49-
v_src.storage_type() == utils::kTexture3D && !int8_buffer_enabled) {
50-
return VK_KERNEL(int8_image_to_nchw_noint8);
56+
if (is_bitw8(v_src.dtype()) && v_src.storage_type() != utils::kBuffer &&
57+
!int8_buffer_enabled) {
58+
kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
59+
add_dtype_suffix(kernel_name, v_src);
60+
add_storage_type_suffix(kernel_name, v_src);
61+
return VK_KERNEL_FROM_STR(kernel_name);
5162
}
5263

5364
if (v_src.storage_type() == utils::kBuffer) {

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,20 @@ void record_image_to_nchw_op(
111111
v_src.axis_map_ubo());
112112
}
113113

114-
void record_int8_image_to_nchw_noint8_op(
114+
void record_bitw8_image_to_nchw_nobitw8buffer_op(
115115
api::Context* const context,
116116
api::vTensor& v_src,
117117
api::StagingBuffer& dst_buffer) {
118118
vkapi::PipelineBarrier pipeline_barrier{};
119119
uint32_t buffer_len = utils::safe_downcast<uint32_t>(dst_buffer.numel() / 4);
120120
utils::uvec3 global_wg_size = {buffer_len, 1, 1};
121+
122+
std::string kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
123+
add_dtype_suffix(kernel_name, v_src);
124+
add_storage_type_suffix(kernel_name, v_src);
125+
121126
context->submit_compute_job(
122-
VK_KERNEL(int8_image_to_nchw_noint8),
127+
VK_KERNEL_FROM_STR(kernel_name),
123128
pipeline_barrier,
124129
global_wg_size,
125130
adaptive_work_group_size(global_wg_size),

backends/vulkan/test/utils/test_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void record_image_to_nchw_op(
8484
vkcompute::api::vTensor& v_src,
8585
vkcompute::vkapi::VulkanBuffer& dst_buffer);
8686

87-
void record_int8_image_to_nchw_noint8_op(
87+
void record_bitw8_image_to_nchw_nobitw8buffer_op(
8888
vkcompute::api::Context* const context,
8989
vkcompute::api::vTensor& v_src,
9090
vkcompute::api::StagingBuffer& dst_buffer);

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2365,7 +2365,8 @@ void run_from_gpu_test(
23652365

23662366
if (dtype == vkapi::kChar &&
23672367
!context()->adapter_ptr()->has_full_int8_buffers_support()) {
2368-
record_int8_image_to_nchw_noint8_op(context(), vten, staging_buffer);
2368+
record_bitw8_image_to_nchw_nobitw8buffer_op(
2369+
context(), vten, staging_buffer);
23692370
} else {
23702371
record_image_to_nchw_op(context(), vten, staging_buffer.buffer());
23712372
}
@@ -2412,7 +2413,8 @@ void round_trip_test(
24122413
// Copy data in and out of the tensor
24132414
if (dtype == vkapi::kChar &&
24142415
!context()->adapter_ptr()->has_full_int8_buffers_support()) {
2415-
record_int8_image_to_nchw_noint8_op(context(), vten, staging_buffer_out);
2416+
record_bitw8_image_to_nchw_nobitw8buffer_op(
2417+
context(), vten, staging_buffer_out);
24162418
} else {
24172419
record_image_to_nchw_op(context(), vten, staging_buffer_out.buffer());
24182420
}

0 commit comments

Comments
 (0)