Skip to content

Commit b9e4fb5

Browse files
committed
[ET-VK][ez] Introduce GraphBuilder abstraction
## Context This changeset introduces `VkGraphBuilder` to handle parsing an `ExportedProgram` and parsing it to construct a `VkGraph` for serialization. Most of the graph parsing functionality previously implemented in `vulkan_preprocess` is now in `VkGraphBuilder`. The main motivation of this refactor is to simplify `vulkan_preprocess`. Differential Revision: [D54128091](https://our.internmc.facebook.com/intern/diff/D54128091/) [ghstack-poisoned]
1 parent 33ba563 commit b9e4fb5

File tree

3 files changed

+227
-105
lines changed

3 files changed

+227
-105
lines changed

backends/vulkan/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ define_common_targets()
88
runtime.python_library(
99
name = "vulkan_preprocess",
1010
srcs = [
11+
"serialization/vulkan_graph_builder.py",
1112
"serialization/vulkan_graph_schema.py",
1213
"serialization/vulkan_graph_serialize.py",
1314
"vulkan_preprocess.py",
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
10+
11+
import torch
12+
13+
from executorch.exir.tensor import TensorSpec
14+
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
15+
from torch.export import ExportedProgram
16+
from torch.fx import Node
17+
18+
19+
class VkGraphBuilder:
20+
def __init__(self, program: ExportedProgram) -> None:
21+
self.program = program
22+
23+
self.chain = []
24+
self.values = []
25+
self.input_ids = []
26+
self.output_ids = []
27+
self.const_tensors = []
28+
29+
# Mapping from torch.fx.Node to VkValue id
30+
self.node_to_value_ids = {}
31+
32+
@staticmethod
33+
def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
34+
if torch_dtype == torch.float32:
35+
return vk_graph_schema.VkDataType.fp32
36+
else:
37+
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")
38+
39+
def is_constant(self, node: torch.fx.Node):
40+
return (
41+
node.name in self.program.graph_signature.inputs_to_lifted_tensor_constants
42+
)
43+
44+
def is_get_attr_node(self, node: torch.fx.Node) -> bool:
45+
"""
46+
Returns true if the given node is a get attr node for a tensor of the model
47+
"""
48+
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
49+
50+
def is_param_node(self, node: torch.fx.Node) -> bool:
51+
"""
52+
Check if the given node is a parameter within the exported program
53+
"""
54+
return (
55+
self.is_get_attr_node(node)
56+
or is_param(self.program, node)
57+
or is_buffer(self.program, node)
58+
or self.is_constant(node)
59+
)
60+
61+
def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]:
62+
"""
63+
Returns the constant associated with the given node in the exported program.
64+
Returns None if the node is not a constant within the exported program
65+
"""
66+
if self.is_constant(node):
67+
constant_name = (
68+
self.program.graph_signature.inputs_to_lifted_tensor_constants[
69+
node.name
70+
]
71+
)
72+
if constant_name in self.program.constants:
73+
return self.program.constants[constant_name]
74+
else:
75+
return None
76+
77+
return None
78+
79+
def get_param_tensor(self, node: torch.fx.Node) -> torch.Tensor:
80+
tensor = None
81+
if node is None:
82+
raise RuntimeError("node is None")
83+
elif is_param(self.program, node):
84+
tensor = get_param(self.program, node)
85+
elif is_buffer(self.program, node):
86+
tensor = get_buffer(self.program, node)
87+
elif self.is_constant(node):
88+
tensor = self.get_constant(node)
89+
elif self.is_get_attr_node(node):
90+
# This is a hack to support both lifted and unlifted graph
91+
try:
92+
tensor = getattr(node.graph.owning_module, node.target)
93+
except AttributeError:
94+
tensor = getattr(self.program.graph_module, node.target)
95+
else:
96+
raise RuntimeError(f"unsupported param type, {node.op}.")
97+
98+
assert tensor is not None
99+
return tensor
100+
101+
def maybe_add_constant_tensor(self, node: Node) -> int:
102+
const_buffer_idx = -1
103+
if self.is_param_node(node):
104+
const_buffer_idx = len(self.const_tensors)
105+
self.const_tensors.append(self.get_param_tensor(node))
106+
107+
return const_buffer_idx
108+
109+
def create_single_vk_value(self, node: Node) -> int:
110+
constant_id = self.maybe_add_constant_tensor(node)
111+
112+
spec = node.meta.get("spec")
113+
assert isinstance(spec, TensorSpec)
114+
new_id = len(self.values)
115+
if node not in self.node_to_value_ids:
116+
self.node_to_value_ids[node] = new_id
117+
else:
118+
current_ids = self.node_to_value_ids[node]
119+
if isinstance(current_ids, int):
120+
current_ids = [current_ids, new_id]
121+
else:
122+
current_ids.append(new_id)
123+
124+
# Negative id indicates that this tensor will have its own dedicated memory.
125+
mem_obj_id = -1
126+
if spec.mem_obj_id is not None:
127+
mem_obj_id = spec.mem_obj_id
128+
129+
self.values.append(
130+
vk_graph_schema.VkValue(
131+
value=vk_graph_schema.VkTensor(
132+
datatype=self.get_vk_datatype(spec.dtype),
133+
dims=spec.shape,
134+
constant_id=constant_id,
135+
mem_obj_id=mem_obj_id,
136+
)
137+
)
138+
)
139+
return new_id
140+
141+
def create_vk_values_for(self, node: Node):
142+
spec = node.meta.get("spec")
143+
if isinstance(spec, TensorSpec):
144+
return self.create_single_vk_value(node)
145+
else:
146+
raise RuntimeError(
147+
"Creating values for nodes with collection types is not supported yet."
148+
)
149+
150+
def process_placeholder_node(self, node: Node) -> None:
151+
ids = self.create_vk_values_for(node)
152+
if not self.is_param_node(node):
153+
if isinstance(ids, int):
154+
self.input_ids.append(ids)
155+
else:
156+
self.input_ids += ids
157+
158+
def process_call_function_node(self, node) -> None:
159+
args = []
160+
# Add input nodes
161+
for inp_node in node.all_input_nodes:
162+
if inp_node not in self.node_to_value_ids:
163+
raise AssertionError(
164+
"Cannot find input to current node in node_to_value_ids. This means "
165+
"this node is being serialized before its input which is not allowed."
166+
)
167+
args.append(self.node_to_value_ids[inp_node])
168+
# Add output node
169+
args.append(self.create_vk_values_for(node))
170+
171+
self.chain.append(
172+
vk_graph_schema.OperatorCall(
173+
name=node.target.__name__,
174+
args=args,
175+
),
176+
)
177+
178+
def process_getattr_node(self, node: Node) -> None:
179+
self.create_vk_values_for(node)
180+
181+
def process_output_node(self, node: Node) -> None:
182+
if node.all_input_nodes[0] not in self.node_to_value_ids:
183+
raise AssertionError(
184+
"Cannot find input to output node in node_to_value_ids. This means the "
185+
"output node is being serialized before its corresponding internal node "
186+
"which is not allowed."
187+
)
188+
self.output_ids.append(self.node_to_value_ids[node.all_input_nodes[0]])
189+
190+
def process_node(self, node: Node) -> None:
191+
if node.op == "placeholder":
192+
self.process_placeholder_node(node)
193+
elif node.op == "call_function":
194+
self.process_call_function_node(node)
195+
elif node.op == "get_attr":
196+
self.process_getattr_node(node)
197+
elif node.op == "output":
198+
self.process_output_node(node)
199+
else:
200+
raise AssertionError(f"Unsupported node op: {node.op}")
201+
202+
def build_graph(self) -> vk_graph_schema.VkGraph:
203+
for node in self.program.graph_module.graph.nodes:
204+
self.process_node(node)
205+
206+
return vk_graph_schema.VkGraph(
207+
version="0",
208+
chain=self.chain,
209+
values=self.values,
210+
input_ids=self.input_ids,
211+
output_ids=self.output_ids,
212+
constants=[],
213+
shaders=[],
214+
)

backends/vulkan/vulkan_preprocess.py

Lines changed: 12 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import final, List
7+
from typing import final, List, Optional
88

99
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
10+
11+
import torch
12+
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
1013
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
1114
serialize_vulkan_graph,
1215
)
@@ -23,6 +26,7 @@
2326
from executorch.exir.program._program import _copy_module
2427
from executorch.exir.tensor import TensorSpec
2528
from torch import dtype, float32
29+
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
2630
from torch.fx import Node
2731

2832
DEFAULT_DEBUG_HANDLE = 65535
@@ -44,62 +48,13 @@ def preprocess( # noqa: C901
4448
program: ExportedProgram,
4549
module_compile_spec: List[CompileSpec],
4650
) -> PreprocessResult:
47-
vk_chain = []
48-
vk_values = []
49-
vk_input_ids = []
50-
vk_output_ids = []
51-
const_tensors = []
52-
53-
# Mapping from graph Node to schema VkValue.
54-
node_to_value_ids = {}
55-
56-
def create_single_vk_value(node: Node, constant_id: int = -1) -> int:
57-
spec = node.meta.get("spec")
58-
assert isinstance(spec, TensorSpec)
59-
new_id = len(vk_values)
60-
if node not in node_to_value_ids:
61-
node_to_value_ids[node] = new_id
62-
else:
63-
current_ids = node_to_value_ids[node]
64-
if isinstance(current_ids, int):
65-
current_ids = [current_ids, new_id]
66-
else:
67-
current_ids.append(new_id)
68-
69-
# Negative id indicates that this tensor will have its own dedicated memory.
70-
mem_obj_id = -1
71-
if spec.mem_obj_id is not None:
72-
mem_obj_id = spec.mem_obj_id
73-
74-
vk_values.append(
75-
vk_graph_schema.VkValue(
76-
value=vk_graph_schema.VkTensor(
77-
datatype=VulkanBackend.get_vk_datatype(spec.dtype),
78-
dims=spec.shape,
79-
constant_id=constant_id,
80-
mem_obj_id=mem_obj_id,
81-
)
82-
)
83-
)
84-
return new_id
85-
86-
def create_vk_values_for(node: Node, constant_id: int = -1):
87-
spec = node.meta.get("spec")
88-
89-
if isinstance(spec, TensorSpec):
90-
return create_single_vk_value(node, constant_id)
91-
else:
92-
ids = []
93-
for _ in spec:
94-
ids.append(create_single_vk_value(node, constant_id))
95-
return ids
96-
9751
passes = [
9852
SpecPropPass(),
9953
MemoryPlanningPass("greedy"),
10054
]
10155

10256
new_gm = program.graph_module
57+
10358
for p in passes:
10459
# This is a workaround to allow the memory planning pass to work without
10560
# having to first apply ToOutVarPass(). See the `greedy()` function in
@@ -110,62 +65,14 @@ def create_vk_values_for(node: Node, constant_id: int = -1):
11065
new_gm_res = p(new_gm)
11166
assert new_gm_res is not None
11267
new_gm = new_gm_res.graph_module
113-
_copy_module(program.graph_module, new_gm)
11468

115-
for node in program.graph_module.graph.nodes:
116-
if node.op == "placeholder":
117-
# Input
118-
ids = create_vk_values_for(node)
119-
if isinstance(ids, int):
120-
vk_input_ids.append(ids)
121-
else:
122-
vk_input_ids += ids
123-
elif node.op == "call_function":
124-
# Op
125-
if (
126-
node.all_input_nodes[0] not in node_to_value_ids
127-
or node.all_input_nodes[1] not in node_to_value_ids
128-
):
129-
raise AssertionError(
130-
"Cannot find input(s) for current node in node_to_value_ids. This means this node is being serialized before its input(s) which is not allowed."
131-
)
132-
vk_chain.append(
133-
vk_graph_schema.OperatorCall(
134-
name=node.target.__name__,
135-
args=[
136-
node_to_value_ids[node.all_input_nodes[0]],
137-
node_to_value_ids[node.all_input_nodes[1]],
138-
create_vk_values_for(node),
139-
],
140-
),
141-
)
142-
elif node.op == "get_attr":
143-
constant_id = len(const_tensors)
144-
const_tensors.append(
145-
getattr(node.graph.owning_module, node.target).contiguous()
146-
)
147-
148-
create_vk_values_for(node, constant_id)
69+
_copy_module(program.graph_module, new_gm)
14970

150-
elif node.op == "output":
151-
if node.all_input_nodes[0] not in node_to_value_ids:
152-
raise AssertionError(
153-
"Cannot find input to output node in node_to_value_ids. This means the output node is being serialized before its corresponding internal node which is not allowed."
154-
)
155-
vk_output_ids.append(node_to_value_ids[node.all_input_nodes[0]])
156-
else:
157-
raise RuntimeError(f"Unsupported op, {node.op}, in Vulkan Preprocess")
71+
graph_builder = VkGraphBuilder(program)
72+
vk_graph = graph_builder.build_graph()
15873

159-
# Raw objects (constants and shaders) are populated in the next line's method.
160-
vk_graph = vk_graph_schema.VkGraph(
161-
version="0",
162-
chain=vk_chain,
163-
values=vk_values,
164-
input_ids=vk_input_ids,
165-
output_ids=vk_output_ids,
166-
constants=[],
167-
shaders=[],
168-
)
16974
return PreprocessResult(
170-
processed_bytes=serialize_vulkan_graph(vk_graph, const_tensors, []),
75+
processed_bytes=serialize_vulkan_graph(
76+
vk_graph, graph_builder.const_tensors, []
77+
),
17178
)

0 commit comments

Comments
 (0)