Skip to content

Commit 87eb155

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
conv1d with bias=False
Summary: Under the same setting as the last diff, we support `bias=false`. Reviewed By: jorgep31415 Differential Revision: D56285842 fbshipit-source-id: 41636d19d2cd7db07ba924606c9cd33999cffdab
1 parent 7b1f10d commit 87eb155

File tree

3 files changed

+53
-12
lines changed

3 files changed

+53
-12
lines changed

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,14 @@ ValueRef prepack_biases(
8181
ComputeGraph& graph,
8282
const ValueRef vref,
8383
const ValueRef weight,
84-
const bool transposed) {
84+
const bool transposed,
85+
const api::StorageType storage_type,
86+
const api::GPUMemoryLayout memory_layout) {
8587
auto sizes = graph.get_sizes_of(weight);
8688
const int64_t out_channels = transposed ? sizes.at(1) : sizes.at(0);
8789

8890
ValueRef v = graph.add_tensor(
89-
{out_channels},
90-
graph.get_dtype_of(weight),
91-
api::kTexture2D,
92-
api::kWidthPacked);
91+
{out_channels}, graph.get_dtype_of(weight), storage_type, memory_layout);
9392
vTensorPtr t = graph.get_tensor(v);
9493

9594
api::ShaderInfo shader = get_nchw_to_image_shader(*t);
@@ -329,7 +328,13 @@ void add_conv2d_node(
329328

330329
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
331330
ValueRef arg_weight = prepack_weights(graph, weight, method);
332-
ValueRef arg_bias = prepack_biases(graph, bias, weight, transposed_val);
331+
ValueRef arg_bias = prepack_biases(
332+
graph,
333+
bias,
334+
weight,
335+
transposed_val,
336+
/* storage_type = */ api::kTexture2D,
337+
/* memory_layout = */ api::kWidthPacked);
333338

334339
vTensorPtr t_in = graph.get_tensor(arg_in);
335340
vTensorPtr t_out = graph.get_tensor(out);
@@ -383,15 +388,16 @@ void add_conv1d_node(
383388
const ValueRef dilation,
384389
const ValueRef groups,
385390
const ValueRef out) {
386-
if (graph.val_is_none(bias)) {
387-
VK_THROW("conv1d: Null bias is not supported yet!");
388-
}
389-
390391
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
391392
ValueRef arg_weight =
392393
prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in));
393-
ValueRef arg_bias =
394-
prepack_if_tensor_ref(graph, bias, graph.memory_layout_of(arg_in));
394+
ValueRef arg_bias = prepack_biases(
395+
graph,
396+
bias,
397+
weight,
398+
/*transposed = */ false,
399+
/*storage_type = */ api::kTexture3D,
400+
/*memory_layout = */ api::kChannelsPacked);
395401

396402
vTensorPtr t_in = graph.get_tensor(arg_in);
397403
vTensorPtr t_weight = graph.get_tensor(arg_weight);

backends/vulkan/test/op_tests/cases.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,17 @@ def get_conv_inputs():
135135
[0],
136136
6,
137137
),
138+
(
139+
(1, 9, 11),
140+
(9, 1, 3),
141+
None,
142+
[1],
143+
[0],
144+
[1],
145+
False,
146+
[0],
147+
9,
148+
),
138149
]
139150
)
140151
return test_suite

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,30 @@ def forward(self, x):
672672
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
673673
)
674674

675+
def test_vulkan_backend_conv1d_bias_false(self):
676+
class Conv1dModule(torch.nn.Module):
677+
def __init__(self):
678+
super().__init__()
679+
self.conv = torch.nn.Conv1d(
680+
in_channels=6,
681+
out_channels=6,
682+
kernel_size=3,
683+
groups=6,
684+
bias=False,
685+
)
686+
687+
def forward(self, x):
688+
return self.conv(x)
689+
690+
conv1d_module = Conv1dModule()
691+
sample_inputs = (torch.randn(size=(1, 6, 7), dtype=torch.float32),)
692+
693+
self.lower_module_and_test_output(
694+
conv1d_module,
695+
sample_inputs,
696+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
697+
)
698+
675699
def test_vulkan_backend_native_layer_norm(self):
676700
class NativeLayerNormModule(torch.nn.Module):
677701
def __init__(self):

0 commit comments

Comments
 (0)