Skip to content

Commit bf72506

Browse files
committed
Update on "[ET-VK][Ops] aten.convolution (Bias=False)"
The final touches to get ET-VK convolution on-par with ATen-VK's convolution. ## Idea In our shaders, we add the bias to our sum. ``` ${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); ``` To keep our shaders as is, we implement having no bias by allocating a buffer of zeros. Then, our shader adds zero to our sum. ## Issue If `Bias=False`, dummy buffer of zeros is not serialized with the graph. The bias ValueRef is deserialized in the runtime as `TypeTag::NONE`, not `TypeTag::TENSORREF`. ## Solution If `TypeTag::NONE` is given, (1) create the `vTensor` using the `out_channels` value from the weights and (2) allocate a StagingBuffer of that size. The StagingBuffer will be transferred to GPU memory and initialized to zeros. Differential Revision: [D55814589](https://our.internmc.facebook.com/intern/diff/D55814589/) [ghstack-poisoned]
1 parent 6ca821f commit bf72506

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ PrepackNode::PrepackNode(
3535
api::StorageBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
3636
vTensor& packed = graph->get_val(packed_).toTensor();
3737

38-
// If no TensorRef is provided, create a zeroed staging buffer according to
38+
// If no TensorRef is provided, create a staging buffer of zeros according to
3939
// the vTensor metadata.
4040
if (graph->get_val(tref_).isNone()) {
4141
size_t numel = api::utils::multiply_integers(packed.sizes());
4242
api::StorageBuffer staging(graph->context(), packed.dtype(), numel);
43+
size_t nbytes = numel * api::element_size(packed.dtype());
44+
copy_zeros_to_staging(staging, nbytes);
4345
return staging;
4446
}
4547

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ void copy_staging_to_ptr(
8989
memcpy_from_mapping(mapping, dst, nbytes, staging.dtype());
9090
}
9191

92+
void copy_zeros_to_staging(api::StorageBuffer& staging, const size_t nbytes) {
93+
void* data = malloc(nbytes);
94+
memset(data, 0, nbytes);
95+
copy_ptr_to_staging(data, staging, nbytes);
96+
free(data);
97+
}
98+
9299
api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) {
93100
if (v_dst.is_quantized()) {
94101
VK_THROW("Quantized Tensors are currently not supported!");

backends/vulkan/runtime/graph/ops/utils/StagingUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ void copy_staging_to_ptr(
2525
void* dst,
2626
const size_t nbytes);
2727

28+
void copy_zeros_to_staging(api::StorageBuffer& staging, const size_t nbytes);
29+
2830
//
2931
// Functions to get shaders
3032
//

0 commit comments

Comments
 (0)