Skip to content

Commit 60eb1bb

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Clean up max_pool2d (#2575)
Summary: Pull Request resolved: #2575 In anticipation of convolution, I noticed some nits in the max_pool2d implementation: 1. Removed unneeded shader branch. 2. Renamed `kernel` to `kernel_size`. 3. Removed `Int` to `IntList` normalization since this is already performed by the compiler, so the serialized graph will always have `IntList`. The fact that the Python tests still pass when they specify `Int` for `padding=0` and `dilation=1` affirms this. Renamed the helper function appropriately and improve readability. ghstack-source-id: 219657438 bypass-github-export-checks Reviewed By: copyrightly Differential Revision: D55220203 fbshipit-source-id: f4d163e13826a7abb9ec07aaa4ef3c574826e948
1 parent a701d82 commit 60eb1bb

File tree

5 files changed

+22
-28
lines changed

5 files changed

+22
-28
lines changed

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ class ExecuteNode final {
7777
const api::utils::uvec3 global_workgroup_size_;
7878
const api::utils::uvec3 local_workgroup_size_;
7979
const std::vector<ArgGroup> args_;
80-
// TODO(T180906457): allow re-computing param buffers.
8180
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
8281
const ResizeFunction resize_fn_;
8382
const std::vector<ValueRef> resize_args_;

backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ layout(set = 0, binding = 4) uniform PRECISION restrict InExtents {
3030
in_extents;
3131

3232
layout(set = 0, binding = 5) uniform PRECISION restrict Params {
33-
ivec2 kernel;
33+
ivec2 kernel_size;
3434
ivec2 stride;
3535
ivec2 padding;
3636
ivec2 dilation;
@@ -49,7 +49,7 @@ void main() {
4949
const ivec2 ipos = pos.xy * params.stride - params.padding;
5050

5151
const ivec2 start = ipos;
52-
const ivec2 end = ipos + params.kernel * params.dilation;
52+
const ivec2 end = ipos + params.kernel_size * params.dilation;
5353

5454
vec4 out_texel = vec4(FLT_MIN);
5555
ivec4 idx_texel = ivec4(0);
@@ -66,9 +66,6 @@ void main() {
6666

6767
out_texel = max(cur_texel, out_texel);
6868
}
69-
else {
70-
out_texel = max(vec4(FLT_MIN), out_texel);
71-
}
7269
}
7370
}
7471

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,24 +37,24 @@ void resize_max_pool2d_node(
3737
// Channel
3838
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);
3939

40-
const auto kernel = normalize_wh(graph->get_val(extra_args[0]));
41-
const auto stride = normalize_wh(graph->get_val(extra_args[1]));
42-
const auto padding = normalize_wh(graph->get_val(extra_args[2]));
43-
const auto dilation = normalize_wh(graph->get_val(extra_args[3]));
40+
const auto kernel_size = reverse(*graph, extra_args[0]);
41+
const auto stride = reverse(*graph, extra_args[1]);
42+
const auto padding = reverse(*graph, extra_args[2]);
43+
const auto dilation = reverse(*graph, extra_args[3]);
4444
const bool ceil_mode = graph->get_val(extra_args[4]).toBool();
4545

4646
// Height
4747
new_out_sizes.at(ndim - 2) = calc_out_size(
4848
self.sizes().at(ndim - 2),
49-
kernel.data[1],
49+
kernel_size.data[1],
5050
stride.data[1],
5151
padding.data[1],
5252
dilation.data[1],
5353
ceil_mode);
5454
// Width
5555
new_out_sizes.at(ndim - 1) = calc_out_size(
5656
self.sizes().at(ndim - 1),
57-
kernel.data[0],
57+
kernel_size.data[0],
5858
stride.data[0],
5959
padding.data[0],
6060
dilation.data[0],
@@ -77,7 +77,7 @@ void check_max_pool2d_args(const vTensor& in, const vTensor& out) {
7777
void add_max_pool2d_node(
7878
ComputeGraph& graph,
7979
const ValueRef in,
80-
const ValueRef kernel,
80+
const ValueRef kernel_size,
8181
const ValueRef stride,
8282
const ValueRef padding,
8383
const ValueRef dilation,
@@ -99,10 +99,10 @@ void add_max_pool2d_node(
9999
apply_dtype_suffix(kernel_name, t_out);
100100

101101
KernelParams kernel_params{
102-
normalize_wh(graph.get_val(kernel)),
103-
normalize_wh(graph.get_val(stride)),
104-
normalize_wh(graph.get_val(padding)),
105-
normalize_wh(graph.get_val(dilation)),
102+
reverse(graph, kernel_size),
103+
reverse(graph, stride),
104+
reverse(graph, padding),
105+
reverse(graph, dilation),
106106
};
107107

108108
graph.execute_nodes().emplace_back(new ExecuteNode(
@@ -121,7 +121,7 @@ void add_max_pool2d_node(
121121
},
122122
// Resizing
123123
resize_max_pool2d_node,
124-
{kernel, stride, padding, dilation, ceil_mode}));
124+
{kernel_size, stride, padding, dilation, ceil_mode}));
125125
}
126126

127127
void max_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,9 @@ int64_t calc_out_size(
2828
return out_size;
2929
}
3030

31-
api::utils::ivec2 normalize_wh(Value& v) {
32-
if (v.isInt()) {
33-
return api::utils::make_ivec2({v.toInt(), v.toInt()});
34-
} else {
35-
auto l = v.toIntList();
36-
return api::utils::make_ivec2({l.at(1), l.at(0)});
37-
}
31+
api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref) {
32+
return api::utils::make_ivec2(
33+
graph.get_val(vref).toIntList(), /*reverse=*/true);
3834
}
3935

4036
} // namespace vulkan

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,30 @@
1212

1313
#include <ATen/native/vulkan/api/api.h>
1414

15+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
16+
1517
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
1618

1719
namespace at {
1820
namespace native {
1921
namespace vulkan {
2022

2123
struct KernelParams final {
22-
api::utils::ivec2 kernel;
24+
api::utils::ivec2 kernel_size;
2325
api::utils::ivec2 stride;
2426
api::utils::ivec2 padding;
2527
api::utils::ivec2 dilation;
2628
};
2729

2830
int64_t calc_out_size(
2931
const int64_t in_size,
30-
const int64_t kernel,
32+
const int64_t kernel_size,
3133
const int64_t stride,
3234
const int64_t padding,
3335
const int64_t dilation,
3436
const bool ceil_mode);
3537

36-
api::utils::ivec2 normalize_wh(Value& v);
38+
api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref);
3739

3840
} // namespace vulkan
3941
} // namespace native

0 commit comments

Comments
 (0)