Skip to content

Commit 0c8cfa9

Browse files
authored
[ET-VK] Support exporting graphs with symbolic shape ops + update view to accept sym_size args (#10997)
## Context The ultimate goal is to be able to export the transformer models with dynamic shapes enabled so that batched prefill can be done. With transformer models, when dynamic shapes are turned on, `sym_size` operators appear in the graph which are used to determine the `seq_len` of the inputs, i.e. how many tokens are being passed into the input sequence. The `sym_size` operator accepts a tensor and a dim, and extracts the size of the tensor at the specified dim as a symbolic integer. In the transformer model, the `seq_len` symint is used as an argument to `view` operators. This PR enables exporting graphs with symbolic integer nodes and in particular the `sym_size` operator, as well as handling when symints are used in a list of ints. # Changes * Miscellaneous fixes to fix errors that show occur when symint nodes are encountered * Add C++ implementation of symint nodes and add registration for it * Enable the view operator to handle when the sizes arg includes symints Differential Revision: [D75019798](https://our.internmc.facebook.com/intern/diff/D75019798/)
1 parent 6b48e89 commit 0c8cfa9

File tree

10 files changed

+198
-21
lines changed

10 files changed

+198
-21
lines changed

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import logging
88
from copy import deepcopy
9-
from typing import Any, Set
9+
from typing import Any, Optional, Set
1010

1111
import executorch.backends.vulkan.utils as utils
1212

@@ -94,7 +94,7 @@ def __init__(
9494
def propose_node_storage(
9595
self,
9696
node: torch.fx.Node,
97-
) -> VkStorageType:
97+
) -> Optional[VkStorageType]:
9898
"""
9999
Uses the operator registry to determine the storage type that should be used for
100100
a given node. The storage type is determined with the following priorities:
@@ -114,6 +114,9 @@ def propose_node_storage(
114114
opinionated user can be found, then proceed to the last step.
115115
4. Use the default storage type setting.
116116
"""
117+
if not utils.is_tensor_node(node):
118+
return None
119+
117120
# The node may have an input/output tensor that is too big to be stored in a
118121
# texture. In this case, buffer storage must be used. Note that the partitioner
119122
# has already checked for the fact that buffer storage is supported by the
@@ -154,12 +157,15 @@ def propose_node_layout(
154157
self,
155158
node: torch.fx.Node,
156159
storage: VkStorageType,
157-
) -> VkMemoryLayout:
160+
) -> Optional[VkMemoryLayout]:
158161
"""
159162
Performs the same steps as propose_node_storage, but detects the memory layout
160163
that should be used for the specific storage type. The same prioritization logic
161164
is applied.
162165
"""
166+
if not utils.is_tensor_node(node):
167+
return None
168+
163169
valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts
164170
# pyre-ignore
165171
if has_impl(node.target):

backends/vulkan/op_registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def update_features_impl(op: OpKey):
228228
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
229229
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
230230
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231+
# Symbolic integer ops
232+
torch.ops.aten.sym_size.int,
231233
]
232234
)
233235
def register_ephemeral_op(features: OpFeatures):
@@ -505,6 +507,7 @@ def register_sdpa_ops(features: OpFeatures):
505507
features.texture_impl = TextureImplFeatures(
506508
valid_packed_dims={PackedDim.WIDTH},
507509
)
510+
features.resize_fn = True
508511
return features
509512

510513

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,8 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
146146
def node_is_compatible(
147147
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
148148
) -> Tuple[bool, str]:
149-
# TODO(ssjia) support symbolic ints
150149
if utils.is_symint_node(node):
151-
return False, "symint node not supported yet"
150+
return node.target in vulkan_supported_ops, "Op is compatible"
152151
elif utils.is_tensor_node(node):
153152
return self.op_node_is_compatible(node, features=features)
154153

@@ -258,7 +257,7 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool:
258257
if target not in vulkan_supported_ops:
259258
# For some ops, i.e. custom ops the name is registered instead of the
260259
# OpOverload object.
261-
if not isinstance(target, str) and target.name() in vulkan_supported_ops:
260+
if hasattr(target, "name") and target.name() in vulkan_supported_ops:
262261
features = vulkan_supported_ops[target.name()]
263262
else:
264263
self.log_skip(node, "no operator implementation")

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,38 @@ ComputeGraph::~ComputeGraph() {
156156
context_->flush();
157157
}
158158

159+
std::vector<int64_t> ComputeGraph::extract_int_or_symint_list(
160+
const ValueRef idx) {
161+
const Value& val = values_.at(idx);
162+
std::vector<int64_t> result;
163+
164+
if (val.isIntList()) {
165+
// If it's an IntList, return a copy of the list
166+
return val.toConstIntList();
167+
} else if (val.isValueList()) {
168+
// If it's a ValueList, extract each element as an Int or SymInt
169+
const std::vector<ValueRef>& value_list = val.toConstValueList();
170+
result.reserve(value_list.size());
171+
172+
for (const ValueRef& ref : value_list) {
173+
const Value& element = values_.at(ref);
174+
if (element.isInt()) {
175+
result.push_back(element.toInt());
176+
} else if (element.isSymInt()) {
177+
result.push_back(read_symint(ref));
178+
} else {
179+
VK_THROW(
180+
"ValueList element is neither Int nor SymInt, but has type ",
181+
element.type());
182+
}
183+
}
184+
return result;
185+
}
186+
187+
VK_THROW(
188+
"Cannot extract int or symint list from Value with type ", val.type());
189+
}
190+
159191
utils::StorageType ComputeGraph::suggested_storage_type() {
160192
if (config_.enable_storage_type_override) {
161193
return config_.storage_type_override;

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,15 @@ class ComputeGraph final {
405405
return values_.at(idx).toString();
406406
}
407407

408+
/*
409+
* Utility function to extract a list of integers from a ValueRef.
410+
* If the ValueRef is an IntList, returns a copy of the list.
411+
* If the ValueRef is a ValueList, extracts each element as an Int or SymInt
412+
* and returns the resulting list.
413+
* Throws an error if the ValueRef is neither an IntList nor a ValueList.
414+
*/
415+
std::vector<int64_t> extract_int_or_symint_list(const ValueRef idx);
416+
408417
template <
409418
typename T,
410419
typename std::enable_if<
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
10+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
11+
12+
namespace vkcompute {
13+
14+
void resize_sym_size_node(
15+
ComputeGraph* graph,
16+
const std::vector<ArgGroup>& args,
17+
const std::vector<ValueRef>& extra_args) {
18+
(void)args; // Unused parameter
19+
20+
ValueRef out_symint_ref = extra_args[0];
21+
ValueRef in_tensor_ref = extra_args[1];
22+
23+
int64_t dim = graph->extract_scalar<int64_t>(extra_args[2]);
24+
int64_t size_at_dim = graph->size_at<int64_t>(dim, in_tensor_ref);
25+
26+
graph->set_symint(out_symint_ref, static_cast<int32_t>(size_at_dim));
27+
}
28+
29+
/*
30+
* This operator takes a tensor and an integer dimension as inputs, and produces
31+
* a symint as output. The symint's value is the size of the tensor at the
32+
* specified dimension.
33+
*/
34+
void sym_size_int(ComputeGraph& graph, const std::vector<ValueRef>& args) {
35+
ValueRef in_tensor = args[0];
36+
ValueRef dim = args[1];
37+
ValueRef out_symint = args[2];
38+
39+
int64_t dim_val = graph.extract_scalar<int64_t>(dim);
40+
41+
int64_t size_at_dim = graph.size_at<int64_t>(dim_val, in_tensor);
42+
graph.set_symint(out_symint, static_cast<int32_t>(size_at_dim));
43+
44+
graph.execute_nodes().emplace_back(
45+
new ExecuteNode(resize_sym_size_node, {out_symint, in_tensor, dim}));
46+
}
47+
48+
REGISTER_OPERATORS {
49+
VK_REGISTER_OP(sym_size.int, sym_size_int);
50+
}
51+
52+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/View.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ void resize_view_node(
4848
if (extra_args[0] == kDummyValueRef || graph->val_is_none(extra_args[0])) {
4949
out->virtual_resize(in->sizes());
5050
} else {
51-
IntListPtr view_sizes = graph->get_int_list(extra_args[0]);
52-
std::vector<int64_t> out_sizes =
53-
compute_out_sizes(in->sizes(), *view_sizes);
51+
std::vector<int64_t> view_sizes =
52+
graph->extract_int_or_symint_list(extra_args[0]);
53+
std::vector<int64_t> out_sizes = compute_out_sizes(in->sizes(), view_sizes);
5454
out->virtual_resize(out_sizes);
5555
}
5656
}

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
is_constant,
2222
is_get_attr_node,
2323
is_param_node,
24+
is_symint_node,
2425
)
2526
from executorch.exir.backend.utils import DelegateMappingBuilder
2627

@@ -54,6 +55,8 @@ def __init__(
5455

5556
# Mapping from Node to VkValue id
5657
self.node_to_value_ids = {}
58+
# Mapping from const scalar value to created VkValue id
59+
self.const_scalar_to_value_ids = {}
5760

5861
# For logging
5962
self.seen_ops = set()
@@ -128,7 +131,7 @@ def maybe_add_constant_tensor(self, node: Node) -> int:
128131

129132
def create_node_value(self, node: Node) -> int:
130133
# If the node has been marked as a scalar tensor, create a SymInt instead of a tensor
131-
if node.meta.get("vkdg_is_scalar_tensor", False):
134+
if is_symint_node(node) or node.meta.get("vkdg_is_scalar_tensor", False):
132135
new_id = self.create_symint_value()
133136
self.node_to_value_ids[node] = new_id
134137
return new_id
@@ -146,21 +149,35 @@ def create_node_value(self, node: Node) -> int:
146149
self.node_to_value_ids[node] = new_id
147150
return new_id
148151
else:
149-
raise RuntimeError(f"Cannot create value for spec of type {type(spec)}")
152+
raise RuntimeError(
153+
f"Cannot create value for node {node} with spec of type {type(spec)}"
154+
)
150155

151156
def create_null_value(self) -> int:
152157
new_id = len(self.values)
153158
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Null()))
154159
return new_id
155160

156-
def create_scalar_value(self, scalar: _ScalarType) -> int:
161+
def get_or_create_scalar_value(self, scalar: _ScalarType) -> int:
162+
scalar_key = scalar
163+
# Since Python considers 1 and True to be "equivalent" (as well as 0 and False)
164+
# to distinguish entries in the dictionary, if scalar is bool then convert it
165+
# to a string representation to use as a key for the dictionary
166+
if isinstance(scalar, bool):
167+
scalar_key = str(scalar)
168+
169+
if scalar_key in self.const_scalar_to_value_ids:
170+
return self.const_scalar_to_value_ids[scalar_key]
171+
157172
new_id = len(self.values)
158173
if isinstance(scalar, bool):
159174
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar)))
160175
elif isinstance(scalar, int):
161176
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar)))
162177
elif isinstance(scalar, float):
163178
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))
179+
180+
self.const_scalar_to_value_ids[scalar_key] = new_id
164181
return new_id
165182

166183
def create_symint_value(self) -> int:
@@ -200,28 +217,50 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
200217

201218
def create_scalar_list_value(self, arg: List[_ScalarType]) -> int:
202219
new_id = len(self.values)
220+
203221
if len(arg) == 0:
204222
self.values.append(
205223
vk_graph_schema.VkValue(vk_graph_schema.IntList(items=[]))
206224
)
207-
elif isinstance(arg[0], bool):
225+
226+
all_bool = True
227+
all_int = True
228+
all_float = True
229+
all_int_or_symint = True
230+
231+
for val in arg:
232+
if not isinstance(val, bool):
233+
all_bool = False
234+
if not isinstance(val, int):
235+
all_int = False
236+
if not (isinstance(val, Node) and is_symint_node(val)):
237+
all_int_or_symint = False
238+
if not isinstance(val, float):
239+
all_float = False
240+
241+
if all_bool:
208242
self.values.append(
209243
vk_graph_schema.VkValue(
210244
vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg])
211245
)
212246
)
213-
elif isinstance(arg[0], int):
247+
if all_int:
214248
self.values.append(
215249
vk_graph_schema.VkValue(
216250
vk_graph_schema.IntList(items=[cast(int, e) for e in arg])
217251
)
218252
)
219-
elif isinstance(arg[0], float):
253+
elif all_float:
220254
self.values.append(
221255
vk_graph_schema.VkValue(
222256
vk_graph_schema.DoubleList(items=[cast(float, e) for e in arg])
223257
)
224258
)
259+
elif all_int_or_symint:
260+
return self.create_value_list_value(arg)
261+
else:
262+
raise NotImplementedError(f"Cannot add value for list {arg}")
263+
225264
return new_id
226265

227266
def create_value_list_value(self, arg: tuple | list) -> int:
@@ -256,11 +295,11 @@ def get_or_create_value_for(self, arg: _Argument):
256295
):
257296
return self.create_null_value()
258297
elif isinstance(arg, _ScalarType):
259-
return self.create_scalar_value(arg)
298+
return self.get_or_create_scalar_value(arg)
260299
elif isinstance(arg, TensorSpec):
261300
return self.create_tensor_value(arg)
262301
elif isinstance(arg, list) and (
263-
len(arg) == 0 or isinstance(arg[0], _ScalarType)
302+
len(arg) == 0 or any(isinstance(val, _ScalarType) for val in arg)
264303
):
265304
# pyre-ignore[6]
266305
return self.create_scalar_list_value(arg)

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,3 +1801,44 @@ def forward(self, x: torch.Tensor):
18011801
LinearModel(n_pca_basis, n_sh_basis, n_gaussians),
18021802
(torch.ones(n_pca_basis),),
18031803
)
1804+
1805+
def test_vulkan_backend_sym_size_int(self):
1806+
"""
1807+
Test the sym_size.int operator with a model that:
1808+
1. Takes an input tensor with shape [1, M, K]
1809+
2. Reshapes it to [M, K]
1810+
3. Applies a linear layer
1811+
4. Reshapes the output back to [1, M, N]
1812+
"""
1813+
K = 64 # Input feature dimension
1814+
N = 32 # Output feature dimension
1815+
1816+
class SymSizeModel(torch.nn.Module):
1817+
def __init__(self):
1818+
super().__init__()
1819+
self.linear = torch.nn.Linear(K, N)
1820+
1821+
def forward(self, x):
1822+
M = x.size(1)
1823+
1824+
reshaped = torch.reshape(x, [M, K])
1825+
output = self.linear(reshaped)
1826+
return torch.reshape(output, [1, M, N])
1827+
1828+
sample_inputs = (torch.randn(1, 64, K),)
1829+
1830+
batch = Dim("batch", min=1, max=128)
1831+
dynamic_shapes = {"x": {1: batch}}
1832+
1833+
test_inputs = [
1834+
(torch.randn(1, 32, K),),
1835+
(torch.randn(1, 96, K),),
1836+
(torch.randn(1, 128, K),),
1837+
]
1838+
1839+
self.lower_module_and_test_output(
1840+
SymSizeModel(),
1841+
sample_inputs,
1842+
dynamic_shapes=dynamic_shapes,
1843+
test_inputs=test_inputs,
1844+
)

backends/vulkan/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,6 @@ def is_tensor_node(node: torch.fx.Node) -> bool:
101101
"""
102102
Returns true if the given node produces a tensor value, or a collection of tensor values
103103
"""
104-
# All nodes with tensor values are tagged by the SpecPropPass transform
105-
if "spec" in node.meta:
106-
return True
107-
108104
if "val" not in node.meta:
109105
return False
110106

0 commit comments

Comments
 (0)