Skip to content

Commit 72a67e2

Browse files
committed
[ET-VK] Add reshape functions for transformers related operators
Pull Request resolved: #11256 ## Changes * Implement resize functions for several operators used in Transformers models ## Motivation Be able to support batched prefill for llama models. ghstack-source-id: 287935585 @exported-using-ghexport Differential Revision: [D75686049](https://our.internmc.facebook.com/intern/diff/D75686049/)
1 parent 5ef38d3 commit 72a67e2

File tree

8 files changed

+154
-76
lines changed

8 files changed

+154
-76
lines changed

backends/vulkan/op_registry.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,12 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures):
500500
return features
501501

502502

503-
@update_features(["llama::update_cache", "llama::custom_sdpa"])
503+
@update_features(
504+
[
505+
"llama::update_cache",
506+
"llama::custom_sdpa",
507+
]
508+
)
504509
def register_sdpa_ops(features: OpFeatures):
505510
features.resize_fn = False
506511
features.buffer_impl = False
@@ -520,8 +525,17 @@ def register_rotary_emb_op(features: OpFeatures):
520525
return features
521526

522527

523-
@update_features(exir_ops.edge.aten.view_copy.default)
524-
def register_view_op(features: OpFeatures):
528+
@update_features(
529+
[
530+
exir_ops.edge.aten.clone.default,
531+
exir_ops.edge.aten.permute.default,
532+
exir_ops.edge.aten.permute_copy.default,
533+
exir_ops.edge.aten.select_copy.int,
534+
exir_ops.edge.aten.slice_copy.Tensor,
535+
exir_ops.edge.aten.view_copy.default,
536+
]
537+
)
538+
def register_view_ops(features: OpFeatures):
525539
features.texture_impl = TextureImplFeatures(
526540
valid_packed_dims=all_packed_dims,
527541
)
@@ -538,10 +552,8 @@ def register_view_op(features: OpFeatures):
538552
# Indexing and lookup
539553
exir_ops.edge.aten.flip.default,
540554
exir_ops.edge.aten.index_select.default,
541-
exir_ops.edge.aten.select_copy.int,
542555
# Tensor creation
543556
exir_ops.edge.aten.arange.start_step,
544-
exir_ops.edge.aten.clone.default,
545557
exir_ops.edge.aten.constant_pad_nd.default,
546558
exir_ops.edge.aten.full.default,
547559
exir_ops.edge.aten.full_like.default,
@@ -564,12 +576,9 @@ def register_ported_op(features: OpFeatures):
564576
# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry becasue they support all packed dimensions
565577
@update_features(
566578
[
567-
# Indexing and lookup
568-
exir_ops.edge.aten.slice_copy.Tensor,
569579
# Shape Manipulation
570580
exir_ops.edge.aten.squeeze_copy.dims,
571581
exir_ops.edge.aten.unsqueeze_copy.default,
572-
exir_ops.edge.aten.permute_copy.default,
573582
# Tensor combination
574583
exir_ops.edge.aten.cat.default,
575584
exir_ops.edge.aten.repeat.default,

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 83 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ using utils::uvec4;
2525
namespace {
2626

2727
void check_args(
28-
const api::vTensor& in,
29-
const std::vector<int64_t>& permute_dims,
30-
const api::vTensor& out) {
31-
VK_CHECK_COND(check_same_packed_dim(in, out));
28+
ComputeGraph& graph,
29+
const ValueRef in,
30+
const ValueRef permute_dims,
31+
const ValueRef out) {
32+
(void)permute_dims;
33+
VK_CHECK_COND(check_same_packed_dim(graph, in, out));
3234

3335
// This implementation doesn't not requires the input tensor to have the same
3436
// dim size as the argument. The code will work as long as the input tensor's
@@ -38,40 +40,94 @@ void check_args(
3840

3941
} // namespace
4042

43+
void resize_permute_node(
44+
ComputeGraph* graph,
45+
const std::vector<ArgGroup>& args,
46+
const std::vector<ValueRef>& resize_args) {
47+
const ValueRef out = args[0].refs[0];
48+
const ValueRef in = args[1].refs[0];
49+
50+
const std::vector<int64_t> in_sizes = graph->sizes_of(in);
51+
const std::vector<int64_t> out_sizes = graph->sizes_of(out);
52+
53+
const std::vector<int64_t> permute_dims =
54+
graph->extract_int_or_symint_list(resize_args[0]);
55+
56+
if (in_sizes.size() == out_sizes.size() &&
57+
in_sizes.size() == permute_dims.size()) {
58+
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
59+
const int64_t out_ndim = std::max(in_sizes.size(), out_sizes.size());
60+
for (int i = 0; i < out_ndim; i++) {
61+
const int64_t permute_dim = permute_dims.at(i);
62+
new_out_sizes.at(i) = in_sizes.at(permute_dim);
63+
}
64+
graph->virtual_resize(out, new_out_sizes);
65+
}
66+
// Case where permute is being used to implement squeeze
67+
else if (
68+
in_sizes.size() > out_sizes.size() &&
69+
in_sizes.size() == permute_dims.size()) {
70+
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
71+
const size_t offset = in_sizes.size() - out_sizes.size();
72+
for (int i = 0; i < out_sizes.size(); i++) {
73+
const int64_t permute_dim = permute_dims.at(i + offset);
74+
new_out_sizes.at(i) = in_sizes.at(permute_dim);
75+
}
76+
graph->virtual_resize(out, new_out_sizes);
77+
}
78+
// Case where Permute is being used to implement unsqueeze
79+
else if (
80+
in_sizes.size() < out_sizes.size() &&
81+
out_sizes.size() == permute_dims.size()) {
82+
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
83+
const size_t offset = out_sizes.size() - in_sizes.size();
84+
for (int i = 0; i < out_sizes.size(); i++) {
85+
int64_t permute_dim = permute_dims.at(i) - offset;
86+
if (permute_dim >= 0) {
87+
new_out_sizes.at(i) = in_sizes.at(permute_dim);
88+
}
89+
}
90+
graph->virtual_resize(out, new_out_sizes);
91+
} else {
92+
VK_THROW("Invalid permute dims");
93+
}
94+
}
95+
4196
void add_permute_node(
4297
ComputeGraph& graph,
43-
ValueRef in,
44-
const std::vector<int64_t>& permute_dims,
45-
ValueRef out) {
46-
vTensorPtr t_in = graph.get_tensor(in);
47-
vTensorPtr t_out = graph.get_tensor(out);
48-
49-
check_args(*t_in, permute_dims, *t_out);
98+
const ValueRef in,
99+
const ValueRef permute_dims,
100+
const ValueRef out) {
101+
check_args(graph, in, permute_dims, out);
50102

51103
ivec4 out_dims{0, 1, 2, 3};
52104

53105
// Special cases of squeeze/unsqueeze. Because the input dim size can be
54-
// different with output dim size. So pick t_in->dim() if squeeze, and
55-
// t_out->dim() if unsqueeze to create parameter for permute.
56-
int64_t out_ndim = std::max(t_in->dim(), t_out->dim());
106+
// different with output dim size. So pick graph.dim_of(in) if squeeze, and
107+
// graph.dim_of(out) if unsqueeze to create parameter for permute.
108+
const int64_t out_ndim = std::max(graph.dim_of(in), graph.dim_of(out));
57109
std::vector<bool> seen(out_ndim);
58-
for (int i = 0; i < out_ndim; i++) {
59-
int64_t permute_dim = permute_dims[i];
60-
VK_CHECK_COND(
61-
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
62-
seen[permute_dim] = true;
63-
64-
out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
110+
{
111+
IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims);
112+
for (int i = 0; i < out_ndim; i++) {
113+
int64_t permute_dim = permute_dims_ptr->at(i);
114+
VK_CHECK_COND(
115+
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
116+
seen[permute_dim] = true;
117+
118+
out_dims[(4u - out_ndim) + i] =
119+
utils::safe_downcast<int32_t>(permute_dim + (4 - out_ndim));
120+
}
65121
}
66122

67123
std::string kernel_name = "permute";
68124
kernel_name.reserve(kShaderNameReserve);
69-
add_dtype_suffix(kernel_name, *t_out);
125+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
70126

71-
int32_t out_channels = dim_at<kChannel4D>(t_out->sizes());
72-
int32_t in_channels = dim_at<kChannel4D>(t_in->sizes());
127+
const int32_t out_channels = dim_at<kChannel4D>(graph.sizes_of(out));
128+
const int32_t in_channels = dim_at<kChannel4D>(graph.sizes_of(in));
73129

74-
const auto packed_dim = graph.packed_dim_of(in);
130+
const int32_t packed_dim = graph.packed_dim_of(in);
75131
ivec2 channel_info = {out_channels, in_channels};
76132
if (packed_dim == WHCN::kChannelsDim) {
77133
channel_info[0] = utils::align_up_4(channel_info[0]);
@@ -95,19 +151,9 @@ void add_permute_node(
95151
// Specialization Constants
96152
spec_vars,
97153
// Resize Args
98-
{},
154+
{permute_dims},
99155
// Resizing Logic
100-
nullptr));
101-
}
102-
103-
void add_permute_node(
104-
ComputeGraph& graph,
105-
ValueRef in,
106-
ValueRef permute_dims_ref,
107-
ValueRef out) {
108-
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);
109-
110-
add_permute_node(graph, in, *permute_dims, out);
156+
resize_permute_node));
111157
}
112158

113159
void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/Permute.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ namespace vkcompute {
1818

1919
void add_permute_node(
2020
ComputeGraph& graph,
21-
ValueRef in,
22-
const std::vector<int64_t>& permute_dims,
23-
ValueRef out);
21+
const ValueRef in,
22+
const ValueRef permute_dims,
23+
const ValueRef out);
2424

2525
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,20 @@ namespace vkcompute {
1515
void resize_rotary_embedding_node(
1616
ComputeGraph* graph,
1717
const std::vector<ArgGroup>& args,
18-
const std::vector<ValueRef>& extra_args) {
19-
(void)extra_args;
20-
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
21-
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
22-
23-
std::vector<int64_t> in_sizes = in->sizes();
24-
// UNCOMMENT BELOW IF NEEDED
25-
// out->virtual_resize(in_sizes);
18+
const std::vector<ValueRef>& resize_args) {
19+
(void)resize_args;
20+
21+
const ValueRef xq_out = args.at(0).refs.at(0);
22+
const ValueRef xk_out = args.at(0).refs.at(1);
23+
24+
const ValueRef xq = args.at(1).refs.at(0);
25+
const ValueRef xk = args.at(1).refs.at(1);
26+
27+
const std::vector<int64_t> xq_sizes = graph->sizes_of(xq);
28+
const std::vector<int64_t> xk_sizes = graph->sizes_of(xk);
29+
30+
graph->virtual_resize(xq_out, xq_sizes);
31+
graph->virtual_resize(xk_out, xk_sizes);
2632
}
2733

2834
void add_rotary_embedding_node(

backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,29 @@ namespace vkcompute {
1717

1818
void add_squeeze_copy_dims_node(
1919
ComputeGraph& graph,
20-
ValueRef in,
21-
ValueRef dims_ref,
22-
ValueRef out) {
23-
vTensorPtr t_in = graph.get_tensor(in);
24-
vTensorPtr t_out = graph.get_tensor(out);
20+
const ValueRef in,
21+
const ValueRef dims_ref,
22+
const ValueRef out) {
23+
const int64_t in_dim = graph.dim_of(in);
24+
const std::vector<int64_t> in_sizes = graph.sizes_of(in);
25+
const std::vector<int64_t> out_sizes = graph.sizes_of(in);
2526

26-
IntListPtr dims = graph.get_int_list(dims_ref);
27+
const std::vector<int64_t> dims = graph.extract_int_or_symint_list(dims_ref);
2728
std::vector<int64_t> squeeze_dims;
2829
// Filter out edge cases that we don't need squeeze:
2930
// 1. The size of squeeze dim is larger than 1.
3031
// 2. Squeeze outter most dim
3132
// For these cases, just pass input to output via clone.
32-
for (int i = 0; i < dims->size(); ++i) {
33-
if (dims->at(i) != 0 && t_in->sizes().at(dims->at(i)) == 1) {
34-
squeeze_dims.push_back(dims->at(i));
33+
for (int i = 0; i < dims.size(); ++i) {
34+
if (dims.at(i) != 0 && in_sizes.at(dims.at(i)) == 1) {
35+
squeeze_dims.push_back(dims.at(i));
3536
}
3637
}
3738
if (squeeze_dims.size() == 0) {
3839
add_clone_node(graph, in, out);
3940
} else {
40-
std::vector<int64_t> permute_dims(t_in->dim());
41-
for (int i = 0; i < t_in->dim(); ++i) {
41+
std::vector<int64_t> permute_dims(in_dim);
42+
for (int i = 0; i < in_dim; ++i) {
4243
permute_dims.at(i) = i;
4344
}
4445
for (auto& elem : squeeze_dims) {
@@ -48,7 +49,9 @@ void add_squeeze_copy_dims_node(
4849
std::rotate(permute_dims.begin(), it, it + 1);
4950
}
5051

51-
add_permute_node(graph, in, permute_dims, out);
52+
const ValueRef permute_dims_ref =
53+
graph.add_scalar_list<int64_t>(std::vector<int64_t>(permute_dims));
54+
add_permute_node(graph, in, permute_dims_ref, out);
5255
}
5356
}
5457

backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,16 @@ namespace vkcompute {
1616

1717
void add_unsqueeze_node(
1818
ComputeGraph& graph,
19-
ValueRef in,
20-
ValueRef dim_ref,
21-
ValueRef out) {
22-
vTensorPtr t_in = graph.get_tensor(in);
23-
vTensorPtr t_out = graph.get_tensor(out);
19+
const ValueRef in,
20+
const ValueRef dim_ref,
21+
const ValueRef out) {
22+
const int64_t in_dim = graph.dim_of(in);
23+
const int64_t out_dim = graph.dim_of(out);
2424

2525
VK_CHECK_COND(
26-
t_in->dim() < 4, "Cannot unsqueeze a tensor with more than 3 dimensions");
26+
in_dim < 4, "Cannot unsqueeze a tensor with more than 3 dimensions");
2727

2828
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
29-
int64_t out_dim = t_out->dim();
3029

3130
std::vector<int64_t> permute_dims(out_dim);
3231
for (int i = 1; i <= dim; i++) {
@@ -38,7 +37,9 @@ void add_unsqueeze_node(
3837
permute_dims[i] = i;
3938
}
4039

41-
add_permute_node(graph, in, permute_dims, out);
40+
const ValueRef permute_dims_ref =
41+
graph.add_scalar_list<int64_t>(std::vector<int64_t>(permute_dims));
42+
add_permute_node(graph, in, permute_dims_ref, out);
4243
}
4344

4445
void unsqueeze(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2) {
5757
return t1.packed_dim() == t2.packed_dim();
5858
}
5959

60+
bool check_same_packed_dim(
61+
ComputeGraph& graph,
62+
const ValueRef in,
63+
const ValueRef out) {
64+
return graph.packed_dim_of(in) == graph.packed_dim_of(out);
65+
}
66+
6067
bool check_same_packed_dim(
6168
const api::vTensor& t1,
6269
const api::vTensor& t2,

backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1213

1314
namespace vkcompute {
1415

@@ -38,6 +39,11 @@ bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim);
3839

3940
bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2);
4041

42+
bool check_same_packed_dim(
43+
ComputeGraph& graph,
44+
const ValueRef in,
45+
const ValueRef out);
46+
4147
bool check_same_packed_dim(
4248
const api::vTensor& t1,
4349
const api::vTensor& t2,

0 commit comments

Comments
 (0)