Skip to content

Commit a4f5f02

Browse files
committed
[ET-VK] Refactor Pool.cpp
Pull Request resolved: pytorch/executorch#2836 This change adds more lines than it subtracts, but it'll be worth it once we reuse the methods for `aten.convolution`. ghstack-source-id: 221721759 @exported-using-ghexport Differential Revision: [D55706057](https://our.internmc.facebook.com/intern/diff/D55706057/)
1 parent 55f8060 commit a4f5f02

File tree

3 files changed

+89
-47
lines changed

3 files changed

+89
-47
lines changed

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,38 +28,23 @@ void resize_max_pool2d_node(
2828
size_t ndim = self.sizes().size();
2929
std::vector<int64_t> new_out_sizes(ndim);
3030

31-
// Batch
31+
// Batch, Channel
3232
if (ndim == 4) {
3333
new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4);
3434
}
35-
// Channel
3635
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);
3736

38-
const auto kernel_size = reverse(*graph, extra_args[0]);
39-
const auto stride = reverse(*graph, extra_args[1]);
40-
const auto padding = reverse(*graph, extra_args[2]);
41-
const auto dilation = reverse(*graph, extra_args[3]);
42-
const bool ceil_mode = graph->get_val(extra_args[4]).toBool();
43-
44-
// Height
45-
new_out_sizes.at(ndim - 2) = calc_out_size(
46-
self.sizes().at(ndim - 2),
47-
kernel_size.data[1],
48-
stride.data[1],
49-
padding.data[1],
50-
dilation.data[1],
51-
ceil_mode);
52-
// Width
53-
new_out_sizes.at(ndim - 1) = calc_out_size(
54-
self.sizes().at(ndim - 1),
55-
kernel_size.data[0],
56-
stride.data[0],
57-
padding.data[0],
58-
dilation.data[0],
59-
ceil_mode);
60-
61-
VK_CHECK_COND(new_out_sizes.at(ndim - 2) >= 1);
62-
VK_CHECK_COND(new_out_sizes.at(ndim - 1) >= 1);
37+
// Height, Width
38+
const auto new_out_sizes_hw = calc_out_sizes_hw(
39+
*graph,
40+
self.sizes(),
41+
extra_args[0],
42+
extra_args[1],
43+
extra_args[2],
44+
extra_args[3],
45+
extra_args[4]);
46+
new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0);
47+
new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1);
6348

6449
out.virtual_resize(new_out_sizes);
6550
indices.virtual_resize(new_out_sizes);
@@ -96,12 +81,8 @@ void add_max_pool2d_node(
9681
kernel_name << "max_pool2d";
9782
apply_dtype_suffix(kernel_name, t_out);
9883

99-
KernelParams kernel_params{
100-
reverse(graph, kernel_size),
101-
reverse(graph, stride),
102-
reverse(graph, padding),
103-
reverse(graph, dilation),
104-
};
84+
KernelParams kernel_params =
85+
create_kernel_params(graph, kernel_size, stride, padding, dilation);
10586

10687
graph.execute_nodes().emplace_back(new ExecuteNode(
10788
graph,

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,80 @@
1010

1111
namespace vkcompute {
1212

13+
api::utils::ivec2 make_ivec2_from_list(ComputeGraph& graph, ValueRef vref) {
14+
return api::utils::make_ivec2(
15+
graph.get_val(vref).toIntList(), /*reverse = */ true);
16+
}
17+
18+
KernelParams create_kernel_params(
19+
ComputeGraph& graph,
20+
const ValueRef kernel_size,
21+
const ValueRef stride,
22+
const ValueRef padding,
23+
const ValueRef dilation) {
24+
return {
25+
make_ivec2_from_list(graph, kernel_size),
26+
make_ivec2_from_list(graph, stride),
27+
make_ivec2_from_list(graph, padding),
28+
make_ivec2_from_list(graph, dilation),
29+
};
30+
}
31+
1332
int64_t calc_out_size(
1433
const int64_t in_size,
15-
const int64_t kernel,
34+
const int64_t kernel_size,
1635
const int64_t stride,
1736
const int64_t padding,
1837
const int64_t dilation,
1938
const bool ceil_mode) {
2039
int64_t c = ceil_mode ? stride - 1 : 0;
2140
int64_t out_size =
22-
(in_size + 2 * padding - dilation * (kernel - 1) - 1 + c) / stride + 1;
41+
(in_size + 2 * padding - dilation * (kernel_size - 1) - 1 + c) / stride +
42+
1;
2343
if (ceil_mode && (out_size - 1) * stride >= in_size + padding) {
2444
--out_size;
2545
}
2646
return out_size;
2747
}
2848

29-
api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref) {
30-
return api::utils::make_ivec2(
31-
graph.get_val(vref).toIntList(), /*reverse=*/true);
49+
std::vector<int64_t> calc_out_sizes_hw(
50+
ComputeGraph& graph,
51+
const std::vector<int64_t>& in_sizes,
52+
const ValueRef kernel_size,
53+
const ValueRef stride,
54+
const ValueRef padding,
55+
const ValueRef dilation,
56+
const ValueRef ceil_mode) {
57+
const int64_t ndim = in_sizes.size();
58+
std::vector<int64_t> out_sizes(2);
59+
60+
const auto kernel_vec = make_ivec2_from_list(graph, kernel_size);
61+
const auto stride_vec = make_ivec2_from_list(graph, stride);
62+
const auto padding_vec = make_ivec2_from_list(graph, padding);
63+
const auto dilation_vec = make_ivec2_from_list(graph, dilation);
64+
const bool ceil_mode_val = graph.get_val(ceil_mode).toBool();
65+
66+
// Height
67+
out_sizes.at(0) = calc_out_size(
68+
in_sizes.at(ndim - 2),
69+
kernel_vec.data[1],
70+
stride_vec.data[1],
71+
padding_vec.data[1],
72+
dilation_vec.data[1],
73+
ceil_mode_val);
74+
// Width
75+
out_sizes.at(1) = calc_out_size(
76+
in_sizes.at(ndim - 1),
77+
kernel_vec.data[0],
78+
stride_vec.data[0],
79+
padding_vec.data[0],
80+
dilation_vec.data[0],
81+
ceil_mode_val);
82+
83+
VK_CHECK_COND(out_sizes.at(0) >= 1);
84+
VK_CHECK_COND(out_sizes.at(1) >= 1);
85+
86+
return out_sizes;
3287
}
3388

3489
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,20 @@ struct KernelParams final {
2323
api::utils::ivec2 dilation;
2424
};
2525

26-
int64_t calc_out_size(
27-
const int64_t in_size,
28-
const int64_t kernel_size,
29-
const int64_t stride,
30-
const int64_t padding,
31-
const int64_t dilation,
32-
const bool ceil_mode);
33-
34-
api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref);
26+
KernelParams create_kernel_params(
27+
ComputeGraph& graph,
28+
const ValueRef kernel_size,
29+
const ValueRef stride,
30+
const ValueRef padding,
31+
const ValueRef dilation);
32+
33+
std::vector<int64_t> calc_out_sizes_hw(
34+
ComputeGraph& graph,
35+
const std::vector<int64_t>& in_sizes,
36+
const ValueRef kernel_size,
37+
const ValueRef stride,
38+
const ValueRef padding,
39+
const ValueRef dilation,
40+
const ValueRef ceil_mode);
3541

3642
} // namespace vkcompute

0 commit comments

Comments
 (0)