Skip to content

Commit 952f3cd

Browse files
committed
[8/n][ET-VK] Support staging any 8-bit texture
Pull Request resolved: #5934 Changes following from #4485 to support `texture2d` and support `uint8`, respectively. ghstack-source-id: 246658454 @exported-using-ghexport Differential Revision: [D63918659](https://our.internmc.facebook.com/intern/diff/D63918659/)
1 parent f1bdb50 commit 952f3cd

File tree

6 files changed

+66
-11
lines changed

6 files changed

+66
-11
lines changed

backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
${define_active_storage_type(STORAGE)}
14+
1315
#include "indexing_utils.h"
1416

1517
layout(std430) buffer;
1618

1719
#extension GL_EXT_control_flow_attributes : require
1820

1921
${layout_declare_buffer(B, "w", "nchw_out", "int")}
20-
${layout_declare_tensor(B, "r", "t_in", "int8", "texture3d")}
22+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2123
${layout_declare_ubo(B, "ivec4", "tensor_sizes")}
2224
${layout_declare_ubo(B, "ivec4", "axis_map")}
2325
${layout_declare_ubo(B, "int", "out_numel")}
@@ -44,7 +46,7 @@ void main() {
4446
const ivec4 tidx = nchwi_to_tidx(in_buf_idx, tensor_sizes);
4547
const ivec4 texture_pos = to_texture_elem_pos(
4648
tidx, tensor_sizes, packed_dim);
47-
values[i] = load_texel(t_in, texture_pos.xyz)[texture_pos.w];
49+
values[i] = ivec4(load_texel(t_in, texture_pos.xyz))[texture_pos.w];
4850
in_buf_idx++;
4951
}
5052

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
int8_image_to_nchw_noint8:
8+
parameter_names_with_default_values:
9+
STORAGE: texture3d
10+
DTYPE: int8
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: int8
14+
- VALUE: uint8
15+
STORAGE:
16+
- VALUE: texture2d
17+
- VALUE: texture3d
18+
shader_variants:
19+
- NAME: int8_image_to_nchw_noint8

backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14+
15+
${define_active_storage_type(STORAGE)}
16+
1317
#include "indexing_utils.h"
1418

1519
layout(std430) buffer;
1620

1721
#extension GL_EXT_control_flow_attributes : require
1822

19-
${layout_declare_tensor(B, "w", "t_out", "int8", "texture3d")}
23+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
2024
${layout_declare_buffer(B, "r", "nchw_in", "int")}
2125
${layout_declare_ubo(B, "ivec4", "sizes")}
2226
${layout_declare_ubo(B, "ivec4", "axis_map")}
@@ -71,5 +75,5 @@ void main() {
7175
return;
7276
}
7377

74-
write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx));
78+
write_texel(t_out, lpos_to_pos(lpos, axis_map), VEC4_T(read_texel(tidx)));
7579
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
nchw_to_int8_image_noint8:
8+
parameter_names_with_default_values:
9+
STORAGE: texture3d
10+
DTYPE: int8
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: int8
14+
- VALUE: uint8
15+
STORAGE:
16+
- VALUE: texture2d
17+
- VALUE: texture3d
18+
shader_variants:
19+
- NAME: nchw_to_int8_image_noint8

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void add_tensor_to_staging_node(
8080
// output buffer. Therefore, the global work group size for this shader will
8181
// be the number of elements in the output buffer divided by 4, as opposed to
8282
// the extents of the input texture.
83-
if (shader.kernel_name == "int8_image_to_nchw_noint8") {
83+
if (shader.kernel_name.starts_with("int8_image_to_nchw_noint8")) {
8484
uint32_t buffer_len = graph.get_staging(out_staging)->numel() / 4;
8585
global_wg_size = {buffer_len, 1, 1};
8686
ubos.append({graph.numel_ubo(in_tensor)});

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,23 @@
1515

1616
namespace vkcompute {
1717

18+
bool is_8bit(vkapi::ScalarType dtype) {
19+
return dtype == vkapi::kByte || dtype == vkapi::kChar ||
20+
dtype == vkapi::kQInt8 || dtype == vkapi::kQUInt8;
21+
}
22+
1823
vkapi::ShaderInfo get_nchw_to_tensor_shader(
1924
const api::vTensor& v_dst,
2025
const bool int8_buffer_enabled) {
2126
std::string kernel_name;
2227
kernel_name.reserve(kShaderNameReserve);
2328

24-
if (v_dst.dtype() == vkapi::kChar &&
25-
v_dst.storage_type() == utils::kTexture3D && !int8_buffer_enabled) {
26-
return VK_KERNEL(nchw_to_int8_image_noint8);
29+
if (is_8bit(v_dst.dtype()) && v_dst.storage_type() != utils::kBuffer &&
30+
!int8_buffer_enabled) {
31+
kernel_name = "nchw_to_int8_image_noint8";
32+
add_dtype_suffix(kernel_name, v_dst);
33+
add_storage_type_suffix(kernel_name, v_dst);
34+
return VK_KERNEL_FROM_STR(kernel_name);
2735
}
2836

2937
if (v_dst.storage_type() == utils::kBuffer) {
@@ -45,9 +53,12 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader(
4553
std::string kernel_name;
4654
kernel_name.reserve(kShaderNameReserve);
4755

48-
if (v_src.dtype() == vkapi::kChar &&
49-
v_src.storage_type() == utils::kTexture3D && !int8_buffer_enabled) {
50-
return VK_KERNEL(int8_image_to_nchw_noint8);
56+
if (is_8bit(v_src.dtype()) && v_src.storage_type() != utils::kBuffer &&
57+
!int8_buffer_enabled) {
58+
kernel_name = "int8_image_to_nchw_noint8";
59+
add_dtype_suffix(kernel_name, v_src);
60+
add_storage_type_suffix(kernel_name, v_src);
61+
return VK_KERNEL_FROM_STR(kernel_name);
5162
}
5263

5364
if (v_src.storage_type() == utils::kBuffer) {

0 commit comments

Comments
 (0)