Skip to content

Commit 1cd1d1d

Browse files
committed
kick CI, SEV seems to have stopped affecting executorch
[ghstack-poisoned]
2 parents 3a2516a + b5a6362 commit 1cd1d1d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+660
-288
lines changed

.ci/scripts/utils.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ build_executorch_runner() {
158158
cmake_install_executorch_lib() {
159159
echo "Installing libexecutorch.a and libportable_kernels.a"
160160
clean_executorch_install_folders
161-
retry cmake -DBUCK2="$BUCK" \
162-
-DCMAKE_INSTALL_PREFIX=cmake-out \
161+
retry cmake -DCMAKE_INSTALL_PREFIX=cmake-out \
163162
-DCMAKE_BUILD_TYPE=Release \
164163
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
165164
-Bcmake-out .

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.exir import ExportedProgram
1818
from executorch.exir.dialects._ops import ops as exir_ops
1919
from executorch.exir.pass_base import ExportPass, PassResult
20+
from executorch.exir.passes import dead_code_elimination_pass
2021

2122
#################
2223
## linear_qcnw ##
@@ -224,6 +225,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
224225
)
225226

226227
graph_module.recompile()
227-
graph_module = super().call(graph_module).graph_module
228+
dead_code_elimination_pass(graph_module)
228229

230+
# Re-trace the graph since new nodes were (potentially) inserted
231+
graph_module = super().call(graph_module).graph_module
229232
return PassResult(graph_module, True)

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from copy import deepcopy
98
from typing import Any, Optional, Set
109

1110
import executorch.backends.vulkan.utils as utils
@@ -22,6 +21,7 @@
2221
from executorch.exir.dialects._ops import ops as exir_ops
2322

2423
from executorch.exir.pass_base import ExportPass, PassResult
24+
from executorch.exir.tensor import TensorSpec
2525

2626
logger: logging.Logger = logging.getLogger("")
2727
logger.setLevel(logging.INFO)
@@ -52,7 +52,7 @@ def insert_transition_node(
5252
(arg,),
5353
)
5454
clone_node.meta["val"] = arg.meta["val"]
55-
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
55+
clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"])
5656
clone_node.meta["spec"].const = False
5757
set_memory_metadata(clone_node, storage, layout)
5858
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)

backends/vulkan/op_registry.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,14 @@ def update_features_impl(op: OpKey):
230230
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231231
# Symbolic integer ops
232232
torch.ops.aten.sym_size.int,
233+
operator.add,
234+
operator.lt,
235+
operator.gt,
236+
operator.ge,
237+
operator.le,
238+
# Guard and assert ops
239+
torch.ops.aten._assert_scalar.default,
240+
torch.ops.aten.sym_constrain_range_for_size.default,
233241
]
234242
)
235243
def register_ephemeral_op(features: OpFeatures):
@@ -500,7 +508,12 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures):
500508
return features
501509

502510

503-
@update_features(["llama::update_cache", "llama::custom_sdpa"])
511+
@update_features(
512+
[
513+
"llama::update_cache",
514+
"llama::custom_sdpa",
515+
]
516+
)
504517
def register_sdpa_ops(features: OpFeatures):
505518
features.resize_fn = False
506519
features.buffer_impl = False
@@ -520,8 +533,17 @@ def register_rotary_emb_op(features: OpFeatures):
520533
return features
521534

522535

523-
@update_features(exir_ops.edge.aten.view_copy.default)
524-
def register_view_op(features: OpFeatures):
536+
@update_features(
537+
[
538+
exir_ops.edge.aten.clone.default,
539+
exir_ops.edge.aten.permute.default,
540+
exir_ops.edge.aten.permute_copy.default,
541+
exir_ops.edge.aten.select_copy.int,
542+
exir_ops.edge.aten.slice_copy.Tensor,
543+
exir_ops.edge.aten.view_copy.default,
544+
]
545+
)
546+
def register_view_ops(features: OpFeatures):
525547
features.texture_impl = TextureImplFeatures(
526548
valid_packed_dims=all_packed_dims,
527549
)
@@ -538,10 +560,8 @@ def register_view_op(features: OpFeatures):
538560
# Indexing and lookup
539561
exir_ops.edge.aten.flip.default,
540562
exir_ops.edge.aten.index_select.default,
541-
exir_ops.edge.aten.select_copy.int,
542563
# Tensor creation
543564
exir_ops.edge.aten.arange.start_step,
544-
exir_ops.edge.aten.clone.default,
545565
exir_ops.edge.aten.constant_pad_nd.default,
546566
exir_ops.edge.aten.full.default,
547567
exir_ops.edge.aten.full_like.default,
@@ -564,12 +584,9 @@ def register_ported_op(features: OpFeatures):
564584
# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry becasue they support all packed dimensions
565585
@update_features(
566586
[
567-
# Indexing and lookup
568-
exir_ops.edge.aten.slice_copy.Tensor,
569587
# Shape Manipulation
570588
exir_ops.edge.aten.squeeze_copy.dims,
571589
exir_ops.edge.aten.unsqueeze_copy.default,
572-
exir_ops.edge.aten.permute_copy.default,
573590
# Tensor combination
574591
exir_ops.edge.aten.cat.default,
575592
exir_ops.edge.aten.repeat.default,

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
146146
def node_is_compatible(
147147
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
148148
) -> Tuple[bool, str]:
149-
if utils.is_symint_node(node):
150-
return node.target in vulkan_supported_ops, "Op is compatible"
151-
elif utils.is_tensor_node(node):
149+
if utils.is_tensor_node(node):
152150
return self.op_node_is_compatible(node, features=features)
151+
# For non-tensor nodes, just check if the op is registered
152+
elif hasattr(node, "target"):
153+
return node.target in vulkan_supported_ops, "Op is compatible"
153154

154155
return False, f"Unsupported node type: {node.format_node()}"
155156

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14+
#define T ${buffer_scalar_type(DTYPE)}
15+
16+
${define_active_storage_type(STORAGE)}
17+
18+
#include "indexing_utils.h"
19+
20+
${define_required_extensions(DTYPE)}
21+
22+
layout(std430) buffer;
23+
24+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
25+
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
26+
$if STORAGE == "buffer":
27+
${layout_declare_ubo(2, "int", "numel")}
28+
$else:
29+
${layout_declare_ubo(2, "ivec3", "out_limits")}
30+
31+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
32+
33+
#include "activations.h"
34+
35+
#ifdef USING_BUFFER
36+
37+
void main() {
38+
const int i = int(gl_GlobalInvocationID.x);
39+
if (i >= numel) {
40+
return;
41+
}
42+
43+
float in_val = float(t_in[i]);
44+
t_out[i] = T(tan(in_val));
45+
}
46+
47+
#else
48+
49+
void main() {
50+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
51+
52+
if (any(greaterThanEqual(pos, out_limits))) {
53+
return;
54+
}
55+
56+
VEC4_T in_texel = texelFetch(t_in, pos, 0);
57+
imageStore(t_out, pos, VEC4_T(tan(in_texel)));
58+
}
59+
60+
#endif
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
tan:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
STORAGE: texture3d
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
STORAGE:
10+
- VALUE: texture3d
11+
- VALUE: buffer
12+
shader_variants:
13+
- NAME: tan

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 83 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ using utils::uvec4;
2525
namespace {
2626

2727
void check_args(
28-
const api::vTensor& in,
29-
const std::vector<int64_t>& permute_dims,
30-
const api::vTensor& out) {
31-
VK_CHECK_COND(check_same_packed_dim(in, out));
28+
ComputeGraph& graph,
29+
const ValueRef in,
30+
const ValueRef permute_dims,
31+
const ValueRef out) {
32+
(void)permute_dims;
33+
VK_CHECK_COND(check_same_packed_dim(graph, in, out));
3234

3335
// This implementation doesn't not requires the input tensor to have the same
3436
// dim size as the argument. The code will work as long as the input tensor's
@@ -38,40 +40,94 @@ void check_args(
3840

3941
} // namespace
4042

43+
void resize_permute_node(
44+
ComputeGraph* graph,
45+
const std::vector<ArgGroup>& args,
46+
const std::vector<ValueRef>& resize_args) {
47+
const ValueRef out = args[0].refs[0];
48+
const ValueRef in = args[1].refs[0];
49+
50+
const std::vector<int64_t> in_sizes = graph->sizes_of(in);
51+
const std::vector<int64_t> out_sizes = graph->sizes_of(out);
52+
53+
const std::vector<int64_t> permute_dims =
54+
graph->extract_int_or_symint_list(resize_args[0]);
55+
56+
if (in_sizes.size() == out_sizes.size() &&
57+
in_sizes.size() == permute_dims.size()) {
58+
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
59+
const int64_t out_ndim = std::max(in_sizes.size(), out_sizes.size());
60+
for (int i = 0; i < out_ndim; i++) {
61+
const int64_t permute_dim = permute_dims.at(i);
62+
new_out_sizes.at(i) = in_sizes.at(permute_dim);
63+
}
64+
graph->virtual_resize(out, new_out_sizes);
65+
}
66+
// Case where permute is being used to implement squeeze
67+
else if (
68+
in_sizes.size() > out_sizes.size() &&
69+
in_sizes.size() == permute_dims.size()) {
70+
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
71+
const size_t offset = in_sizes.size() - out_sizes.size();
72+
for (int i = 0; i < out_sizes.size(); i++) {
73+
const int64_t permute_dim = permute_dims.at(i + offset);
74+
new_out_sizes.at(i) = in_sizes.at(permute_dim);
75+
}
76+
graph->virtual_resize(out, new_out_sizes);
77+
}
78+
// Case where Permute is being used to implement unsqueeze
79+
else if (
80+
in_sizes.size() < out_sizes.size() &&
81+
out_sizes.size() == permute_dims.size()) {
82+
std::vector<int64_t> new_out_sizes(out_sizes.size(), 1);
83+
const size_t offset = out_sizes.size() - in_sizes.size();
84+
for (int i = 0; i < out_sizes.size(); i++) {
85+
int64_t permute_dim = permute_dims.at(i) - offset;
86+
if (permute_dim >= 0) {
87+
new_out_sizes.at(i) = in_sizes.at(permute_dim);
88+
}
89+
}
90+
graph->virtual_resize(out, new_out_sizes);
91+
} else {
92+
VK_THROW("Invalid permute dims");
93+
}
94+
}
95+
4196
void add_permute_node(
4297
ComputeGraph& graph,
43-
ValueRef in,
44-
const std::vector<int64_t>& permute_dims,
45-
ValueRef out) {
46-
vTensorPtr t_in = graph.get_tensor(in);
47-
vTensorPtr t_out = graph.get_tensor(out);
48-
49-
check_args(*t_in, permute_dims, *t_out);
98+
const ValueRef in,
99+
const ValueRef permute_dims,
100+
const ValueRef out) {
101+
check_args(graph, in, permute_dims, out);
50102

51103
ivec4 out_dims{0, 1, 2, 3};
52104

53105
// Special cases of squeeze/unsqueeze. Because the input dim size can be
54-
// different with output dim size. So pick t_in->dim() if squeeze, and
55-
// t_out->dim() if unsqueeze to create parameter for permute.
56-
int64_t out_ndim = std::max(t_in->dim(), t_out->dim());
106+
// different with output dim size. So pick graph.dim_of(in) if squeeze, and
107+
// graph.dim_of(out) if unsqueeze to create parameter for permute.
108+
const int64_t out_ndim = std::max(graph.dim_of(in), graph.dim_of(out));
57109
std::vector<bool> seen(out_ndim);
58-
for (int i = 0; i < out_ndim; i++) {
59-
int64_t permute_dim = permute_dims[i];
60-
VK_CHECK_COND(
61-
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
62-
seen[permute_dim] = true;
63-
64-
out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
110+
{
111+
IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims);
112+
for (int i = 0; i < out_ndim; i++) {
113+
int64_t permute_dim = permute_dims_ptr->at(i);
114+
VK_CHECK_COND(
115+
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
116+
seen[permute_dim] = true;
117+
118+
out_dims[(4u - out_ndim) + i] =
119+
utils::safe_downcast<int32_t>(permute_dim + (4 - out_ndim));
120+
}
65121
}
66122

67123
std::string kernel_name = "permute";
68124
kernel_name.reserve(kShaderNameReserve);
69-
add_dtype_suffix(kernel_name, *t_out);
125+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
70126

71-
int32_t out_channels = dim_at<kChannel4D>(t_out->sizes());
72-
int32_t in_channels = dim_at<kChannel4D>(t_in->sizes());
127+
const int32_t out_channels = dim_at<kChannel4D>(graph.sizes_of(out));
128+
const int32_t in_channels = dim_at<kChannel4D>(graph.sizes_of(in));
73129

74-
const auto packed_dim = graph.packed_dim_of(in);
130+
const int32_t packed_dim = graph.packed_dim_of(in);
75131
ivec2 channel_info = {out_channels, in_channels};
76132
if (packed_dim == WHCN::kChannelsDim) {
77133
channel_info[0] = utils::align_up_4(channel_info[0]);
@@ -95,19 +151,9 @@ void add_permute_node(
95151
// Specialization Constants
96152
spec_vars,
97153
// Resize Args
98-
{},
154+
{permute_dims},
99155
// Resizing Logic
100-
nullptr));
101-
}
102-
103-
void add_permute_node(
104-
ComputeGraph& graph,
105-
ValueRef in,
106-
ValueRef permute_dims_ref,
107-
ValueRef out) {
108-
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);
109-
110-
add_permute_node(graph, in, *permute_dims, out);
156+
resize_permute_node));
111157
}
112158

113159
void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/Permute.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ namespace vkcompute {
1818

1919
void add_permute_node(
2020
ComputeGraph& graph,
21-
ValueRef in,
22-
const std::vector<int64_t>& permute_dims,
23-
ValueRef out);
21+
const ValueRef in,
22+
const ValueRef permute_dims,
23+
const ValueRef out);
2424

2525
} // namespace vkcompute

0 commit comments

Comments
 (0)