Skip to content

Commit 4431dd5

Browse files
committed
Update on "[ET-VK][8/n] Unsqueeze"
Exploit the fact that, we reduce the unsqueeze operation to permute. ``` torch.all(torch.permute(x.unsqueeze(0), [1, 0, 2, 3]) == x.unsqueeze(1)) torch.all(torch.permute(x.unsqueeze(0), [1, 2, 0, 3]) == x.unsqueeze(2)) torch.all(torch.permute(x.unsqueeze(0), [1, 2, 3, 0]) == x.unsqueeze(3)) ``` This diff introduce a minor change to the Permute implementation that it no longer requires the input dimension length to match the length of the permute array. This allows the `unsqueeze` operation to achieve a no-op `unsqueeze(0)` and then apply a permute. Differential Revision: [D56347734](https://our.internmc.facebook.com/intern/diff/D56347734/) [ghstack-poisoned]
2 parents aca7014 + 616c6ab commit 4431dd5

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

backends/vulkan/runtime/graph/ops/impl/Permute.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#pragma once
10+
911
#include <executorch/backends/vulkan/runtime/api/api.h>
1012

1113
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11-
#include <executorch/backends/vulkan/runtime/api/api.h>
12-
1311
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
1412
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15-
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1613
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1714

1815
namespace vkcompute {

0 commit comments

Comments
 (0)