Skip to content

Commit 53ce2f1

Browse files
committed
Update on "[8/n][ET-VK] Support staging any 8-bit texture"
"bitw8" = bit width 8, which is equivalent to 8-bit. We use "bitw8" as the name since shader compilation disallows names starting with a digit. Changes follow from #4485 to support `texture2d` and support `uint8`, respectively. Differential Revision: [D63918659](https://our.internmc.facebook.com/intern/diff/D63918659/) [ghstack-poisoned]
2 parents 7c182ea + 7a9251e commit 53ce2f1

File tree

7 files changed

+10
-11
lines changed

7 files changed

+10
-11
lines changed

backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/bitw8_image_to_nchw_nobitw8buffer.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
int8_image_to_nchw_noint8:
7+
bitw8_image_to_nchw_nobitw8buffer:
88
parameter_names_with_default_values:
99
STORAGE: texture3d
1010
DTYPE: int8
@@ -16,4 +16,4 @@ int8_image_to_nchw_noint8:
1616
- VALUE: texture2d
1717
- VALUE: texture3d
1818
shader_variants:
19-
- NAME: int8_image_to_nchw_noint8
19+
- NAME: bitw8_image_to_nchw_nobitw8buffer

backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
nchw_to_int8_image_noint8:
7+
nchw_to_bitw8_image_nobitw8buffer:
88
parameter_names_with_default_values:
99
STORAGE: texture3d
1010
DTYPE: int8
@@ -16,4 +16,4 @@ nchw_to_int8_image_noint8:
1616
- VALUE: texture2d
1717
- VALUE: texture3d
1818
shader_variants:
19-
- NAME: nchw_to_int8_image_noint8
19+
- NAME: nchw_to_bitw8_image_nobitw8buffer

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ void add_copy_channel_offset_node(
116116
std::string kernel_name = "copy_channel_offset";
117117
kernel_name.reserve(kShaderNameReserve);
118118
add_dtype_suffix(kernel_name, *t_out);
119-
add_storage_type_suffix(kernel_name, *t_out);
120119

121120
int32_t out_channels = dim_at<kChannel4D>(out_sizes);
122121

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void add_tensor_to_staging_node(
8080
// output buffer. Therefore, the global work group size for this shader will
8181
// be the number of elements in the output buffer divided by 4, as opposed to
8282
// the extents of the input texture.
83-
if (shader.kernel_name.starts_with("int8_image_to_nchw_noint8")) {
83+
if (shader.kernel_name.starts_with("bitw8_image_to_nchw_nobitw8buffer")) {
8484
uint32_t buffer_len = graph.get_staging(out_staging)->numel() / 4;
8585
global_wg_size = {buffer_len, 1, 1};
8686
ubos.append({graph.numel_ubo(in_tensor)});

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
namespace vkcompute {
1717

18-
bool is_8bit(vkapi::ScalarType dtype) {
18+
bool is_bitw8(vkapi::ScalarType dtype) {
1919
return dtype == vkapi::kByte || dtype == vkapi::kChar ||
2020
dtype == vkapi::kQInt8 || dtype == vkapi::kQUInt8;
2121
}
@@ -26,9 +26,9 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader(
2626
std::string kernel_name;
2727
kernel_name.reserve(kShaderNameReserve);
2828

29-
if (is_8bit(v_dst.dtype()) && v_dst.storage_type() != utils::kBuffer &&
29+
if (is_bitw8(v_dst.dtype()) && v_dst.storage_type() != utils::kBuffer &&
3030
!int8_buffer_enabled) {
31-
kernel_name = "nchw_to_int8_image_noint8";
31+
kernel_name = "nchw_to_bitw8_image_nobitw8buffer";
3232
add_dtype_suffix(kernel_name, v_dst);
3333
add_storage_type_suffix(kernel_name, v_dst);
3434
return VK_KERNEL_FROM_STR(kernel_name);
@@ -53,9 +53,9 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader(
5353
std::string kernel_name;
5454
kernel_name.reserve(kShaderNameReserve);
5555

56-
if (is_8bit(v_src.dtype()) && v_src.storage_type() != utils::kBuffer &&
56+
if (is_bitw8(v_src.dtype()) && v_src.storage_type() != utils::kBuffer &&
5757
!int8_buffer_enabled) {
58-
kernel_name = "int8_image_to_nchw_noint8";
58+
kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
5959
add_dtype_suffix(kernel_name, v_src);
6060
add_storage_type_suffix(kernel_name, v_src);
6161
return VK_KERNEL_FROM_STR(kernel_name);

0 commit comments

Comments
 (0)