Skip to content

Commit f5be29a

Browse files
committed
[ET-VK][Ops] aten.max_pool2d_with_indices
## The Operator An `nn.Module` invocation of `torch.nn.MaxPool2d()` is represented as `aten.max_pool2d_with_indices.default` in the Edge Dialect, indpendent of `use_indices = True/False`. ``` # Return: (Tensor output, Tensor indices) - func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) ``` This is different from PT-VK where `torch.nn.MaxPool2d()` was represented as `aten.max_pool2d.default`. ``` - func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor ``` The difference is we now return an additional tensor for the max indices. Still, much of the core logic is taken from [`max_pool2d.glsl`](https://github.com/pytorch/pytorch/blob/cceabe873f11c6611f627a3bb0055994952ec6b8/aten/src/ATen/native/vulkan/glsl/max_pool2d.glsl) and [`Pool.cpp`](https://github.com/pytorch/pytorch/blob/cceabe873f11c6611f627a3bb0055994952ec6b8/aten/src/ATen/native/vulkan/ops/Pool.cpp). We provide only a `CHANNELS_PACKED` implementation. ## The Smoke Test Given any input and kernel sizes, we fill the input tensor with increasing values, e.g., ``` tensor([[[10., 11., 12., 13., 14., 15.], [16., 17., 18., 19., 20., 21.], [22., 23., 24., 25., 26., 27.], [28., 29., 30., 31., 32., 33.]]]) ``` With this setup, the max number for each pool is always in the lower-right. We use the kernel size to compute the size of the lower-right block and verify that 1. the output tensor values match the lower-right block values, and 2. the index tensor values match the lower-right block indices. ``` tensor([[[ 18., 19., 20., 21.], [ 24., 25., 26., 27.], [ 30., 31., 32., 33.]]]) tensor([[[ 8, 9, 10, 11], [14, 15, 16, 17], [20, 21, 22, 23]]]) ``` Differential Revision: [D54961929](https://our.internmc.facebook.com/intern/diff/D54961929/) ghstack-source-id: 219490378 Pull Request resolved: #2547
1 parent 67502f2 commit f5be29a

File tree

8 files changed

+405
-3
lines changed

8 files changed

+405
-3
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4141
exir_ops.edge.aten.relu.default,
4242
# Matrix multiplication operators
4343
exir_ops.edge.aten.mm.default,
44+
# Pooling operators
45+
exir_ops.edge.aten.max_pool2d_with_indices.default,
4446
# Other
4547
operator.getitem,
4648
]
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define FLT_MIN -3.402823466e+38
13+
14+
#include "indexing_utils.h"
15+
16+
layout(std430) buffer;
17+
18+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
19+
layout(set = 0, binding = 1, ${IMAGE_FORMAT["int"]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM]["int"]} image_idx;
20+
layout(set = 0, binding = 2) uniform PRECISION sampler3D image_in;
21+
22+
layout(set = 0, binding = 3) uniform PRECISION restrict OutExtents {
23+
uvec4 data;
24+
}
25+
out_extents;
26+
27+
layout(set = 0, binding = 4) uniform PRECISION restrict InExtents {
28+
uvec4 data;
29+
}
30+
in_extents;
31+
32+
layout(set = 0, binding = 5) uniform PRECISION restrict Params {
33+
ivec2 kernel;
34+
ivec2 stride;
35+
ivec2 padding;
36+
ivec2 dilation;
37+
}
38+
params;
39+
40+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
41+
42+
void main() {
43+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
44+
45+
if (any(greaterThanEqual(pos, out_extents.data.xyz))) {
46+
return;
47+
}
48+
49+
const ivec2 ipos = pos.xy * params.stride - params.padding;
50+
51+
const ivec2 start = ipos;
52+
const ivec2 end = ipos + params.kernel * params.dilation;
53+
54+
vec4 out_texel = vec4(FLT_MIN);
55+
ivec4 idx_texel = ivec4(0);
56+
57+
for (int y = start.y; y < end.y; y += params.dilation.y) {
58+
for (int x = start.x; x < end.x; x += params.dilation.x) {
59+
if ((x >= 0 && x < in_extents.data.x) && (y >= 0 && y < in_extents.data.y)) {
60+
const vec4 cur_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
61+
62+
const int cur_idx = x + int(in_extents.data.x) * y;
63+
if (cur_texel.x > out_texel.x) {
64+
idx_texel.x = cur_idx;
65+
}
66+
if (cur_texel.y > out_texel.y) {
67+
idx_texel.y = cur_idx;
68+
}
69+
if (cur_texel.z > out_texel.z) {
70+
idx_texel.z = cur_idx;
71+
}
72+
if (cur_texel.w > out_texel.w) {
73+
idx_texel.w = cur_idx;
74+
}
75+
out_texel = max(cur_texel, out_texel);
76+
}
77+
else {
78+
out_texel = max(vec4(FLT_MIN), out_texel);
79+
}
80+
}
81+
}
82+
83+
imageStore(image_out, pos, out_texel);
84+
imageStore(image_idx, pos, idx_texel);
85+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
max_pool2d:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
SUFFIX: half
15+
- VALUE: float
16+
SUFFIX: float
17+
shader_variants:
18+
- NAME: max_pool2d
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16+
17+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
18+
19+
namespace at {
20+
namespace native {
21+
namespace vulkan {
22+
23+
void resize_max_pool2d_node(
24+
ComputeGraph* graph,
25+
const std::vector<ArgGroup>& args,
26+
const std::vector<ValueRef>& extra_args) {
27+
vTensor& out = graph->get_val(args[0].refs[0]).toTensor();
28+
vTensor& indices = graph->get_val(args[0].refs[1]).toTensor();
29+
vTensor& self = graph->get_val(args[1].refs[0]).toTensor();
30+
31+
size_t ndim = self.sizes().size();
32+
std::vector<int64_t> new_out_sizes(ndim);
33+
34+
// Batch
35+
if (ndim == 4) {
36+
new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4);
37+
}
38+
// Channel
39+
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);
40+
41+
const auto kernel = normalize_wh(graph->get_val(extra_args[0]));
42+
const auto stride = normalize_wh(graph->get_val(extra_args[1]));
43+
const auto padding = normalize_wh(graph->get_val(extra_args[2]));
44+
const auto dilation = normalize_wh(graph->get_val(extra_args[3]));
45+
const bool ceil_mode = graph->get_val(extra_args[4]).toBool();
46+
47+
// Height
48+
new_out_sizes.at(ndim - 2) = calc_out_size(
49+
self.sizes().at(ndim - 2),
50+
kernel.data[1],
51+
stride.data[1],
52+
padding.data[1],
53+
dilation.data[1],
54+
ceil_mode);
55+
// Width
56+
new_out_sizes.at(ndim - 1) = calc_out_size(
57+
self.sizes().at(ndim - 1),
58+
kernel.data[0],
59+
stride.data[0],
60+
padding.data[0],
61+
dilation.data[0],
62+
ceil_mode);
63+
64+
VK_CHECK_COND(new_out_sizes.at(ndim - 2) >= 1);
65+
VK_CHECK_COND(new_out_sizes.at(ndim - 1) >= 1);
66+
67+
out.virtual_resize(new_out_sizes);
68+
indices.virtual_resize(new_out_sizes);
69+
}
70+
71+
void check_max_pool2d_args(const vTensor& in, const vTensor& out) {
72+
VK_CHECK_COND(
73+
check_memory_layout_is(in, api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED));
74+
VK_CHECK_COND(check_memory_layout_is(
75+
out, api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED));
76+
}
77+
78+
void add_max_pool2d_node(
79+
ComputeGraph& graph,
80+
const ValueRef in,
81+
const ValueRef kernel,
82+
const ValueRef stride,
83+
const ValueRef padding,
84+
const ValueRef dilation,
85+
const ValueRef ceil_mode,
86+
const ValueRef out) {
87+
ValueRef arg = prepack_if_tensor_ref(graph, in);
88+
vTensor& t_in = graph.get_val(arg).toTensor();
89+
90+
const auto& out_val = graph.get_val(out).toValueList();
91+
vTensor& t_out = graph.get_val(out_val[0]).toTensor();
92+
93+
check_max_pool2d_args(t_in, t_out);
94+
95+
api::utils::uvec3 global_size = t_out.virtual_extents();
96+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
97+
98+
std::stringstream kernel_name;
99+
kernel_name << "max_pool2d";
100+
apply_dtype_suffix(kernel_name, t_out);
101+
102+
KernelParams kernel_params{
103+
normalize_wh(graph.get_val(kernel)),
104+
normalize_wh(graph.get_val(stride)),
105+
normalize_wh(graph.get_val(padding)),
106+
normalize_wh(graph.get_val(dilation)),
107+
};
108+
109+
graph.execute_nodes().emplace_back(new ExecuteNode(
110+
graph,
111+
VK_KERNEL_FROM_STR(kernel_name.str()),
112+
global_size,
113+
local_size,
114+
// Inputs and Outputs
115+
{{{out_val[0], out_val[1]}, api::MemoryAccessType::WRITE},
116+
{arg, api::MemoryAccessType::READ}},
117+
// Shader params buffers
118+
{
119+
t_out.extents_ubo(),
120+
t_in.extents_ubo(),
121+
graph.create_params_buffer(kernel_params),
122+
},
123+
// Resizing
124+
resize_max_pool2d_node,
125+
{kernel, stride, padding, dilation, ceil_mode}));
126+
}
127+
128+
void max_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
129+
return add_max_pool2d_node(
130+
graph, args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
131+
}
132+
133+
REGISTER_OPERATORS {
134+
VK_REGISTER_OP(aten.max_pool2d_with_indices.default, max_pool2d);
135+
}
136+
137+
} // namespace vulkan
138+
} // namespace native
139+
} // namespace at
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef USE_VULKAN_API
12+
13+
#include <ATen/native/vulkan/api/api.h>
14+
15+
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
16+
17+
namespace at {
18+
namespace native {
19+
namespace vulkan {
20+
21+
struct KernelParams final {
22+
api::utils::ivec2 kernel;
23+
api::utils::ivec2 stride;
24+
api::utils::ivec2 padding;
25+
api::utils::ivec2 dilation;
26+
};
27+
28+
int64_t calc_out_size(
29+
const int64_t in_size,
30+
const int64_t kernel,
31+
const int64_t stride,
32+
const int64_t padding,
33+
const int64_t dilation,
34+
const bool ceil_mode) {
35+
int64_t c = ceil_mode ? stride - 1 : 0;
36+
int64_t out_size =
37+
(in_size + 2 * padding - dilation * (kernel - 1) - 1 + c) / stride + 1;
38+
if (ceil_mode && (out_size - 1) * stride >= in_size + padding) {
39+
--out_size;
40+
}
41+
return out_size;
42+
}
43+
44+
api::utils::ivec2 normalize_wh(Value& v) {
45+
if (v.isInt()) {
46+
return api::utils::make_ivec2({v.toInt(), v.toInt()});
47+
} else {
48+
auto l = v.toIntList();
49+
return api::utils::make_ivec2({l.at(1), l.at(0)});
50+
}
51+
}
52+
53+
} // namespace vulkan
54+
} // namespace native
55+
} // namespace at
56+
57+
#endif /* USE_VULKAN_API */

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,17 @@ void fill_vtensor(vTensor& vten, std::vector<float>& data) {
172172
}
173173
}
174174

175-
void fill_vtensor(ComputeGraph& graph, const IOValueRef idx, float val) {
175+
void fill_vtensor(
176+
ComputeGraph& graph,
177+
const IOValueRef idx,
178+
float val,
179+
bool iota) {
176180
std::vector<float> data(graph.get_val(idx.value).toTensor().gpu_numel());
177-
std::fill(data.begin(), data.end(), val);
181+
if (iota) {
182+
std::iota(data.begin(), data.end(), val);
183+
} else {
184+
std::fill(data.begin(), data.end(), val);
185+
}
178186

179187
graph.copy_into_staging(idx.staging, data.data(), data.size());
180188
}

backends/vulkan/test/utils/test_utils.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,11 @@ inline void fill_vtensor(vTensor& vten, float val) {
118118
fill_vtensor(vten, vten_data);
119119
}
120120

121-
void fill_vtensor(ComputeGraph& graph, const IOValueRef idx, float val);
121+
void fill_vtensor(
122+
ComputeGraph& graph,
123+
const IOValueRef idx,
124+
float val,
125+
bool iota = false);
122126

123127
void extract_vtensor(vTensor& vten, std::vector<float>& data);
124128

0 commit comments

Comments
 (0)