Skip to content

Commit 8793427

Browse files
committed
[ET-VK] Refactor Pool.cpp
This change adds more lines than it subtracts, but it'll be worth it once we reuse the methods for `aten.convolution`. Differential Revision: [D55706057](https://our.internmc.facebook.com/intern/diff/D55706057/) ghstack-source-id: 221165542 Pull Request resolved: #2836
1 parent 6713c58 commit 8793427

File tree

3 files changed

+88
-38
lines changed

3 files changed

+88
-38
lines changed

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,38 +28,23 @@ void resize_max_pool2d_node(
2828
size_t ndim = self.sizes().size();
2929
std::vector<int64_t> new_out_sizes(ndim);
3030

31-
// Batch
31+
// Batch, Channel
3232
if (ndim == 4) {
3333
new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4);
3434
}
35-
// Channel
3635
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);
3736

38-
const auto kernel_size = reverse(*graph, extra_args[0]);
39-
const auto stride = reverse(*graph, extra_args[1]);
40-
const auto padding = reverse(*graph, extra_args[2]);
41-
const auto dilation = reverse(*graph, extra_args[3]);
42-
const bool ceil_mode = graph->get_val(extra_args[4]).toBool();
43-
44-
// Height
45-
new_out_sizes.at(ndim - 2) = calc_out_size(
46-
self.sizes().at(ndim - 2),
47-
kernel_size.data[1],
48-
stride.data[1],
49-
padding.data[1],
50-
dilation.data[1],
51-
ceil_mode);
52-
// Width
53-
new_out_sizes.at(ndim - 1) = calc_out_size(
54-
self.sizes().at(ndim - 1),
55-
kernel_size.data[0],
56-
stride.data[0],
57-
padding.data[0],
58-
dilation.data[0],
59-
ceil_mode);
60-
61-
VK_CHECK_COND(new_out_sizes.at(ndim - 2) >= 1);
62-
VK_CHECK_COND(new_out_sizes.at(ndim - 1) >= 1);
37+
// Height, Width
38+
const auto hw_sizes = calc_hw_out_sizes(
39+
*graph,
40+
self.sizes(),
41+
extra_args[0],
42+
extra_args[1],
43+
extra_args[2],
44+
extra_args[3],
45+
extra_args[4]);
46+
new_out_sizes.at(ndim - 2) = hw_sizes.at(0);
47+
new_out_sizes.at(ndim - 1) = hw_sizes.at(1);
6348

6449
out.virtual_resize(new_out_sizes);
6550
indices.virtual_resize(new_out_sizes);
@@ -96,12 +81,8 @@ void add_max_pool2d_node(
9681
kernel_name << "max_pool2d";
9782
apply_dtype_suffix(kernel_name, t_out);
9883

99-
KernelParams kernel_params{
100-
reverse(graph, kernel_size),
101-
reverse(graph, stride),
102-
reverse(graph, padding),
103-
reverse(graph, dilation),
104-
};
84+
KernelParams kernel_params =
85+
create_kernel_params(graph, kernel_size, stride, padding, dilation);
10586

10687
graph.execute_nodes().emplace_back(new ExecuteNode(
10788
graph,

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,25 @@
1010

1111
namespace vkcompute {
1212

13+
api::utils::ivec2
14+
make_ivec2_int_list(ComputeGraph& graph, ValueRef vref, const bool reverse) {
15+
return api::utils::make_ivec2(graph.get_val(vref).toIntList(), reverse);
16+
}
17+
18+
KernelParams create_kernel_params(
19+
ComputeGraph& graph,
20+
const ValueRef kernel_size,
21+
const ValueRef stride,
22+
const ValueRef padding,
23+
const ValueRef dilation) {
24+
return {
25+
make_ivec2_int_list(graph, kernel_size, /*reverse=*/true),
26+
make_ivec2_int_list(graph, stride, /*reverse=*/true),
27+
make_ivec2_int_list(graph, padding, /*reverse=*/true),
28+
make_ivec2_int_list(graph, dilation, /*reverse=*/true),
29+
};
30+
}
31+
1332
int64_t calc_out_size(
1433
const int64_t in_size,
1534
const int64_t kernel,
@@ -26,9 +45,46 @@ int64_t calc_out_size(
2645
return out_size;
2746
}
2847

29-
api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref) {
30-
return api::utils::make_ivec2(
31-
graph.get_val(vref).toIntList(), /*reverse=*/true);
48+
std::vector<int64_t> calc_hw_out_sizes(
49+
ComputeGraph& graph,
50+
const std::vector<int64_t>& in_sizes,
51+
const ValueRef kernel_size,
52+
const ValueRef stride,
53+
const ValueRef padding,
54+
const ValueRef dilation,
55+
const ValueRef ceil_mode) {
56+
const int64_t ndim = in_sizes.size();
57+
std::vector<int64_t> out_sizes(2);
58+
59+
const auto kernel_vec =
60+
make_ivec2_int_list(graph, kernel_size, /*reverse=*/false);
61+
const auto stride_vec = make_ivec2_int_list(graph, stride, /*reverse=*/false);
62+
const auto padding_vec =
63+
make_ivec2_int_list(graph, padding, /*reverse=*/false);
64+
const auto dilation_vec =
65+
make_ivec2_int_list(graph, dilation, /*reverse=*/false);
66+
67+
// Height
68+
out_sizes.at(0) = calc_out_size(
69+
in_sizes.at(ndim - 2),
70+
kernel_vec.data[0],
71+
stride_vec.data[0],
72+
padding_vec.data[0],
73+
dilation_vec.data[0],
74+
ceil_mode);
75+
// Width
76+
out_sizes.at(1) = calc_out_size(
77+
in_sizes.at(ndim - 1),
78+
kernel_vec.data[1],
79+
stride_vec.data[1],
80+
padding_vec.data[1],
81+
dilation_vec.data[1],
82+
ceil_mode);
83+
84+
VK_CHECK_COND(out_sizes.at(0) >= 1);
85+
VK_CHECK_COND(out_sizes.at(1) >= 1);
86+
87+
return out_sizes;
3288
}
3389

3490
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ struct KernelParams final {
2323
api::utils::ivec2 dilation;
2424
};
2525

26+
KernelParams create_kernel_params(
27+
ComputeGraph& graph,
28+
const ValueRef kernel_size,
29+
const ValueRef stride,
30+
const ValueRef padding,
31+
const ValueRef dilation);
32+
2633
int64_t calc_out_size(
2734
const int64_t in_size,
2835
const int64_t kernel_size,
@@ -31,6 +38,12 @@ int64_t calc_out_size(
3138
const int64_t dilation,
3239
const bool ceil_mode);
3340

34-
api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref);
35-
41+
std::vector<int64_t> calc_hw_out_sizes(
42+
ComputeGraph& graph,
43+
const std::vector<int64_t>& in_sizes,
44+
const ValueRef kernel_size,
45+
const ValueRef stride,
46+
const ValueRef padding,
47+
const ValueRef dilation,
48+
const ValueRef ceil_mode);
3649
} // namespace vkcompute

0 commit comments

Comments
 (0)