Skip to content

[ET-VK] Support exporting graphs with symbolic shape ops + update view to accept sym_size args #10997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
from copy import deepcopy
from typing import Any, Set
from typing import Any, Optional, Set

import executorch.backends.vulkan.utils as utils

Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(
def propose_node_storage(
self,
node: torch.fx.Node,
) -> VkStorageType:
) -> Optional[VkStorageType]:
"""
Uses the operator registry to determine the storage type that should be used for
a given node. The storage type is determined with the following priorities:
Expand All @@ -114,6 +114,9 @@ def propose_node_storage(
opinionated user can be found, then proceed to the last step.
4. Use the default storage type setting.
"""
if not utils.is_tensor_node(node):
return None

# The node may have an input/output tensor that is too big to be stored in a
# texture. In this case, buffer storage must be used. Note that the partitioner
# has already checked for the fact that buffer storage is supported by the
Expand Down Expand Up @@ -154,12 +157,15 @@ def propose_node_layout(
self,
node: torch.fx.Node,
storage: VkStorageType,
) -> VkMemoryLayout:
) -> Optional[VkMemoryLayout]:
"""
Performs the same steps as propose_node_storage, but detects the memory layout
that should be used for the specific storage type. The same prioritization logic
is applied.
"""
if not utils.is_tensor_node(node):
return None

valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts
# pyre-ignore
if has_impl(node.target):
Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def update_features_impl(op: OpKey):
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
# Symbolic integer ops
torch.ops.aten.sym_size.int,
]
)
def register_ephemeral_op(features: OpFeatures):
Expand Down Expand Up @@ -505,6 +507,7 @@ def register_sdpa_ops(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.WIDTH},
)
features.resize_fn = True
return features


Expand Down
5 changes: 2 additions & 3 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,8 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
def node_is_compatible(
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
) -> Tuple[bool, str]:
# TODO(ssjia) support symbolic ints
if utils.is_symint_node(node):
return False, "symint node not supported yet"
return node.target in vulkan_supported_ops, "Op is compatible"
elif utils.is_tensor_node(node):
return self.op_node_is_compatible(node, features=features)

Expand Down Expand Up @@ -258,7 +257,7 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool:
if target not in vulkan_supported_ops:
# For some ops, i.e. custom ops the name is registered instead of the
# OpOverload object.
if not isinstance(target, str) and target.name() in vulkan_supported_ops:
if hasattr(target, "name") and target.name() in vulkan_supported_ops:
features = vulkan_supported_ops[target.name()]
else:
self.log_skip(node, "no operator implementation")
Expand Down
32 changes: 32 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,38 @@ ComputeGraph::~ComputeGraph() {
context_->flush();
}

std::vector<int64_t> ComputeGraph::extract_int_or_symint_list(
const ValueRef idx) {
const Value& val = values_.at(idx);
std::vector<int64_t> result;

if (val.isIntList()) {
// If it's an IntList, return a copy of the list
return val.toConstIntList();
} else if (val.isValueList()) {
// If it's a ValueList, extract each element as an Int or SymInt
const std::vector<ValueRef>& value_list = val.toConstValueList();
result.reserve(value_list.size());

for (const ValueRef& ref : value_list) {
const Value& element = values_.at(ref);
if (element.isInt()) {
result.push_back(element.toInt());
} else if (element.isSymInt()) {
result.push_back(read_symint(ref));
} else {
VK_THROW(
"ValueList element is neither Int nor SymInt, but has type ",
element.type());
}
}
return result;
}

VK_THROW(
"Cannot extract int or symint list from Value with type ", val.type());
}

utils::StorageType ComputeGraph::suggested_storage_type() {
if (config_.enable_storage_type_override) {
return config_.storage_type_override;
Expand Down
9 changes: 9 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,15 @@ class ComputeGraph final {
return values_.at(idx).toString();
}

/*
* Utility function to extract a list of integers from a ValueRef.
* If the ValueRef is an IntList, returns a copy of the list.
* If the ValueRef is a ValueList, extracts each element as an Int or SymInt
* and returns the resulting list.
* Throws an error if the ValueRef is neither an IntList nor a ValueList.
*/
std::vector<int64_t> extract_int_or_symint_list(const ValueRef idx);

template <
typename T,
typename std::enable_if<
Expand Down
52 changes: 52 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

namespace vkcompute {

void resize_sym_size_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)args; // Unused parameter

ValueRef out_symint_ref = extra_args[0];
ValueRef in_tensor_ref = extra_args[1];

int64_t dim = graph->extract_scalar<int64_t>(extra_args[2]);
int64_t size_at_dim = graph->size_at<int64_t>(dim, in_tensor_ref);

graph->set_symint(out_symint_ref, static_cast<int32_t>(size_at_dim));
}

/*
* This operator takes a tensor and an integer dimension as inputs, and produces
* a symint as output. The symint's value is the size of the tensor at the
* specified dimension.
*/
void sym_size_int(ComputeGraph& graph, const std::vector<ValueRef>& args) {
ValueRef in_tensor = args[0];
ValueRef dim = args[1];
ValueRef out_symint = args[2];

int64_t dim_val = graph.extract_scalar<int64_t>(dim);

int64_t size_at_dim = graph.size_at<int64_t>(dim_val, in_tensor);
graph.set_symint(out_symint, static_cast<int32_t>(size_at_dim));

graph.execute_nodes().emplace_back(
new ExecuteNode(resize_sym_size_node, {out_symint, in_tensor, dim}));
}

REGISTER_OPERATORS {
VK_REGISTER_OP(sym_size.int, sym_size_int);
}

} // namespace vkcompute
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/View.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ void resize_view_node(
if (extra_args[0] == kDummyValueRef || graph->val_is_none(extra_args[0])) {
out->virtual_resize(in->sizes());
} else {
IntListPtr view_sizes = graph->get_int_list(extra_args[0]);
std::vector<int64_t> out_sizes =
compute_out_sizes(in->sizes(), *view_sizes);
std::vector<int64_t> view_sizes =
graph->extract_int_or_symint_list(extra_args[0]);
std::vector<int64_t> out_sizes = compute_out_sizes(in->sizes(), view_sizes);
out->virtual_resize(out_sizes);
}
}
Expand Down
55 changes: 47 additions & 8 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
is_constant,
is_get_attr_node,
is_param_node,
is_symint_node,
)
from executorch.exir.backend.utils import DelegateMappingBuilder

Expand Down Expand Up @@ -54,6 +55,8 @@ def __init__(

# Mapping from Node to VkValue id
self.node_to_value_ids = {}
# Mapping from const scalar value to created VkValue id
self.const_scalar_to_value_ids = {}

# For logging
self.seen_ops = set()
Expand Down Expand Up @@ -128,7 +131,7 @@ def maybe_add_constant_tensor(self, node: Node) -> int:

def create_node_value(self, node: Node) -> int:
# If the node has been marked as a scalar tensor, create a SymInt instead of a tensor
if node.meta.get("vkdg_is_scalar_tensor", False):
if is_symint_node(node) or node.meta.get("vkdg_is_scalar_tensor", False):
new_id = self.create_symint_value()
self.node_to_value_ids[node] = new_id
return new_id
Expand All @@ -146,21 +149,35 @@ def create_node_value(self, node: Node) -> int:
self.node_to_value_ids[node] = new_id
return new_id
else:
raise RuntimeError(f"Cannot create value for spec of type {type(spec)}")
raise RuntimeError(
f"Cannot create value for node {node} with spec of type {type(spec)}"
)

def create_null_value(self) -> int:
new_id = len(self.values)
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Null()))
return new_id

def create_scalar_value(self, scalar: _ScalarType) -> int:
def get_or_create_scalar_value(self, scalar: _ScalarType) -> int:
scalar_key = scalar
# Since Python considers 1 and True to be "equivalent" (as well as 0 and False)
# to distinguish entries in the dictionary, if scalar is bool then convert it
# to a string representation to use as a key for the dictionary
if isinstance(scalar, bool):
scalar_key = str(scalar)

if scalar_key in self.const_scalar_to_value_ids:
return self.const_scalar_to_value_ids[scalar_key]

new_id = len(self.values)
if isinstance(scalar, bool):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar)))
elif isinstance(scalar, int):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar)))
elif isinstance(scalar, float):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))

self.const_scalar_to_value_ids[scalar_key] = new_id
return new_id

def create_symint_value(self) -> int:
Expand Down Expand Up @@ -200,28 +217,50 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:

def create_scalar_list_value(self, arg: List[_ScalarType]) -> int:
new_id = len(self.values)

if len(arg) == 0:
self.values.append(
vk_graph_schema.VkValue(vk_graph_schema.IntList(items=[]))
)
elif isinstance(arg[0], bool):

all_bool = True
all_int = True
all_float = True
all_int_or_symint = True

for val in arg:
if not isinstance(val, bool):
all_bool = False
if not isinstance(val, int):
all_int = False
if not (isinstance(val, Node) and is_symint_node(val)):
all_int_or_symint = False
if not isinstance(val, float):
all_float = False

if all_bool:
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg])
)
)
elif isinstance(arg[0], int):
if all_int:
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.IntList(items=[cast(int, e) for e in arg])
)
)
elif isinstance(arg[0], float):
elif all_float:
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.DoubleList(items=[cast(float, e) for e in arg])
)
)
elif all_int_or_symint:
return self.create_value_list_value(arg)
else:
raise NotImplementedError(f"Cannot add value for list {arg}")

return new_id

def create_value_list_value(self, arg: tuple | list) -> int:
Expand Down Expand Up @@ -256,11 +295,11 @@ def get_or_create_value_for(self, arg: _Argument):
):
return self.create_null_value()
elif isinstance(arg, _ScalarType):
return self.create_scalar_value(arg)
return self.get_or_create_scalar_value(arg)
elif isinstance(arg, TensorSpec):
return self.create_tensor_value(arg)
elif isinstance(arg, list) and (
len(arg) == 0 or isinstance(arg[0], _ScalarType)
len(arg) == 0 or any(isinstance(val, _ScalarType) for val in arg)
):
# pyre-ignore[6]
return self.create_scalar_list_value(arg)
Expand Down
41 changes: 41 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,3 +1801,44 @@ def forward(self, x: torch.Tensor):
LinearModel(n_pca_basis, n_sh_basis, n_gaussians),
(torch.ones(n_pca_basis),),
)

def test_vulkan_backend_sym_size_int(self):
"""
Test the sym_size.int operator with a model that:
1. Takes an input tensor with shape [1, M, K]
2. Reshapes it to [M, K]
3. Applies a linear layer
4. Reshapes the output back to [1, M, N]
"""
K = 64 # Input feature dimension
N = 32 # Output feature dimension

class SymSizeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(K, N)

def forward(self, x):
M = x.size(1)

reshaped = torch.reshape(x, [M, K])
output = self.linear(reshaped)
return torch.reshape(output, [1, M, N])

sample_inputs = (torch.randn(1, 64, K),)

batch = Dim("batch", min=1, max=128)
dynamic_shapes = {"x": {1: batch}}

test_inputs = [
(torch.randn(1, 32, K),),
(torch.randn(1, 96, K),),
(torch.randn(1, 128, K),),
]

self.lower_module_and_test_output(
SymSizeModel(),
sample_inputs,
dynamic_shapes=dynamic_shapes,
test_inputs=test_inputs,
)
4 changes: 0 additions & 4 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,6 @@ def is_tensor_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node produces a tensor value, or a collection of tensor values
"""
# All nodes with tensor values are tagged by the SpecPropPass transform
if "spec" in node.meta:
return True

if "val" not in node.meta:
return False

Expand Down
Loading