Skip to content

[8/n][ET-VK] Support staging any 8-bit texture #5934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@

#define PRECISION ${PRECISION}

${define_active_storage_type(STORAGE)}

#include "indexing_utils.h"

layout(std430) buffer;

#extension GL_EXT_control_flow_attributes : require

${layout_declare_buffer(B, "w", "nchw_out", "int")}
${layout_declare_tensor(B, "r", "t_in", "int8", "texture3d")}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
${layout_declare_ubo(B, "ivec4", "tensor_sizes")}
${layout_declare_ubo(B, "ivec4", "axis_map")}
${layout_declare_ubo(B, "int", "out_numel")}
Expand All @@ -44,7 +46,7 @@ void main() {
const ivec4 tidx = nchwi_to_tidx(in_buf_idx, tensor_sizes);
const ivec4 texture_pos = to_texture_elem_pos(
tidx, tensor_sizes, packed_dim);
values[i] = load_texel(t_in, texture_pos.xyz)[texture_pos.w];
values[i] = ivec4(load_texel(t_in, texture_pos.xyz))[texture_pos.w];
in_buf_idx++;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

bitw8_image_to_nchw_nobitw8buffer:
parameter_names_with_default_values:
STORAGE: texture3d
DTYPE: int8
generate_variant_forall:
DTYPE:
- VALUE: int8
- VALUE: uint8
STORAGE:
- VALUE: texture2d
- VALUE: texture3d
shader_variants:
- NAME: bitw8_image_to_nchw_nobitw8buffer
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

${define_active_storage_type(STORAGE)}

#include "indexing_utils.h"

layout(std430) buffer;

#extension GL_EXT_control_flow_attributes : require

${layout_declare_tensor(B, "w", "t_out", "int8", "texture3d")}
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_buffer(B, "r", "nchw_in", "int")}
${layout_declare_ubo(B, "ivec4", "sizes")}
${layout_declare_ubo(B, "ivec4", "axis_map")}
Expand Down Expand Up @@ -71,5 +75,5 @@ void main() {
return;
}

write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx));
write_texel(t_out, lpos_to_pos(lpos, axis_map), VEC4_T(read_texel(tidx)));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

nchw_to_bitw8_image_nobitw8buffer:
parameter_names_with_default_values:
STORAGE: texture3d
DTYPE: int8
generate_variant_forall:
DTYPE:
- VALUE: int8
- VALUE: uint8
STORAGE:
- VALUE: texture2d
- VALUE: texture3d
shader_variants:
- NAME: nchw_to_bitw8_image_nobitw8buffer
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void add_tensor_to_staging_node(
// output buffer. Therefore, the global work group size for this shader will
// be the number of elements in the output buffer divided by 4, as opposed to
// the extents of the input texture.
if (shader.kernel_name == "int8_image_to_nchw_noint8") {
if (shader.kernel_name.starts_with("bitw8_image_to_nchw_nobitw8buffer")) {
uint32_t buffer_len = graph.get_staging(out_staging)->numel() / 4;
global_wg_size = {buffer_len, 1, 1};
ubos.append({graph.numel_ubo(in_tensor)});
Expand Down
23 changes: 17 additions & 6 deletions backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,23 @@

namespace vkcompute {

bool is_bitw8(vkapi::ScalarType dtype) {
return dtype == vkapi::kByte || dtype == vkapi::kChar ||
dtype == vkapi::kQInt8 || dtype == vkapi::kQUInt8;
}

vkapi::ShaderInfo get_nchw_to_tensor_shader(
const api::vTensor& v_dst,
const bool int8_buffer_enabled) {
std::string kernel_name;
kernel_name.reserve(kShaderNameReserve);

if (v_dst.dtype() == vkapi::kChar &&
v_dst.storage_type() == utils::kTexture3D && !int8_buffer_enabled) {
return VK_KERNEL(nchw_to_int8_image_noint8);
if (is_bitw8(v_dst.dtype()) && v_dst.storage_type() != utils::kBuffer &&
!int8_buffer_enabled) {
kernel_name = "nchw_to_bitw8_image_nobitw8buffer";
add_dtype_suffix(kernel_name, v_dst);
add_storage_type_suffix(kernel_name, v_dst);
return VK_KERNEL_FROM_STR(kernel_name);
}

if (v_dst.storage_type() == utils::kBuffer) {
Expand All @@ -45,9 +53,12 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader(
std::string kernel_name;
kernel_name.reserve(kShaderNameReserve);

if (v_src.dtype() == vkapi::kChar &&
v_src.storage_type() == utils::kTexture3D && !int8_buffer_enabled) {
return VK_KERNEL(int8_image_to_nchw_noint8);
if (is_bitw8(v_src.dtype()) && v_src.storage_type() != utils::kBuffer &&
!int8_buffer_enabled) {
kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
add_dtype_suffix(kernel_name, v_src);
add_storage_type_suffix(kernel_name, v_src);
return VK_KERNEL_FROM_STR(kernel_name);
}

if (v_src.storage_type() == utils::kBuffer) {
Expand Down
9 changes: 7 additions & 2 deletions backends/vulkan/test/utils/test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,20 @@ void record_image_to_nchw_op(
v_src.axis_map_ubo());
}

void record_int8_image_to_nchw_noint8_op(
void record_bitw8_image_to_nchw_nobitw8buffer_op(
api::Context* const context,
api::vTensor& v_src,
api::StagingBuffer& dst_buffer) {
vkapi::PipelineBarrier pipeline_barrier{};
uint32_t buffer_len = utils::safe_downcast<uint32_t>(dst_buffer.numel() / 4);
utils::uvec3 global_wg_size = {buffer_len, 1, 1};

std::string kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
add_dtype_suffix(kernel_name, v_src);
add_storage_type_suffix(kernel_name, v_src);

context->submit_compute_job(
VK_KERNEL(int8_image_to_nchw_noint8),
VK_KERNEL_FROM_STR(kernel_name),
pipeline_barrier,
global_wg_size,
adaptive_work_group_size(global_wg_size),
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/test/utils/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void record_image_to_nchw_op(
vkcompute::api::vTensor& v_src,
vkcompute::vkapi::VulkanBuffer& dst_buffer);

void record_int8_image_to_nchw_noint8_op(
void record_bitw8_image_to_nchw_nobitw8buffer_op(
vkcompute::api::Context* const context,
vkcompute::api::vTensor& v_src,
vkcompute::api::StagingBuffer& dst_buffer);
Expand Down
6 changes: 4 additions & 2 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2365,7 +2365,8 @@ void run_from_gpu_test(

if (dtype == vkapi::kChar &&
!context()->adapter_ptr()->has_full_int8_buffers_support()) {
record_int8_image_to_nchw_noint8_op(context(), vten, staging_buffer);
record_bitw8_image_to_nchw_nobitw8buffer_op(
context(), vten, staging_buffer);
} else {
record_image_to_nchw_op(context(), vten, staging_buffer.buffer());
}
Expand Down Expand Up @@ -2412,7 +2413,8 @@ void round_trip_test(
// Copy data in and out of the tensor
if (dtype == vkapi::kChar &&
!context()->adapter_ptr()->has_full_int8_buffers_support()) {
record_int8_image_to_nchw_noint8_op(context(), vten, staging_buffer_out);
record_bitw8_image_to_nchw_nobitw8buffer_op(
context(), vten, staging_buffer_out);
} else {
record_image_to_nchw_op(context(), vten, staging_buffer_out.buffer());
}
Expand Down
Loading