Skip to content

Commit f0315db

Browse files
committed
Update on "[8/n][ET-VK] Support staging any 8-bit texture"
"bitw8" = bit width 8, which is equivalent to 8-bit. We use "bitw8" as the name since shader compilation disallows names starting with a digit. Changes follow from #4485 to support `texture2d` and support `uint8`, respectively. Differential Revision: [D63918659](https://our.internmc.facebook.com/intern/diff/D63918659/) [ghstack-poisoned]
2 parents 458b373 + 3854e9e commit f0315db

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,20 @@ void record_image_to_nchw_op(
111111
v_src.axis_map_ubo());
112112
}
113113

114-
void record_int8_image_to_nchw_noint8_op(
114+
void record_bitw8_image_to_nchw_nobitw8buffer_op(
115115
api::Context* const context,
116116
api::vTensor& v_src,
117117
api::StagingBuffer& dst_buffer) {
118118
vkapi::PipelineBarrier pipeline_barrier{};
119119
uint32_t buffer_len = utils::safe_downcast<uint32_t>(dst_buffer.numel() / 4);
120120
utils::uvec3 global_wg_size = {buffer_len, 1, 1};
121+
122+
std::string kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
123+
add_dtype_suffix(kernel_name, v_src);
124+
add_storage_type_suffix(kernel_name, v_src);
125+
121126
context->submit_compute_job(
122-
VK_KERNEL(int8_image_to_nchw_noint8),
127+
VK_KERNEL_FROM_STR(kernel_name),
123128
pipeline_barrier,
124129
global_wg_size,
125130
adaptive_work_group_size(global_wg_size),

backends/vulkan/test/utils/test_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void record_image_to_nchw_op(
8484
vkcompute::api::vTensor& v_src,
8585
vkcompute::vkapi::VulkanBuffer& dst_buffer);
8686

87-
void record_int8_image_to_nchw_noint8_op(
87+
void record_bitw8_image_to_nchw_nobitw8buffer_op(
8888
vkcompute::api::Context* const context,
8989
vkcompute::api::vTensor& v_src,
9090
vkcompute::api::StagingBuffer& dst_buffer);

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2365,7 +2365,7 @@ void run_from_gpu_test(
23652365

23662366
if (dtype == vkapi::kChar &&
23672367
!context()->adapter_ptr()->has_full_int8_buffers_support()) {
2368-
record_int8_image_to_nchw_noint8_op(context(), vten, staging_buffer);
2368+
record_bitw8_image_to_nchw_nobitw8buffer_op(context(), vten, staging_buffer);
23692369
} else {
23702370
record_image_to_nchw_op(context(), vten, staging_buffer.buffer());
23712371
}
@@ -2412,7 +2412,7 @@ void round_trip_test(
24122412
// Copy data in and out of the tensor
24132413
if (dtype == vkapi::kChar &&
24142414
!context()->adapter_ptr()->has_full_int8_buffers_support()) {
2415-
record_int8_image_to_nchw_noint8_op(context(), vten, staging_buffer_out);
2415+
record_bitw8_image_to_nchw_nobitw8buffer_op(context(), vten, staging_buffer_out);
24162416
} else {
24172417
record_image_to_nchw_op(context(), vten, staging_buffer_out.buffer());
24182418
}

0 commit comments

Comments
 (0)