6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
+ #include < executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
10
+
9
11
#include < executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
12
13
+ #include < executorch/backends/vulkan/runtime/api/api.h>
14
+
15
+ #include < executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
11
16
#include < executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
12
17
#include < executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
13
18
#include < executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
@@ -20,53 +25,51 @@ using api::utils::uvec4;
20
25
21
26
void check_args (
22
27
const vTensor& in,
23
- const IntListPtr & permute_dims,
28
+ const std::vector< int64_t > & permute_dims,
24
29
const vTensor& out) {
25
30
VK_CHECK_COND (check_memory_layout_is (in, api::kChannelsPacked ));
26
31
VK_CHECK_COND (check_memory_layout_is (out, api::kChannelsPacked ));
27
32
28
- int64_t in_dim = in.dim ();
33
+ // This implementation doesn't not requires the input tensor to have the same
34
+ // dim size as the argument. The code will work as long as the input tensor's
35
+ // dim size is shorter than the permute dim array. In this case, the code
36
+ // assume size of 1 at the higher dimensions.
37
+
38
+ int64_t out_dim = out.dim ();
29
39
VK_CHECK_COND (
30
- in_dim == permute_dims-> size (),
31
- " Input tensor dim size must match argument" );
40
+ out_dim == permute_dims. size (),
41
+ " Output tensor dim size must match argument" );
32
42
}
33
43
34
44
void add_permute_node (
35
45
ComputeGraph& graph,
36
46
ValueRef in,
37
- ValueRef permute_dims_ref ,
47
+ const std::vector< int64_t >& permute_dims ,
38
48
ValueRef out) {
39
49
vTensorPtr t_in = graph.get_tensor (in);
40
50
vTensorPtr t_out = graph.get_tensor (out);
41
51
42
- IntListPtr permute_dims = graph.get_int_list (permute_dims_ref);
43
-
44
52
check_args (*t_in, permute_dims, *t_out);
45
53
46
- uvec4 in_size{1u , 1u , 1u , 1u }, out_size{1u , 1u , 1u , 1u };
47
54
uvec4 out_dims{0u , 1u , 2u , 3u };
48
55
49
- int64_t in_dim = t_in->dim ();
50
-
51
- std::vector<bool > seen (in_dim);
52
- for (int i = 0 ; i < in_dim; i++) {
53
- int64_t permute_dim = (*permute_dims)[i];
56
+ int64_t out_dim = t_out->dim ();
57
+ std::vector<bool > seen (out_dim);
58
+ for (int i = 0 ; i < t_out->dim (); i++) {
59
+ int64_t permute_dim = permute_dims[i];
54
60
VK_CHECK_COND (
55
61
!seen[permute_dim], " Argument dim " , permute_dim, " is repeated" );
56
62
seen[permute_dim] = true ;
57
63
58
- // Map to 4D tensor dims.
59
- in_size.data [(4u - in_dim) + i] = t_in->size (i);
60
- out_size.data [(4u - in_dim) + i] = t_in->size (permute_dim);
61
- out_dims.data [(4u - in_dim) + i] = permute_dim + (4u - in_dim);
64
+ out_dims.data [(4u - out_dim) + i] = permute_dim + (4u - out_dim);
62
65
}
63
66
64
67
std::string kernel_name = " permute" ;
65
68
kernel_name.reserve (kShaderNameReserve );
66
69
add_dtype_suffix (kernel_name, *t_out);
67
70
68
- uint32_t out_channels = out_size. data [ 1u ] ;
69
- uint32_t in_channels = in_size. data [ 1u ] ;
71
+ uint32_t out_channels = dim_at<Dim4D::Channel>(t_out-> sizes ()) ;
72
+ uint32_t in_channels = dim_at<Dim4D::Channel>(t_in-> sizes ()) ;
70
73
71
74
uint32_t out_c_aligned = api::utils::align_up (out_channels, 4u );
72
75
uint32_t in_c_aligned = api::utils::align_up (in_channels, 4u );
@@ -91,6 +94,16 @@ void add_permute_node(
91
94
{t_out->gpu_sizes_ubo (), graph.create_params_buffer (params)}));
92
95
}
93
96
97
+ void add_permute_node (
98
+ ComputeGraph& graph,
99
+ ValueRef in,
100
+ ValueRef permute_dims_ref,
101
+ ValueRef out) {
102
+ IntListPtr permute_dims = graph.get_int_list (permute_dims_ref);
103
+
104
+ add_permute_node (graph, in, *permute_dims, out);
105
+ }
106
+
94
107
void permute (ComputeGraph& graph, const std::vector<ValueRef>& args) {
95
108
return add_permute_node (graph, args[0 ], args[1 ], args[2 ]);
96
109
}
0 commit comments