Skip to content

Commit 6c50546

Browse files
committed
Update base for Update on "[ET-VK][Ops] aten.convolution (SlidingWindow)"
## The Operator `nn.Module` invocations of [`nn.Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) and [`nn.ConvTranspose2d`](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d) get compiled to `aten.convolution.default` in the Edge Dialect, which carries the signature ``` - func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor ``` ## Summary (cases handled) We introduce support for the convolution cases covered by [ATen-VK's default SlidingWindow implementation](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L73). This is achieved by - reusing the [existing `conv2d.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d.glsl), and - [moving special weights prepacking from CPU](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L134-L235) to the GPU in `conv2d_prepack_weights.glsl`. We also include resizing support for dynamic shapes. Note that only height and width of the input can vary. ## Cases not handled The implementation is on-par with ATen-VK's SlidingWindow. This means the following cases are missing: 1. **Groups G > 1.** Largely not covered by ATen-VK. `G = in_channels` is covered by ATen-VK's Depthwise impl and will be added soon. 2. **Batch (input) N > 1.** Not covered by ATen-VK. 3. **Padding > 0 while Dilation, Kernel > 1.** Not covered by ATen-VK. ## Coming soon 1. Transpose convolution 2. Depthwise convolution (for completeness) 3. Pointwise convolution (for optimization) 4. Null bias Differential Revision: [D55346778](https://our.internmc.facebook.com/intern/diff/D55346778/) [ghstack-poisoned]
1 parent 1f1a2c2 commit 6c50546

File tree

10 files changed

+168
-72
lines changed

10 files changed

+168
-72
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,27 @@ ValueRef ComputeGraph::add_tensor(
132132
sizes, dtype, suggested_storage_type(), memory_layout, shared_object_idx);
133133
}
134134

135+
ValueRef ComputeGraph::add_tensor_like(
136+
const ValueRef vref,
137+
const api::StorageType storage_type,
138+
const api::GPUMemoryLayout memory_layout) {
139+
TensorRef& tref = get_val(vref).toTensorRef();
140+
return add_tensor(tref.sizes, tref.dtype, storage_type, memory_layout);
141+
}
142+
143+
ValueRef ComputeGraph::add_tensor_like(
144+
const ValueRef vref,
145+
const api::GPUMemoryLayout memory_layout) {
146+
TensorRef& tref = get_val(vref).toTensorRef();
147+
return add_tensor(tref.sizes, tref.dtype, memory_layout);
148+
}
149+
135150
ValueRef ComputeGraph::add_tensor(
136151
const std::vector<int64_t>& sizes,
137152
const api::ScalarType dtype,
138153
const int64_t shared_object_idx) {
139154
return add_tensor(
140-
sizes,
141-
dtype,
142-
suggested_storage_type(),
143-
suggested_memory_layout(sizes),
144-
shared_object_idx);
155+
sizes, dtype, suggested_memory_layout(sizes), shared_object_idx);
145156
}
146157

147158
ValueRef ComputeGraph::add_tensorref(

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class ComputeGraph final {
172172
const api::ScalarType dtype,
173173
const api::StorageType storage_type,
174174
const api::GPUMemoryLayout memory_layout,
175-
const int64_t shared_object_idx);
175+
const int64_t shared_object_idx = -1);
176176

177177
/*
178178
* Add a `vTensor` value to the graph with the specified properties. The
@@ -191,9 +191,25 @@ class ComputeGraph final {
191191
*/
192192
ValueRef add_tensor(
193193
const std::vector<int64_t>& sizes,
194-
const api::ScalarType dtype = api::ScalarType::Float,
194+
const api::ScalarType dtype,
195195
const int64_t shared_object_idx = -1);
196196

197+
/*
198+
* Add a `vTensor` value to the graph with the properties of `vref`.
199+
*/
200+
ValueRef add_tensor_like(
201+
const ValueRef vref,
202+
const api::StorageType storage_type,
203+
const api::GPUMemoryLayout memory_layout);
204+
205+
/*
206+
* Add a `vTensor` value to the graph with the properties of `vref`. The
207+
* suggested storage type will be used to construct the `vTensor`.
208+
*/
209+
ValueRef add_tensor_like(
210+
const ValueRef vref,
211+
const api::GPUMemoryLayout memory_layout);
212+
197213
/*
198214
* Add a `TensorRef` value to the graph with the specific properties. A
199215
* `TensorRef` is a reference to a `vTensor` whose data is stored in an

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ void PrepackNode::encode(ComputeGraph* graph) {
3636
api::Context* const context = graph->context();
3737
api::PipelineBarrier pipeline_barrier{};
3838

39-
TensorRef tref = graph->get_val(tref_).toTensorRef();
40-
vTensor packed = graph->get_val(packed_).toTensor();
39+
TensorRef& tref = graph->get_val(tref_).toTensorRef();
40+
vTensor& packed = graph->get_val(packed_).toTensor();
4141

4242
size_t numel = api::utils::multiply_integers(tref.sizes);
4343
api::StorageBuffer staging(graph->context(), tref.dtype, numel);

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,38 +28,23 @@ void resize_max_pool2d_node(
2828
size_t ndim = self.sizes().size();
2929
std::vector<int64_t> new_out_sizes(ndim);
3030

31-
// Batch
31+
// Batch, Channel
3232
if (ndim == 4) {
3333
new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4);
3434
}
35-
// Channel
3635
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);
3736

38-
const auto kernel_size = reverse(*graph, extra_args[0]);
39-
const auto stride = reverse(*graph, extra_args[1]);
40-
const auto padding = reverse(*graph, extra_args[2]);
41-
const auto dilation = reverse(*graph, extra_args[3]);
42-
const bool ceil_mode = graph->get_val(extra_args[4]).toBool();
43-
44-
// Height
45-
new_out_sizes.at(ndim - 2) = calc_out_size(
46-
self.sizes().at(ndim - 2),
47-
kernel_size.data[1],
48-
stride.data[1],
49-
padding.data[1],
50-
dilation.data[1],
51-
ceil_mode);
52-
// Width
53-
new_out_sizes.at(ndim - 1) = calc_out_size(
54-
self.sizes().at(ndim - 1),
55-
kernel_size.data[0],
56-
stride.data[0],
57-
padding.data[0],
58-
dilation.data[0],
59-
ceil_mode);
60-
61-
VK_CHECK_COND(new_out_sizes.at(ndim - 2) >= 1);
62-
VK_CHECK_COND(new_out_sizes.at(ndim - 1) >= 1);
37+
// Height, Width
38+
const auto new_out_sizes_hw = calc_out_sizes_hw(
39+
*graph,
40+
self.sizes(),
41+
extra_args[0],
42+
extra_args[1],
43+
extra_args[2],
44+
extra_args[3],
45+
extra_args[4]);
46+
new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0);
47+
new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1);
6348

6449
out.virtual_resize(new_out_sizes);
6550
indices.virtual_resize(new_out_sizes);
@@ -96,12 +81,8 @@ void add_max_pool2d_node(
9681
kernel_name << "max_pool2d";
9782
apply_dtype_suffix(kernel_name, t_out);
9883

99-
KernelParams kernel_params{
100-
reverse(graph, kernel_size),
101-
reverse(graph, stride),
102-
reverse(graph, padding),
103-
reverse(graph, dilation),
104-
};
84+
KernelParams kernel_params =
85+
create_kernel_params(graph, kernel_size, stride, padding, dilation);
10586

10687
graph.execute_nodes().emplace_back(new ExecuteNode(
10788
graph,

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ ValueRef prepack(
6363
ComputeGraph& graph,
6464
const ValueRef vref,
6565
const api::GPUMemoryLayout layout) {
66-
TensorRef& tref = graph.get_val(vref).toTensorRef();
67-
ValueRef v = graph.add_tensor(tref.sizes, tref.dtype, layout);
68-
vTensor t = graph.get_val(v).toTensor();
66+
ValueRef v = graph.add_tensor_like(vref, layout);
67+
vTensor& t = graph.get_val(v).toTensor();
6968

7069
api::ShaderInfo shader = get_nchw_to_image_shader(t);
7170

backends/vulkan/runtime/graph/ops/impl/Sum.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ void add_sum_dim_IntList(
120120
vTensor& in_tensor = graph.get_val(in).toTensor();
121121

122122
std::set<int64_t> dims_set;
123-
auto dims_to_sum = graph.get_val(opt_dim).toIntList();
123+
const auto& dims_to_sum = graph.get_val(opt_dim).toIntList();
124124
int64_t in_dim = in_tensor.sizes().size();
125125

126126
for (const auto& dim : dims_to_sum) {

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
8282
return add_unary_op_node( \
8383
graph, \
8484
args[0], \
85-
get_val_or_inf(graph, args[1], /*max =*/false), \
86-
get_val_or_inf(graph, args[2], /*max =*/true), \
85+
get_val_or_inf(graph, args[1], /*max = */ false), \
86+
get_val_or_inf(graph, args[2], /*max = */ true), \
8787
args[3], \
8888
kClampShaderName); \
8989
}

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,80 @@
1010

1111
namespace vkcompute {
1212

13+
api::utils::ivec2 make_ivec2_from_list(ComputeGraph& graph, ValueRef vref) {
14+
return api::utils::make_ivec2(
15+
graph.get_val(vref).toIntList(), /*reverse = */ true);
16+
}
17+
18+
KernelParams create_kernel_params(
19+
ComputeGraph& graph,
20+
const ValueRef kernel_size,
21+
const ValueRef stride,
22+
const ValueRef padding,
23+
const ValueRef dilation) {
24+
return {
25+
make_ivec2_from_list(graph, kernel_size),
26+
make_ivec2_from_list(graph, stride),
27+
make_ivec2_from_list(graph, padding),
28+
make_ivec2_from_list(graph, dilation),
29+
};
30+
}
31+
1332
int64_t calc_out_size(
1433
const int64_t in_size,
15-
const int64_t kernel,
34+
const int64_t kernel_size,
1635
const int64_t stride,
1736
const int64_t padding,
1837
const int64_t dilation,
1938
const bool ceil_mode) {
2039
int64_t c = ceil_mode ? stride - 1 : 0;
2140
int64_t out_size =
22-
(in_size + 2 * padding - dilation * (kernel - 1) - 1 + c) / stride + 1;
41+
(in_size + 2 * padding - dilation * (kernel_size - 1) - 1 + c) / stride +
42+
1;
2343
if (ceil_mode && (out_size - 1) * stride >= in_size + padding) {
2444
--out_size;
2545
}
2646
return out_size;
2747
}
2848

29-
api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref) {
30-
return api::utils::make_ivec2(
31-
graph.get_val(vref).toIntList(), /*reverse=*/true);
49+
std::vector<int64_t> calc_out_sizes_hw(
50+
ComputeGraph& graph,
51+
const std::vector<int64_t>& in_sizes,
52+
const ValueRef kernel_size,
53+
const ValueRef stride,
54+
const ValueRef padding,
55+
const ValueRef dilation,
56+
const ValueRef ceil_mode) {
57+
const int64_t ndim = in_sizes.size();
58+
std::vector<int64_t> out_sizes(2);
59+
60+
const auto kernel_vec = make_ivec2_from_list(graph, kernel_size);
61+
const auto stride_vec = make_ivec2_from_list(graph, stride);
62+
const auto padding_vec = make_ivec2_from_list(graph, padding);
63+
const auto dilation_vec = make_ivec2_from_list(graph, dilation);
64+
const bool ceil_mode_val = graph.get_val(ceil_mode).toBool();
65+
66+
// Height
67+
out_sizes.at(0) = calc_out_size(
68+
in_sizes.at(ndim - 2),
69+
kernel_vec.data[1],
70+
stride_vec.data[1],
71+
padding_vec.data[1],
72+
dilation_vec.data[1],
73+
ceil_mode_val);
74+
// Width
75+
out_sizes.at(1) = calc_out_size(
76+
in_sizes.at(ndim - 1),
77+
kernel_vec.data[0],
78+
stride_vec.data[0],
79+
padding_vec.data[0],
80+
dilation_vec.data[0],
81+
ceil_mode_val);
82+
83+
VK_CHECK_COND(out_sizes.at(0) >= 1);
84+
VK_CHECK_COND(out_sizes.at(1) >= 1);
85+
86+
return out_sizes;
3287
}
3388

3489
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,20 @@ struct KernelParams final {
2323
api::utils::ivec2 dilation;
2424
};
2525

26-
int64_t calc_out_size(
27-
const int64_t in_size,
28-
const int64_t kernel_size,
29-
const int64_t stride,
30-
const int64_t padding,
31-
const int64_t dilation,
32-
const bool ceil_mode);
33-
34-
api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref);
26+
KernelParams create_kernel_params(
27+
ComputeGraph& graph,
28+
const ValueRef kernel_size,
29+
const ValueRef stride,
30+
const ValueRef padding,
31+
const ValueRef dilation);
32+
33+
std::vector<int64_t> calc_out_sizes_hw(
34+
ComputeGraph& graph,
35+
const std::vector<int64_t>& in_sizes,
36+
const ValueRef kernel_size,
37+
const ValueRef stride,
38+
const ValueRef padding,
39+
const ValueRef dilation,
40+
const ValueRef ceil_mode);
3541

3642
} // namespace vkcompute

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ TEST(VulkanComputeGraphTest, test_values_scalar_list_inplace_constructed) {
401401
ComputeGraph graph(config);
402402

403403
ValueRef idx = graph.add_scalar_list<int64_t>({1, 2, 3, 4});
404-
std::vector<int64_t>& arr = graph.get_val(idx).toIntList();
404+
const auto& arr = graph.get_val(idx).toIntList();
405405
EXPECT_TRUE(arr.size() == 4);
406406
for (int i = 0; i < 4; i++) {
407407
EXPECT_TRUE(arr[i] == i + 1);
@@ -417,7 +417,7 @@ TEST(VulkanComputeGraphTest, test_values_scalar_list_outside_constructed) {
417417
std::vector<double> data = {5.0, 4.0, 3.0, 2.0, 1.0};
418418
idx = graph.add_scalar_list(std::move(data));
419419
}
420-
std::vector<double>& arr = graph.get_val(idx).toDoubleList();
420+
const auto& arr = graph.get_val(idx).toDoubleList();
421421
EXPECT_TRUE(arr.size() == 5);
422422
for (int i = 0; i < 5; i++) {
423423
EXPECT_TRUE(arr[i] == (5 - i));
@@ -1044,11 +1044,39 @@ void test_mm(
10441044
}
10451045

10461046
TEST(VulkanComputeGraphOpsTest, mm_smoke_test) {
1047-
#define RUN_TESTS(dtype, layout, prepack) \
1048-
test_mm(/*B=*/1, /*M=*/31, /*K=*/127, /*N=*/23, dtype, layout, prepack); \
1049-
test_mm(/*B=*/5, /*M=*/31, /*K=*/127, /*N=*/23, dtype, layout, prepack); \
1050-
test_mm(/*B=*/7, /*M=*/13, /*K=*/89, /*N=*/17, dtype, layout, prepack); \
1051-
test_mm(/*B=*/1, /*M=*/13, /*K=*/89, /*N=*/17, dtype, layout, prepack);
1047+
#define RUN_TESTS(dtype, layout, prepack) \
1048+
test_mm( \
1049+
/*B = */ 1, \
1050+
/*M = */ 31, \
1051+
/*K = */ 127, \
1052+
/*N = */ 23, \
1053+
dtype, \
1054+
layout, \
1055+
prepack); \
1056+
test_mm( \
1057+
/*B = */ 5, \
1058+
/*M = */ 31, \
1059+
/*K = */ 127, \
1060+
/*N = */ 23, \
1061+
dtype, \
1062+
layout, \
1063+
prepack); \
1064+
test_mm( \
1065+
/*B = */ 7, \
1066+
/*M = */ 13, \
1067+
/*K = */ 89, \
1068+
/*N = */ 17, \
1069+
dtype, \
1070+
layout, \
1071+
prepack); \
1072+
test_mm( \
1073+
/*B = */ 1, \
1074+
/*M = */ 13, \
1075+
/*K = */ 89, \
1076+
/*N = */ 17, \
1077+
dtype, \
1078+
layout, \
1079+
prepack);
10521080

10531081
CALL_TEST_FN_FOR_W_PACKED(RUN_TESTS);
10541082
CALL_TEST_FN_FOR_C_PACKED(RUN_TESTS);
@@ -1102,7 +1130,7 @@ void test_max_pool2d(
11021130

11031131
// Run graph
11041132

1105-
fill_vtensor(graph, graph.inputs().at(0), base_val, /*iota=*/true);
1133+
fill_vtensor(graph, graph.inputs().at(0), base_val, /*iota = */ true);
11061134

11071135
vTensor& t_in = graph.get_val(in_ioval.value).toTensor();
11081136
std::vector<float> input_data(t_in.gpu_numel());
@@ -1140,7 +1168,7 @@ void test_max_pool2d(
11401168
TEST(VulkanComputeGraphOpsTest, max_pool2d_smoke_test) {
11411169
std::vector<int64_t> kernel = {2, 3};
11421170
test_max_pool2d(
1143-
/*in_size=*/{1, 4, 6},
1144-
/*base_val=*/10.0f,
1171+
/*in_size = */ {1, 4, 6},
1172+
/*base_val = */ 10.0f,
11451173
kernel);
11461174
}

0 commit comments

Comments
 (0)