Skip to content

Commit b8fbc48

Browse files
[ET-VK] Deserialize VkGraph in ET-VK
Pull Request resolved: #7068 Add logic to deserialize a VkGraph blob back python object. This allows us to get a implement debugging / visualization directly on the vulkan-exported program. Still extra works need to be done: From the entire bundle, need to extract the specific vulkan delegate first. ghstack-source-id: 255454169 Differential Revision: [D66443780](https://our.internmc.facebook.com/intern/diff/D66443780/) Co-authored-by: Justin Yip <[email protected]>
1 parent 633057c commit b8fbc48

File tree

2 files changed

+65
-4
lines changed

2 files changed

+65
-4
lines changed

backends/vulkan/serialization/vulkan_graph_serialize.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
33
#
4+
# pyre-strict
5+
#
46
# This source code is licensed under the BSD-style license found in the
57
# LICENSE file in the root directory of this source tree.
68

@@ -19,9 +21,9 @@
1921
VkBytes,
2022
VkGraph,
2123
)
22-
from executorch.exir._serialize._dataclass import _DataclassEncoder
24+
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
2325

24-
from executorch.exir._serialize._flatbuffer import _flatc_compile
26+
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
2527

2628

2729
def convert_to_flatbuffer(vk_graph: VkGraph) -> bytes:
@@ -40,6 +42,25 @@ def convert_to_flatbuffer(vk_graph: VkGraph) -> bytes:
4042
return output_file.read()
4143

4244

45+
def flatbuffer_to_vk_graph(flatbuffers: bytes) -> VkGraph:
46+
# Following similar (de)serialization logic on other backends:
47+
# https://github.com/pytorch/executorch/blob/main/backends/qualcomm/serialization/qc_schema_serialize.py#L33
48+
with tempfile.TemporaryDirectory() as d:
49+
schema_path = os.path.join(d, "schema.fbs")
50+
with open(schema_path, "wb") as schema_file:
51+
schema_file.write(pkg_resources.resource_string(__name__, "schema.fbs"))
52+
53+
bin_path = os.path.join(d, "schema.bin")
54+
with open(bin_path, "wb") as bin_file:
55+
bin_file.write(flatbuffers)
56+
57+
_flatc_decompile(d, schema_path, bin_path, ["--raw-binary"])
58+
59+
json_path = os.path.join(d, "schema.json")
60+
with open(json_path, "rb") as output_file:
61+
return _json_to_dataclass(json.load(output_file), VkGraph)
62+
63+
4364
@dataclass
4465
class VulkanDelegateHeader:
4566
# Defines the byte region that each component of the header corresponds to

backends/vulkan/test/test_serialization.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
33
#
4+
# pyre-strict
5+
#
46
# This source code is licensed under the BSD-style license found in the
57
# LICENSE file in the root directory of this source tree.
68

@@ -11,9 +13,17 @@
1113

1214
import torch
1315

14-
from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkGraph
16+
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
17+
IntList,
18+
OperatorCall,
19+
String,
20+
VkGraph,
21+
VkValue,
22+
)
1523

1624
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
25+
convert_to_flatbuffer,
26+
flatbuffer_to_vk_graph,
1727
serialize_vulkan_graph,
1828
VulkanDelegateHeader,
1929
)
@@ -36,7 +46,7 @@ def _generate_random_const_tensors(self, num_tensors: int) -> List[torch.Tensor]
3646

3747
return tensors
3848

39-
def test_serialize_vulkan_binary(self):
49+
def test_serialize_vulkan_binary(self) -> None:
4050
vk_graph = VkGraph(
4151
version="0",
4252
chain=[],
@@ -93,3 +103,33 @@ def test_serialize_vulkan_binary(self):
93103

94104
tensor_bytes = bytes(array)
95105
self.assertEqual(constant_data_bytes, tensor_bytes)
106+
107+
def test_serialize_deserialize_vkgraph(self) -> None:
108+
in_vk_graph = VkGraph(
109+
version="1",
110+
chain=[
111+
OperatorCall(node_id=1, name="foo", args=[1, 2, 3]),
112+
OperatorCall(node_id=2, name="bar", args=[]),
113+
],
114+
values=[
115+
VkValue(
116+
value=String(
117+
string_val="abc",
118+
),
119+
),
120+
VkValue(
121+
value=IntList(
122+
items=[-1, -4, 2],
123+
),
124+
),
125+
],
126+
input_ids=[],
127+
output_ids=[],
128+
constants=[],
129+
shaders=[],
130+
)
131+
132+
bs = convert_to_flatbuffer(in_vk_graph)
133+
out_vk_graph = flatbuffer_to_vk_graph(bs)
134+
135+
self.assertEqual(in_vk_graph, out_vk_graph)

0 commit comments

Comments
 (0)