Skip to content

Commit 6f9ac70

Browse files
committed
Update on "[ET-VK] Add convolution cases to codegen"
TSIA Differential Revision: [D55829466](https://our.internmc.facebook.com/intern/diff/D55829466/) [ghstack-poisoned]
2 parents 965fa2b + 5350ede commit 6f9ac70

File tree

3 files changed

+6
-7
lines changed

3 files changed

+6
-7
lines changed

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ api::StorageBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
4141
size_t numel = api::utils::multiply_integers(packed.sizes());
4242
api::StorageBuffer staging(graph->context(), packed.dtype(), numel);
4343
size_t nbytes = numel * api::element_size(packed.dtype());
44-
copy_zeros_to_staging(staging, nbytes);
44+
set_staging_zeros(staging, nbytes);
4545
return staging;
4646
}
4747

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,10 @@ void copy_staging_to_ptr(
8989
memcpy_from_mapping(mapping, dst, nbytes, staging.dtype());
9090
}
9191

92-
void copy_zeros_to_staging(api::StorageBuffer& staging, const size_t nbytes) {
93-
void* data = malloc(nbytes);
94-
memset(data, 0, nbytes);
95-
copy_ptr_to_staging(data, staging, nbytes);
96-
free(data);
92+
void set_staging_zeros(api::StorageBuffer& staging, const size_t nbytes) {
93+
api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
94+
uint8_t* data_ptr = mapping.template data<uint8_t>();
95+
memset(data_ptr, 0, staging.nbytes());
9796
}
9897

9998
api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) {

backends/vulkan/runtime/graph/ops/utils/StagingUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ void copy_staging_to_ptr(
2525
void* dst,
2626
const size_t nbytes);
2727

28-
void copy_zeros_to_staging(api::StorageBuffer& staging, const size_t nbytes);
28+
void set_staging_zeros(api::StorageBuffer& staging, const size_t nbytes);
2929

3030
//
3131
// Functions to get shaders

0 commit comments

Comments
 (0)