Skip to content

Commit 7c8ccc5

Browse files
committed
Update on "[ET-VK][Ops] aten.convolution (Bias=False)"
The final touches to get ET-VK convolution on-par with ATen-VK's convolution. ## Idea In our shaders, we add the bias to our sum. ``` ${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); ``` To keep our shaders as is, we implement having no bias by allocating a buffer of zeros. Then, our shader adds zero to our sum. ## Issue If `Bias=False`, dummy buffer of zeros is not serialized with the graph. The bias ValueRef is deserialized in the runtime as `TypeTag::NONE`, not `TypeTag::TENSORREF`. ## Solution If `TypeTag::NONE` is given, (1) create the `vTensor` using the `out_channels` value from the weights and (2) allocate a StagingBuffer of that size. The StagingBuffer will be transferred to GPU memory and initialized to zeros. Differential Revision: [D55814589](https://our.internmc.facebook.com/intern/diff/D55814589/) [ghstack-poisoned]
2 parents ca686ae + a85f33e commit 7c8ccc5

File tree

3 files changed

+6
-7
lines changed

3 files changed

+6
-7
lines changed

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ api::StorageBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
4141
size_t numel = api::utils::multiply_integers(packed.sizes());
4242
api::StorageBuffer staging(graph->context(), packed.dtype(), numel);
4343
size_t nbytes = numel * api::element_size(packed.dtype());
44-
copy_zeros_to_staging(staging, nbytes);
44+
set_staging_zeros(staging, nbytes);
4545
return staging;
4646
}
4747

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,10 @@ void copy_staging_to_ptr(
8989
memcpy_from_mapping(mapping, dst, nbytes, staging.dtype());
9090
}
9191

92-
void copy_zeros_to_staging(api::StorageBuffer& staging, const size_t nbytes) {
93-
void* data = malloc(nbytes);
94-
memset(data, 0, nbytes);
95-
copy_ptr_to_staging(data, staging, nbytes);
96-
free(data);
92+
void set_staging_zeros(api::StorageBuffer& staging, const size_t nbytes) {
93+
api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
94+
uint8_t* data_ptr = mapping.template data<uint8_t>();
95+
memset(data_ptr, 0, staging.nbytes());
9796
}
9897

9998
api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) {

backends/vulkan/runtime/graph/ops/utils/StagingUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ void copy_staging_to_ptr(
2525
void* dst,
2626
const size_t nbytes);
2727

28-
void copy_zeros_to_staging(api::StorageBuffer& staging, const size_t nbytes);
28+
void set_staging_zeros(api::StorageBuffer& staging, const size_t nbytes);
2929

3030
//
3131
// Functions to get shaders

0 commit comments

Comments
 (0)