Skip to content

Commit fd76eea

Browse files
committed
Add pass to tag external constants for delegates
Pull Request resolved: #10328 generate pte+ptd file for a delegated linear example ghstack-source-id: 279396874 Differential Revision: [D73281924](https://our.internmc.facebook.com/intern/diff/D73281924/)
1 parent 08c5d93 commit fd76eea

File tree

5 files changed

+80
-3
lines changed

5 files changed

+80
-3
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,8 +592,13 @@ def get_serialized_buffer_index(
592592
xnn_graph.constant_data.append(
593593
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
594594
)
595+
596+
external_tag = tensor.meta.get("xnnpack_constant_tag", None)
595597
self._named_data_store.add_named_data(
596-
named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT
598+
named_key,
599+
bytes(array),
600+
alignment=CONSTANT_TENSOR_ALIGNMENT,
601+
external_tag=external_tag,
597602
)
598603

599604
return buffer_idx

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,9 @@ const uint8_t* getConstantDataPtr(
204204
if (!buffer.ok()) {
205205
ET_LOG(
206206
Error,
207-
"Failed to get constant data for key %s",
208-
data_name.c_str());
207+
"Failed to get constant data for key %s from named_data_map. Error code: %u",
208+
data_name.c_str(),
209+
static_cast<uint32_t>(buffer.error()));
209210
return nullptr;
210211
}
211212
const uint8_t* data_ptr =

exir/passes/external_constants_pass.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
from typing import Callable, Optional
10+
911
import torch
1012
from executorch.exir.pass_base import PassResult
1113
from executorch.exir.tensor import TensorSpec
@@ -74,3 +76,28 @@ def external_mutable_weights_pass(
7476
node.meta["constant_tag"] = "_default_external_constant"
7577
mutated = True
7678
return PassResult(gm, mutated)
79+
80+
81+
def xnnpack_external_constants_pass(
82+
gm: GraphModule,
83+
filter_fn: Optional[Callable[[torch.fx.Node], str]] = None,
84+
) -> PassResult:
85+
"""
86+
Tag external constants before to_backend. Tagged constants will be saved
87+
to an external file.
88+
89+
Args:
90+
gm: GraphModule to tag.
91+
filter_fn: node -> str callable indicating the file (str) that a node should be saved to.
92+
Returns:
93+
PassResult: The resulting gm, and if it was mutated or not.
94+
"""
95+
mutated = False
96+
for module in gm.modules():
97+
if not isinstance(module, torch.fx.GraphModule):
98+
continue
99+
for node in module.graph.nodes:
100+
if node.op == "placeholder" and filter_fn is not None:
101+
node.meta["xnnpack_constant_tag"] = filter_fn(node)
102+
mutated = True
103+
return PassResult(gm, mutated)

test/models/export_delegated_program.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import inspect
1111
import os
1212
import sys
13+
14+
from functools import partial
1315
from typing import Dict, final, Optional, Sequence, Type
1416

1517
import executorch.exir as exir
@@ -21,6 +23,9 @@
2123
from executorch.exir.backend.test.backend_with_compiler_demo import (
2224
BackendWithCompilerDemo,
2325
)
26+
from executorch.exir.passes.external_constants_pass import (
27+
xnnpack_external_constants_pass,
28+
)
2429
from executorch.exir.program import ExecutorchProgramManager
2530
from torch import nn
2631
from torch.export import export
@@ -129,6 +134,7 @@ def export_module_to_program(
129134
constant_tensor_alignment: Optional[int] = None,
130135
delegate_alignment: Optional[int] = None,
131136
method_name: str = "forward",
137+
external_constants: bool = False,
132138
) -> ExecutorchProgramManager:
133139
eager_module = module_class().eval()
134140
inputs = ()
@@ -158,8 +164,16 @@ def forward(self, *args, **kwargs):
158164
XnnpackPartitioner,
159165
)
160166

167+
transform_passes = []
168+
if external_constants:
169+
partial_function = partial(
170+
xnnpack_external_constants_pass,
171+
filter_fn=lambda x: module_class.__name__,
172+
)
173+
transform_passes.append(partial_function)
161174
executorch_program = to_edge_transform_and_lower(
162175
exported_program,
176+
transform_passes=transform_passes,
163177
compile_config=edge_config,
164178
partitioner=[XnnpackPartitioner()],
165179
).to_executorch(config=et_config)
@@ -221,6 +235,11 @@ def main() -> None:
221235
parser.add_argument(
222236
"--delegate_alignment", type=int, default=None, help="Delegate alignment."
223237
)
238+
parser.add_argument(
239+
"--external_constants",
240+
action="store_true",
241+
help="Export the model with all constants saved to an external file.",
242+
)
224243
parser.add_argument(
225244
"--outdir",
226245
type=str,
@@ -247,16 +266,22 @@ def main() -> None:
247266
suffix += "-nosegments"
248267
if args.delegate_alignment is not None:
249268
suffix += f"-da{args.delegate_alignment}"
269+
if args.external_constants:
270+
suffix += f"-e"
250271
outfile = os.path.join(args.outdir, f"{module_name}{suffix}.pte")
251272
executorch_program = export_module_to_program(
252273
module_class,
253274
backend_id=args.backend_id,
254275
extract_delegate_segments=not args.inline_delegate_segments,
255276
delegate_alignment=args.delegate_alignment,
277+
external_constants=args.external_constants,
256278
)
257279
with open(outfile, "wb") as fp:
258280
fp.write(executorch_program.buffer)
259281
print(f"Exported {module_name} and wrote program data to {outfile}")
282+
if args.external_constants:
283+
print(f"Saving external constants to {module_name}.ptd")
284+
executorch_program.write_tensor_data_to_file(args.outdir)
260285

261286

262287
if __name__ == "__main__":

test/models/targets.bzl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,22 @@ def define_common_targets():
206206
],
207207
env = {"PYTORCH_DISABLE_JUSTKNOBS": "1",},
208208
)
209+
210+
runtime.genrule(
211+
name = "exported_program_data",
212+
cmd = "$(exe :export_delegated_program)" +
213+
" --modules ModuleLinear" +
214+
" --backend_id XnnpackBackend" +
215+
" --external_constants" +
216+
" --outdir $OUT",
217+
218+
outs = {
219+
"ModuleLinear-e.pte": ["ModuleLinear-e.pte"],
220+
"ModuleLinear.ptd": ["ModuleLinear.ptd"],
221+
},
222+
default_outs = ["."],
223+
visibility = [
224+
"//executorch/runtime/executor/test/...",
225+
"//executorch/test/...",
226+
],
227+
)

0 commit comments

Comments
 (0)