Skip to content

Commit 2dc56dd

Browse files
angelayifacebook-github-bot
authored andcommitted
Serialize delegate
Summary: Since call_delegate is an executorch concept, we serialize the LoweredModules as string arguments, but under the hood they are a serialized json string. Reviewed By: tarun292 Differential Revision: D47252889 fbshipit-source-id: 4f017fb6bda8a56f21f734a102c04a78a5853fc0
1 parent 5719d28 commit 2dc56dd

File tree

5 files changed

+244
-0
lines changed

5 files changed

+244
-0
lines changed

exir/serde/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,21 @@ python_library(
88
"serialize.py",
99
],
1010
deps = [
11+
":schema",
1112
"//caffe2:torch",
13+
"//executorch/backends:compile_spec_schema",
14+
"//executorch/exir:delegate",
1215
"//executorch/exir:lib",
1316
"//executorch/exir:memory",
1417
],
1518
)
19+
20+
python_library(
21+
name = "schema",
22+
srcs = [
23+
"schema.py",
24+
],
25+
deps = [
26+
"//caffe2:torch",
27+
],
28+
)

exir/serde/schema.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Additional schema from torch._export.serde.schema that is edge specific
2+
3+
from dataclasses import dataclass
4+
from typing import List
5+
6+
import torch._export.serde.schema as export_schema
7+
8+
9+
@dataclass
10+
class CompileSpec:
11+
key: str
12+
value: str
13+
14+
15+
@dataclass
16+
class LoweredBackendModule:
17+
backend_id: str
18+
processed_bytes: str
19+
compile_specs: List[CompileSpec]
20+
original_module: export_schema.ExportedProgram
21+
original_state_dict: str

exir/serde/serialize.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pyre-strict
22

3+
import base64
34
import copy
45
import dataclasses
56
import json
@@ -8,11 +9,14 @@
89
from typing import Any, Callable, Dict, List, Optional, Tuple
910

1011
import executorch.exir as exir
12+
import executorch.exir.delegate as delegate
1113
import executorch.exir.memory as memory
1214
import torch
1315
import torch._export.exported_program as ep
1416
import torch._export.serde.schema as schema
1517
import torch._export.serde.serialize as export_serialize
18+
from executorch.backends.compile_spec_schema import CompileSpec as delegate_CompileSpec
19+
from executorch.exir.serde.schema import CompileSpec, LoweredBackendModule
1620
from torch.fx.experimental import symbolic_shapes
1721

1822

@@ -39,6 +43,16 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
3943
self.graph_state.nodes.append(ex_node)
4044
return
4145

46+
elif node.target is delegate.executorch_call_delegate:
47+
ex_node = schema.Node(
48+
target=export_serialize.serialize_operator(node.target),
49+
inputs=self.serialize_call_delegate_inputs(node.args),
50+
outputs=self.serialize_arbitrary_outputs(node),
51+
metadata=self.serialize_metadata(node),
52+
)
53+
self.graph_state.nodes.append(ex_node)
54+
return
55+
4256
super().handle_call_function(node)
4357

4458
def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
@@ -138,6 +152,71 @@ def serialize_graph(self, graph_module: torch.fx.GraphModule) -> schema.Graph:
138152
self.original_graph_module: torch.fx.GraphModule = graph_module # pyre-ignore
139153
return super().serialize_graph(graph_module)
140154

155+
def serialize_call_delegate_inputs(
156+
self, args # pyre-ignore
157+
) -> List[schema.NamedArgument]:
158+
lowered_module_arg = args[0]
159+
delegate_args = args[1:]
160+
161+
serialized_lowered_module = self.serialize_lowered_module(lowered_module_arg)
162+
serialized_lowered_module_arg = schema.NamedArgument(
163+
name=lowered_module_arg.target,
164+
arg=schema.Argument.create(as_string=serialized_lowered_module),
165+
)
166+
167+
serialized_args = [serialized_lowered_module_arg]
168+
for i, arg in enumerate(delegate_args):
169+
serialized_args.append(
170+
schema.NamedArgument(
171+
name=f"delegate_arg_{i}", arg=self.serialize_input(arg)
172+
)
173+
)
174+
return serialized_args
175+
176+
def serialize_lowered_module(self, lowered_module_arg: torch.fx.Node) -> str:
177+
assert lowered_module_arg.op == "get_attr"
178+
assert isinstance(lowered_module_arg.target, str)
179+
180+
def serialize_bytes(b: bytes) -> str:
181+
# We want to serialize the bytes to string because JSON cannot
182+
# serialize bytes.
183+
# Since the given bytes may be serialized with any encoding, so we
184+
# want to first encode with base64, and then decode it with
185+
# ascii. During deserialization we can just directly decode with b64
186+
# to get the original encoded bytes.
187+
return base64.b64encode(b).decode("ascii")
188+
189+
lowered_module = getattr(
190+
lowered_module_arg.graph.owning_module, lowered_module_arg.target
191+
)
192+
assert isinstance(lowered_module, delegate.LoweredBackendModule)
193+
194+
serialized_compile_spec = [
195+
CompileSpec(cs.key, serialize_bytes(cs.value))
196+
for cs in lowered_module.compile_specs
197+
]
198+
199+
(
200+
serialized_original_module,
201+
serialized_original_state_dict,
202+
) = ExportedProgramSerializer().serialize(lowered_module.original_module)
203+
204+
serialized_processed_bytes = serialize_bytes(lowered_module.processed_bytes)
205+
206+
serialized_lowered_module = LoweredBackendModule(
207+
original_module=serialized_original_module,
208+
original_state_dict=serialize_bytes(serialized_original_state_dict),
209+
processed_bytes=serialized_processed_bytes,
210+
compile_specs=serialized_compile_spec,
211+
backend_id=lowered_module.backend_id,
212+
)
213+
214+
json_lowered_module = json.dumps(
215+
dataclasses.asdict(serialized_lowered_module),
216+
cls=export_serialize.EnumEncoder,
217+
)
218+
return json_lowered_module
219+
141220

142221
class ExportedProgramSerializer(export_serialize.ExportedProgramSerializer):
143222
def serialize(
@@ -186,6 +265,27 @@ def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> No
186265
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
187266
return
188267

268+
elif target is delegate.executorch_call_delegate:
269+
if (
270+
len(serialized_node.outputs) == 1
271+
and serialized_node.outputs[0].type == "as_tensor"
272+
):
273+
# If it's a single tensor return then we can use the name of the
274+
# node itself
275+
name = serialized_node.outputs[0].value.name
276+
else:
277+
# Otherwise FX will make a name for us, and we'll have `getitem`
278+
# nodes pointed to that
279+
name = None
280+
281+
args = self.deserialize_call_delegate_inputs(serialized_node.inputs)
282+
fx_node = self.graph.create_node("call_function", target, args, {}, name)
283+
284+
self.deserialize_arbitrary_outputs(serialized_node, fx_node)
285+
286+
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
287+
return
288+
189289
elif isinstance(target, str):
190290
# Create a dummy fake op if the target does not exist
191291
# because we cannot create a call_function node w/o a
@@ -267,6 +367,49 @@ def deserialize_input(self, inp: schema.Argument) -> Any:
267367

268368
return super().deserialize_input(inp)
269369

370+
# pyre-ignore
371+
def deserialize_call_delegate_inputs(
372+
self, serialized_inputs: List[schema.NamedArgument]
373+
):
374+
serialized_lowered_module = serialized_inputs[0]
375+
lowered_module_node = self.deserialize_lowered_module(serialized_lowered_module)
376+
serialized_delegate_inputs = serialized_inputs[1:]
377+
args = tuple(
378+
self.deserialize_input(input.arg) for input in serialized_delegate_inputs
379+
)
380+
return (lowered_module_node,) + args
381+
382+
def deserialize_lowered_module(
383+
self, serialized_lowered_module_arg: schema.NamedArgument
384+
) -> torch.fx.Node:
385+
assert serialized_lowered_module_arg.arg.type == "as_string"
386+
lowered_module_str = serialized_lowered_module_arg.arg.value
387+
json_lowered_module = json.loads(lowered_module_str)
388+
serialized_lowered_module = export_serialize._dict_to_dataclass(
389+
LoweredBackendModule, json_lowered_module
390+
)
391+
392+
backend_id = serialized_lowered_module.backend_id
393+
processed_bytes = base64.b64decode(serialized_lowered_module.processed_bytes)
394+
compile_specs = [
395+
delegate_CompileSpec(key=cs.key, value=base64.b64decode(cs.value))
396+
for cs in serialized_lowered_module.compile_specs
397+
]
398+
399+
original_module = ExportedProgramDeserializer().deserialize(
400+
serialized_lowered_module.original_module,
401+
base64.b64decode(serialized_lowered_module.original_state_dict),
402+
)
403+
404+
lowered_module = delegate.LoweredBackendModule(
405+
original_module,
406+
backend_id,
407+
processed_bytes,
408+
compile_specs,
409+
)
410+
self.module.register_module(serialized_lowered_module_arg.name, lowered_module)
411+
return self.graph.get_attr(serialized_lowered_module_arg.name)
412+
270413

271414
class ExportedProgramDeserializer(export_serialize.ExportedProgramDeserializer):
272415
def deserialize(

exir/tests/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ python_unittest(
9696
],
9797
deps = [
9898
"//caffe2:torch",
99+
"//executorch/backends:backend_api",
100+
"//executorch/backends/test:backend_with_compiler_demo",
101+
"//executorch/backends/test:op_partitioner_demo",
99102
"//executorch/exir:lib",
100103
"//executorch/exir/serde:serialize",
101104
],

exir/tests/test_serde.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
import executorch.exir as exir
77

88
import torch
9+
from executorch.backends.backend_api import CompileSpec, to_backend
10+
from executorch.backends.test.backend_with_compiler_demo import ( # noqa
11+
BackendWithCompilerDemo,
12+
)
13+
from executorch.backends.test.op_partitioner_demo import AddMulPartitionerDemo
914
from executorch.exir.serde.serialize import deserialize, serialize
1015
from torch._export.exported_program import ExportedProgram as TorchExportedProgram
1116
from torch.utils import _pytree as pytree
@@ -89,3 +94,62 @@ def get_random_inputs(self):
8994
model = MyModel()
9095
inputs = model.get_random_inputs()
9196
self.check_serde(model, inputs)
97+
98+
def test_delegate(self) -> None:
99+
class SinModule(torch.nn.Module):
100+
def __init__(self):
101+
super().__init__()
102+
103+
def forward(self, x):
104+
return torch.sin(x)
105+
106+
sin_module = SinModule()
107+
model_inputs = (torch.ones(1),)
108+
edgeir_m = exir.capture(
109+
sin_module, model_inputs, exir.CaptureConfig(pt2_mode=True)
110+
).to_edge()
111+
max_value = model_inputs[0].shape[0]
112+
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
113+
lowered_sin_module = to_backend(
114+
"BackendWithCompilerDemo", edgeir_m, compile_specs
115+
)
116+
117+
class CompositeModule(torch.nn.Module):
118+
def __init__(self):
119+
super().__init__()
120+
self.lowered_linear_sin = lowered_sin_module
121+
122+
def forward(self, x):
123+
return self.lowered_linear_sin(x)
124+
125+
composite_model = CompositeModule()
126+
model_inputs = (torch.ones(1),)
127+
128+
composite_model(*model_inputs)
129+
130+
aten = exir.capture(
131+
composite_model, model_inputs, exir.CaptureConfig(pt2_mode=True)
132+
)
133+
aten_new = deserialize(*serialize(aten))
134+
self.check_ep(aten, aten_new, model_inputs)
135+
136+
def test_delegate_partitioner(self) -> None:
137+
class Model(torch.nn.Module):
138+
def __init__(self):
139+
super().__init__()
140+
141+
def forward(self, a, x, b):
142+
y = torch.mm(a, x)
143+
z = y + b
144+
a = z - a
145+
y = torch.mm(a, x)
146+
z = y + b
147+
return z
148+
149+
m = Model()
150+
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
151+
152+
ep = exir.capture(m, inputs, exir.CaptureConfig(pt2_mode=True)).to_edge()
153+
edge = to_backend(ep, AddMulPartitionerDemo)
154+
edge_new = deserialize(*serialize(edge))
155+
self.check_ep(edge, edge_new, inputs)

0 commit comments

Comments
 (0)