Skip to content

Commit bb26fea

Browse files
committed
Update on "[ET-VK][Test] aten.max_pool2d_with_indices"
Due to the below issues, we only check equality of the output tensor and not the index tensor. 1. We can't verify index tensors since VK-float16 vs CPU-float32 deltas can change which index in a pool is the maximum. That can yield completely different integers in the index tensor. Hence, we only verify the output tensor not the index tensor. 2. To actually visualize the index tensor, we need to re-construct the int32 values from the int64 values. Since the `torch.int64` index tensor is serialized as `int32` in Vulkan, Python expects int64 but C++ writes to the buffer as though it is for int32. Hence, we must apply some computation to re-construct the tensor. See below for details. A helper function was included in an earlier version of this change, but was removed for conciseness since we aren't checking that index tensor anyway. For example, if the first and second elements return 16 and 17, we get this value as the first element: ``` 73014444048 = 1000100000000000000000000000000010000 ``` We must split this int64 into two int32 values, and construct a new tensor accordingly. ``` 10001 | 00000000000000000000000000010000 10001 | 10000 17 | 16 ``` Differential Revision: [D54962492](https://our.internmc.facebook.com/intern/diff/D54962492/) [ghstack-poisoned]
2 parents 9e21fec + 76806f6 commit bb26fea

File tree

4 files changed

+50
-32
lines changed

4 files changed

+50
-32
lines changed

backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,11 @@ void main() {
5959
if ((x >= 0 && x < in_extents.data.x) && (y >= 0 && y < in_extents.data.y)) {
6060
const vec4 cur_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0);
6161

62-
const int cur_idx = x + int(in_extents.data.x) * y;
63-
if (cur_texel.x > out_texel.x) {
64-
idx_texel.x = cur_idx;
65-
}
66-
if (cur_texel.y > out_texel.y) {
67-
idx_texel.y = cur_idx;
68-
}
69-
if (cur_texel.z > out_texel.z) {
70-
idx_texel.z = cur_idx;
71-
}
72-
if (cur_texel.w > out_texel.w) {
73-
idx_texel.w = cur_idx;
74-
}
62+
// Set idx if value is greatest in the pool; else, keep the existing idx.
63+
ivec4 cur_idx = ivec4(x + int(in_extents.data.x) * y);
64+
ivec4 mask = ivec4(greaterThan(cur_texel, out_texel));
65+
idx_texel = ivec4(mix(idx_texel, cur_idx, mask));
66+
7567
out_texel = max(cur_texel, out_texel);
7668
}
7769
else {

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1212

1313
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14-
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
1514
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1615

1716
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
10+
11+
namespace at {
12+
namespace native {
13+
namespace vulkan {
14+
15+
int64_t calc_out_size(
16+
const int64_t in_size,
17+
const int64_t kernel,
18+
const int64_t stride,
19+
const int64_t padding,
20+
const int64_t dilation,
21+
const bool ceil_mode) {
22+
int64_t c = ceil_mode ? stride - 1 : 0;
23+
int64_t out_size =
24+
(in_size + 2 * padding - dilation * (kernel - 1) - 1 + c) / stride + 1;
25+
if (ceil_mode && (out_size - 1) * stride >= in_size + padding) {
26+
--out_size;
27+
}
28+
return out_size;
29+
}
30+
31+
api::utils::ivec2 normalize_wh(Value& v) {
32+
if (v.isInt()) {
33+
return api::utils::make_ivec2({v.toInt(), v.toInt()});
34+
} else {
35+
auto l = v.toIntList();
36+
return api::utils::make_ivec2({l.at(1), l.at(0)});
37+
}
38+
}
39+
40+
} // namespace vulkan
41+
} // namespace native
42+
} // namespace at

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,9 @@ int64_t calc_out_size(
3131
const int64_t stride,
3232
const int64_t padding,
3333
const int64_t dilation,
34-
const bool ceil_mode) {
35-
int64_t c = ceil_mode ? stride - 1 : 0;
36-
int64_t out_size =
37-
(in_size + 2 * padding - dilation * (kernel - 1) - 1 + c) / stride + 1;
38-
if (ceil_mode && (out_size - 1) * stride >= in_size + padding) {
39-
--out_size;
40-
}
41-
return out_size;
42-
}
43-
44-
api::utils::ivec2 normalize_wh(Value& v) {
45-
if (v.isInt()) {
46-
return api::utils::make_ivec2({v.toInt(), v.toInt()});
47-
} else {
48-
auto l = v.toIntList();
49-
return api::utils::make_ivec2({l.at(1), l.at(0)});
50-
}
51-
}
34+
const bool ceil_mode);
35+
36+
api::utils::ivec2 normalize_wh(Value& v);
5237

5338
} // namespace vulkan
5439
} // namespace native

0 commit comments

Comments
 (0)