Skip to content

Commit a8a49a9

Browse files
jorgep31415facebook-github-bot
authored andcommitted
aten.max_pool2d_with_indices (#2547)
Summary: Pull Request resolved: #2547 ## The Operator An `nn.Module` invocation of `torch.nn.MaxPool2d()` is represented as `aten.max_pool2d_with_indices.default` in the Edge Dialect, indpendent of `use_indices = True/False`. ``` # Return: (Tensor output, Tensor indices) - func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) ``` This is different from PT-VK where `torch.nn.MaxPool2d()` was represented as `aten.max_pool2d.default`. ``` - func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor ``` The difference is we now return an additional tensor for the max indices. Still, much of the core logic is taken from [`max_pool2d.glsl`](https://github.com/pytorch/pytorch/blob/cceabe873f11c6611f627a3bb0055994952ec6b8/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl) and [`Pool.cpp`](https://github.com/pytorch/pytorch/blob/cceabe873f11c6611f627a3bb0055994952ec6b8/aten/src/ATen/native/vulkan/ops/Pool.cpp). We provide only a `CHANNELS_PACKED` implementation. ## The Smoke Test Given any input and kernel sizes, we fill the input tensor with increasing values, e.g., ``` tensor([[[10., 11., 12., 13., 14., 15.], [16., 17., 18., 19., 20., 21.], [22., 23., 24., 25., 26., 27.], [28., 29., 30., 31., 32., 33.]]]) ``` With this setup, the max number for each pool is always in the lower-right. We use the kernel size to compute the size of the lower-right block and verify that 1. the output tensor values match the lower-right block values, and 2. the index tensor values match the lower-right block indices. ``` tensor([[[ 18., 19., 20., 21.], [ 24., 25., 26., 27.], [ 30., 31., 32., 33.]]]) tensor([[[ 8, 9, 10, 11], [14, 15, 16, 17], [20, 21, 22, 23]]]) ``` ghstack-source-id: 219506264 exported-using-ghexport bypass-github-export-checks Reviewed By: SS-JIA Differential Revision: D54961929 fbshipit-source-id: 277a629e973dd72d49fc059ace2d7a9a9388ab36
1 parent 8cf4d1f commit a8a49a9

File tree

9 files changed

+423
-3
lines changed

9 files changed

+423
-3
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4141
exir_ops.edge.aten.relu.default,
4242
# Matrix multiplication operators
4343
exir_ops.edge.aten.mm.default,
44+
# Pooling operators
45+
exir_ops.edge.aten.max_pool2d_with_indices.default,
4446
# Other
4547
operator.getitem,
4648
]
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define FLT_MIN -3.402823466e+38
13+
14+
#include "indexing_utils.h"
15+
16+
layout(std430) buffer;
17+
18+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
19+
layout(set = 0, binding = 1, ${IMAGE_FORMAT["int"]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM]["int"]} image_idx;
20+
layout(set = 0, binding = 2) uniform PRECISION sampler3D image_in;
21+
22+
layout(set = 0, binding = 3) uniform PRECISION restrict OutExtents {
23+
uvec4 data;
24+
}
25+
out_extents;
26+
27+
layout(set = 0, binding = 4) uniform PRECISION restrict InExtents {
28+
uvec4 data;
29+
}
30+
in_extents;
31+
32+
layout(set = 0, binding = 5) uniform PRECISION restrict Params {
33+
ivec2 kernel;
34+
ivec2 stride;
35+
ivec2 padding;
36+
ivec2 dilation;
37+
}
38+
params;
39+
40+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
41+
42+
void main() {
43+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
44+
45+
if (any(greaterThanEqual(pos, out_extents.data.xyz))) {
46+
return;
47+
}
48+
49+
const ivec2 ipos = pos.xy * params.stride - params.padding;
50+
51+
const ivec2 start = ipos;
52+
const ivec2 end = ipos + params.kernel * params.dilation;
53+
54+
vec4 out_texel = vec4(FLT_MIN);
55+
ivec4 idx_texel = ivec4(0);
56+
57+
for (int y = start.y; y < end.y; y += params.dilation.y) {
58+
for (int x = start.x; x < end.x; x += params.dilation.x) {
59+
if ((x >= 0 && x < in_extents.data.x) && (y >= 0 && y < in_extents.data.y)) {
60+
const vec4 cur_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
61+
62+
// Set idx if value is greatest in the pool; else, keep the existing idx.
63+
ivec4 cur_idx = ivec4(x + int(in_extents.data.x) * y);
64+
ivec4 mask = ivec4(greaterThan(cur_texel, out_texel));
65+
idx_texel = ivec4(mix(idx_texel, cur_idx, mask));
66+
67+
out_texel = max(cur_texel, out_texel);
68+
}
69+
else {
70+
out_texel = max(vec4(FLT_MIN), out_texel);
71+
}
72+
}
73+
}
74+
75+
imageStore(image_out, pos, out_texel);
76+
imageStore(image_idx, pos, idx_texel);
77+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
max_pool2d:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
SUFFIX: half
15+
- VALUE: float
16+
SUFFIX: float
17+
shader_variants:
18+
- NAME: max_pool2d
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace at {
19+
namespace native {
20+
namespace vulkan {
21+
22+
void resize_max_pool2d_node(
23+
ComputeGraph* graph,
24+
const std::vector<ArgGroup>& args,
25+
const std::vector<ValueRef>& extra_args) {
26+
vTensor& out = graph->get_val(args[0].refs[0]).toTensor();
27+
vTensor& indices = graph->get_val(args[0].refs[1]).toTensor();
28+
vTensor& self = graph->get_val(args[1].refs[0]).toTensor();
29+
30+
size_t ndim = self.sizes().size();
31+
std::vector<int64_t> new_out_sizes(ndim);
32+
33+
// Batch
34+
if (ndim == 4) {
35+
new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4);
36+
}
37+
// Channel
38+
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);
39+
40+
const auto kernel = normalize_wh(graph->get_val(extra_args[0]));
41+
const auto stride = normalize_wh(graph->get_val(extra_args[1]));
42+
const auto padding = normalize_wh(graph->get_val(extra_args[2]));
43+
const auto dilation = normalize_wh(graph->get_val(extra_args[3]));
44+
const bool ceil_mode = graph->get_val(extra_args[4]).toBool();
45+
46+
// Height
47+
new_out_sizes.at(ndim - 2) = calc_out_size(
48+
self.sizes().at(ndim - 2),
49+
kernel.data[1],
50+
stride.data[1],
51+
padding.data[1],
52+
dilation.data[1],
53+
ceil_mode);
54+
// Width
55+
new_out_sizes.at(ndim - 1) = calc_out_size(
56+
self.sizes().at(ndim - 1),
57+
kernel.data[0],
58+
stride.data[0],
59+
padding.data[0],
60+
dilation.data[0],
61+
ceil_mode);
62+
63+
VK_CHECK_COND(new_out_sizes.at(ndim - 2) >= 1);
64+
VK_CHECK_COND(new_out_sizes.at(ndim - 1) >= 1);
65+
66+
out.virtual_resize(new_out_sizes);
67+
indices.virtual_resize(new_out_sizes);
68+
}
69+
70+
void check_max_pool2d_args(const vTensor& in, const vTensor& out) {
71+
VK_CHECK_COND(
72+
check_memory_layout_is(in, api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED));
73+
VK_CHECK_COND(check_memory_layout_is(
74+
out, api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED));
75+
}
76+
77+
void add_max_pool2d_node(
78+
ComputeGraph& graph,
79+
const ValueRef in,
80+
const ValueRef kernel,
81+
const ValueRef stride,
82+
const ValueRef padding,
83+
const ValueRef dilation,
84+
const ValueRef ceil_mode,
85+
const ValueRef out) {
86+
ValueRef arg = prepack_if_tensor_ref(graph, in);
87+
vTensor& t_in = graph.get_val(arg).toTensor();
88+
89+
const auto& out_val = graph.get_val(out).toValueList();
90+
vTensor& t_out = graph.get_val(out_val[0]).toTensor();
91+
92+
check_max_pool2d_args(t_in, t_out);
93+
94+
api::utils::uvec3 global_size = t_out.virtual_extents();
95+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
96+
97+
std::stringstream kernel_name;
98+
kernel_name << "max_pool2d";
99+
apply_dtype_suffix(kernel_name, t_out);
100+
101+
KernelParams kernel_params{
102+
normalize_wh(graph.get_val(kernel)),
103+
normalize_wh(graph.get_val(stride)),
104+
normalize_wh(graph.get_val(padding)),
105+
normalize_wh(graph.get_val(dilation)),
106+
};
107+
108+
graph.execute_nodes().emplace_back(new ExecuteNode(
109+
graph,
110+
VK_KERNEL_FROM_STR(kernel_name.str()),
111+
global_size,
112+
local_size,
113+
// Inputs and Outputs
114+
{{{out_val[0], out_val[1]}, api::MemoryAccessType::WRITE},
115+
{arg, api::MemoryAccessType::READ}},
116+
// Shader params buffers
117+
{
118+
t_out.extents_ubo(),
119+
t_in.extents_ubo(),
120+
graph.create_params_buffer(kernel_params),
121+
},
122+
// Resizing
123+
resize_max_pool2d_node,
124+
{kernel, stride, padding, dilation, ceil_mode}));
125+
}
126+
127+
void max_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
128+
return add_max_pool2d_node(
129+
graph, args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
130+
}
131+
132+
REGISTER_OPERATORS {
133+
VK_REGISTER_OP(aten.max_pool2d_with_indices.default, max_pool2d);
134+
}
135+
136+
} // namespace vulkan
137+
} // namespace native
138+
} // namespace at
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
10+
11+
namespace at {
12+
namespace native {
13+
namespace vulkan {
14+
15+
int64_t calc_out_size(
16+
const int64_t in_size,
17+
const int64_t kernel,
18+
const int64_t stride,
19+
const int64_t padding,
20+
const int64_t dilation,
21+
const bool ceil_mode) {
22+
int64_t c = ceil_mode ? stride - 1 : 0;
23+
int64_t out_size =
24+
(in_size + 2 * padding - dilation * (kernel - 1) - 1 + c) / stride + 1;
25+
if (ceil_mode && (out_size - 1) * stride >= in_size + padding) {
26+
--out_size;
27+
}
28+
return out_size;
29+
}
30+
31+
api::utils::ivec2 normalize_wh(Value& v) {
32+
if (v.isInt()) {
33+
return api::utils::make_ivec2({v.toInt(), v.toInt()});
34+
} else {
35+
auto l = v.toIntList();
36+
return api::utils::make_ivec2({l.at(1), l.at(0)});
37+
}
38+
}
39+
40+
} // namespace vulkan
41+
} // namespace native
42+
} // namespace at
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef USE_VULKAN_API
12+
13+
#include <ATen/native/vulkan/api/api.h>
14+
15+
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
16+
17+
namespace at {
18+
namespace native {
19+
namespace vulkan {
20+
21+
struct KernelParams final {
22+
api::utils::ivec2 kernel;
23+
api::utils::ivec2 stride;
24+
api::utils::ivec2 padding;
25+
api::utils::ivec2 dilation;
26+
};
27+
28+
int64_t calc_out_size(
29+
const int64_t in_size,
30+
const int64_t kernel,
31+
const int64_t stride,
32+
const int64_t padding,
33+
const int64_t dilation,
34+
const bool ceil_mode);
35+
36+
api::utils::ivec2 normalize_wh(Value& v);
37+
38+
} // namespace vulkan
39+
} // namespace native
40+
} // namespace at
41+
42+
#endif /* USE_VULKAN_API */

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,17 @@ void fill_vtensor(vTensor& vten, std::vector<float>& data) {
172172
}
173173
}
174174

175-
void fill_vtensor(ComputeGraph& graph, const IOValueRef idx, float val) {
175+
void fill_vtensor(
176+
ComputeGraph& graph,
177+
const IOValueRef idx,
178+
float val,
179+
bool iota) {
176180
std::vector<float> data(graph.get_val(idx.value).toTensor().gpu_numel());
177-
std::fill(data.begin(), data.end(), val);
181+
if (iota) {
182+
std::iota(data.begin(), data.end(), val);
183+
} else {
184+
std::fill(data.begin(), data.end(), val);
185+
}
178186

179187
graph.copy_into_staging(idx.staging, data.data(), data.size());
180188
}

backends/vulkan/test/utils/test_utils.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,11 @@ inline void fill_vtensor(vTensor& vten, float val) {
118118
fill_vtensor(vten, vten_data);
119119
}
120120

121-
void fill_vtensor(ComputeGraph& graph, const IOValueRef idx, float val);
121+
void fill_vtensor(
122+
ComputeGraph& graph,
123+
const IOValueRef idx,
124+
float val,
125+
bool iota = false);
122126

123127
void extract_vtensor(vTensor& vten, std::vector<float>& data);
124128

0 commit comments

Comments
 (0)