Skip to content

Commit 54a3c55

Browse files
committed
[ET-VK][Ops] aten.convolution (Bias=False)
Pull Request resolved: #2887 The final touches to get ET-VK convolution on-par with ATen-VK's convolution. ## Idea In our shaders, we add the bias to our sum. ``` ${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); ``` To keep our shaders as is, we implement having no bias by allocating a buffer of zeros. Then, our shader adds zero to our sum. ## Issue If `Bias=False`, dummy buffer of zeros is not serialized with the graph. The bias ValueRef is deserialized in the runtime as `TypeTag::NONE`, not `TypeTag::TENSORREF`. ## Solution If `TypeTag::NONE` is given, (1) create the `vTensor` using the `out_channels` value from the weights, (2) allocate a StagingBuffer of that size, and (3) `memset` its data to zero. Failure to do (3) will result in undefined behavior. ghstack-source-id: 221887675 @exported-using-ghexport Differential Revision: [D55814589](https://our.internmc.facebook.com/intern/diff/D55814589/)
1 parent 2ea9b5c commit 54a3c55

File tree

6 files changed

+71
-12
lines changed

6 files changed

+71
-12
lines changed

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,33 @@ PrepackNode::PrepackNode(
3232
graph.update_descriptor_counts(shader, /*execute = */ false);
3333
}
3434

35-
void PrepackNode::encode(ComputeGraph* graph) {
36-
api::Context* const context = graph->context();
37-
api::PipelineBarrier pipeline_barrier{};
38-
39-
TensorRef& tref = graph->get_val(tref_).toTensorRef();
35+
api::StorageBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
4036
vTensor& packed = graph->get_val(packed_).toTensor();
4137

38+
// If no TensorRef is provided, create a staging buffer of zeros according to
39+
// the vTensor metadata.
40+
if (graph->get_val(tref_).isNone()) {
41+
size_t numel = api::utils::multiply_integers(packed.sizes());
42+
api::StorageBuffer staging(graph->context(), packed.dtype(), numel);
43+
size_t nbytes = numel * api::element_size(packed.dtype());
44+
copy_zeros_to_staging(staging, nbytes);
45+
return staging;
46+
}
47+
48+
TensorRef& tref = graph->get_val(tref_).toTensorRef();
4249
size_t numel = api::utils::multiply_integers(tref.sizes);
4350
api::StorageBuffer staging(graph->context(), tref.dtype, numel);
4451
size_t nbytes = numel * api::element_size(tref.dtype);
4552
copy_ptr_to_staging(tref.data, staging, nbytes);
53+
return staging;
54+
}
55+
56+
void PrepackNode::encode(ComputeGraph* graph) {
57+
api::Context* const context = graph->context();
58+
api::PipelineBarrier pipeline_barrier{};
59+
60+
vTensor& packed = graph->get_val(packed_).toTensor();
61+
api::StorageBuffer staging = create_staging_buffer(graph);
4662

4763
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
4864

backends/vulkan/runtime/graph/ops/PrepackNode.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class PrepackNode final {
4747
const ValueRef packed_;
4848
// TODO(T180906457): allow re-computing param buffers.
4949
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
50+
51+
private:
52+
api::StorageBuffer create_staging_buffer(ComputeGraph* graph);
5053
};
5154

5255
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,17 @@ void resize_conv2d_node(
5252
out.virtual_resize(new_out_sizes);
5353
}
5454

55-
ValueRef prepack_biases(ComputeGraph& graph, const ValueRef vref) {
56-
if (graph.get_val(vref).isNone()) {
57-
VK_THROW("aten.convolution.default: Null bias is not supported yet!");
58-
}
55+
ValueRef prepack_biases(
56+
ComputeGraph& graph,
57+
const ValueRef vref,
58+
const ValueRef weight,
59+
const bool transposed) {
60+
TensorRef& tref = graph.get_val(weight).toTensorRef();
61+
const int64_t out_channels = transposed ? tref.sizes.at(1) : tref.sizes.at(0);
5962

60-
ValueRef v = graph.add_tensor_like(
61-
vref,
63+
ValueRef v = graph.add_tensor(
64+
{out_channels},
65+
tref.dtype,
6266
api::StorageType::TEXTURE_2D,
6367
api::GPUMemoryLayout::TENSOR_WIDTH_PACKED);
6468
vTensor& t = graph.get_val(v).toTensor();
@@ -301,7 +305,7 @@ void add_conv2d_node(
301305

302306
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
303307
ValueRef arg_weight = prepack_weights(graph, weight, method);
304-
ValueRef arg_bias = prepack_biases(graph, bias);
308+
ValueRef arg_bias = prepack_biases(graph, bias, weight, transposed_val);
305309

306310
vTensor& t_in = graph.get_val(arg_in).toTensor();
307311
vTensor& t_out = graph.get_val(out).toTensor();

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ void copy_staging_to_ptr(
8989
memcpy_from_mapping(mapping, dst, nbytes, staging.dtype());
9090
}
9191

92+
void copy_zeros_to_staging(api::StorageBuffer& staging, const size_t nbytes) {
93+
void* data = malloc(nbytes);
94+
memset(data, 0, nbytes);
95+
copy_ptr_to_staging(data, staging, nbytes);
96+
free(data);
97+
}
98+
9299
api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) {
93100
if (v_dst.is_quantized()) {
94101
VK_THROW("Quantized Tensors are currently not supported!");

backends/vulkan/runtime/graph/ops/utils/StagingUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ void copy_staging_to_ptr(
2525
void* dst,
2626
const size_t nbytes);
2727

28+
void copy_zeros_to_staging(api::StorageBuffer& staging, const size_t nbytes);
29+
2830
//
2931
// Functions to get shaders
3032
//

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,3 +601,30 @@ def forward(self, x):
601601
sample_inputs,
602602
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
603603
)
604+
605+
def test_vulkan_backend_conv2d_bias_false(self):
606+
class Conv2dModule(torch.nn.Module):
607+
def __init__(self):
608+
super().__init__()
609+
self.conv = torch.nn.Conv2d(
610+
in_channels=6,
611+
out_channels=8,
612+
kernel_size=(3, 3),
613+
padding=(2, 3),
614+
stride=(1, 2),
615+
dilation=1,
616+
groups=1,
617+
bias=False,
618+
)
619+
620+
def forward(self, x):
621+
return self.conv(x)
622+
623+
conv2d_module = Conv2dModule()
624+
sample_inputs = (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),)
625+
626+
self.lower_module_and_test_output(
627+
conv2d_module,
628+
sample_inputs,
629+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
630+
)

0 commit comments

Comments
 (0)