Skip to content

Commit 5316e69

Browse files
lucylqfacebook-github-bot
authored andcommitted
Introduce write_to_file api (#2307)
Summary: Pull Request resolved: #2307 Update callsites that save to file to use write_to_file api instead of .buffer bypass-github-export-checks Reviewed By: dbort Differential Revision: D54526788 fbshipit-source-id: 6b4975f3fd7fd6c74b97a486a2f58aa62a7b2a71
1 parent e44d5b2 commit 5316e69

File tree

11 files changed

+44
-17
lines changed

11 files changed

+44
-17
lines changed

examples/apple/mps/scripts/mps_example.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,10 @@
143143
if not args.use_fp16:
144144
extension = "fp32"
145145
model_name = f"{model_name}_{extension}"
146-
program_buffer = bundled_program_buffer
147-
else:
148-
program_buffer = executorch_program.buffer
149146

150147
if args.generate_etrecord:
151148
etrecord_path = "etrecord.bin"
152149
logging.info("generating etrecord.bin")
153150
generate_etrecord(etrecord_path, edge_program_manager_copy, executorch_program)
154151

155-
save_pte_program(program_buffer, model_name)
152+
save_pte_program(executorch_program, model_name)

examples/models/llama2/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,4 +343,4 @@ def save_to_pte(self, output_name: str) -> None:
343343
output_name (Optional[str]): The name of the .pte file.
344344
"""
345345
assert output_name, "Need a valid output name"
346-
save_pte_program(self.export_program.buffer, output_name, self.output_dir)
346+
save_pte_program(self.export_program, output_name, self.output_dir)

examples/portable/custom_ops/custom_ops_1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main():
5151
(input,),
5252
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
5353
)
54-
save_pte_program(prog.buffer, model_name)
54+
save_pte_program(prog, model_name)
5555

5656

5757
if __name__ == "__main__":

examples/portable/custom_ops/custom_ops_2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def main():
3232
(input,),
3333
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
3434
)
35-
save_pte_program(prog.buffer, model_name)
35+
save_pte_program(prog, model_name)
3636

3737

3838
if __name__ == "__main__":

examples/portable/scripts/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def main() -> None:
7171
dynamic_shapes=dynamic_shapes,
7272
backend_config=backend_config,
7373
)
74-
save_pte_program(prog.buffer, args.model_name, args.output_dir)
74+
save_pte_program(prog, args.model_name, args.output_dir)
7575

7676

7777
if __name__ == "__main__":

examples/portable/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,13 @@ def export_to_exec_prog(
9898
return exec_prog
9999

100100

101-
def save_pte_program(buffer: bytes, model_name: str, output_dir: str = "") -> None:
101+
def save_pte_program(
102+
prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
103+
) -> None:
102104
filename = os.path.join(output_dir, f"{model_name}.pte")
103105
try:
104106
with open(filename, "wb") as file:
105-
file.write(buffer)
107+
prog.write_to_file(file)
106108
logging.info(f"Saved exported program to {filename}")
107109
except Exception as e:
108110
logging.error(f"Error while saving to {filename}: {e}")

examples/qualcomm/scripts/export_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,4 @@
7474
executorch_program = delegated_program.to_executorch(
7575
config=ExecutorchBackendConfig(extract_constant_segment=False)
7676
)
77-
save_pte_program(executorch_program.buffer, args.model_name)
77+
save_pte_program(executorch_program, args.model_name)

examples/xnnpack/aot_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,4 @@
111111

112112
quant_tag = "q8" if args.quantize else "fp32"
113113
model_name = f"{args.model_name}_xnnpack_{quant_tag}"
114-
save_pte_program(exec_prog.buffer, model_name, args.output_dir)
114+
save_pte_program(exec_prog, model_name, args.output_dir)

examples/xnnpack/quantization/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def main() -> None:
193193
prog = edge_m.to_executorch(
194194
config=ExecutorchBackendConfig(extract_constant_segment=False)
195195
)
196-
save_pte_program(prog.buffer, f"{args.model_name}_quantized")
196+
save_pte_program(prog, f"{args.model_name}_quantized")
197197
end = time.perf_counter()
198198
logging.info(f"Save time: {end - start}s")
199199
logging.info("finished")

examples/xtensa/aot/export_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,4 @@ def forward(self, x: torch.Tensor):
9090
logging.info(f"Final exported graph:\n{exec_prog.exported_program().graph}")
9191

9292
# Save the program as XtensaDemoModel.pte
93-
save_pte_program(exec_prog.buffer, "XtensaDemoModel")
93+
save_pte_program(exec_prog, "XtensaDemoModel")

exir/program/_program.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8+
import io
89
import logging
910
from typing import Any, Dict, List, Optional, Sequence, Set, Union
1011

@@ -437,7 +438,8 @@ def buffer(self) -> bytes:
437438
"""Returns the serialized ExecuTorch binary as a byte string.
438439
439440
Note that the call to `buffer` may allocate a very large amount of
440-
contiguous memory, depending on the model size.
441+
contiguous memory, depending on the model size. If writing to a file,
442+
use `write_to_file` which won't incur additional copies.
441443
"""
442444
# TODO(T181494963): update pybinding to remove buffer cache, which can consume large
443445
# amounts of memory longer than necessary.
@@ -478,6 +480,14 @@ def dump_graph_module(self) -> torch.fx.GraphModule:
478480
def dump_exported_program(self) -> ExportedProgram:
479481
return self.exported_program
480482

483+
def write_to_file(self, open_file: io.BufferedIOBase) -> None:
484+
"""
485+
Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over
486+
`buffer`, as it writes to file without copying into a contiguous block of memory first,
487+
reducing the peak memory usage.
488+
"""
489+
self._get_pte_data().write_to_file(open_file)
490+
481491

482492
def _get_aten_to_edge_passes(config: EdgeCompileConfig):
483493
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
@@ -769,7 +779,8 @@ def buffer(self) -> bytes:
769779
"""Returns the serialized ExecuTorch binary as a byte string.
770780
771781
Note that the call to `buffer` may allocate a very large amount of
772-
contiguous memory, depending on the model size.
782+
contiguous memory, depending on the model size. If writing to a file,
783+
use `write_to_file` which won't incur additional copies.
773784
"""
774785
# TODO(T181494963): update pybinding to remove buffer cache, which can consume large
775786
# amounts of memory longer than necessary.
@@ -800,6 +811,14 @@ def dump_graph_module(self) -> torch.fx.GraphModule:
800811
def get_multi_method_graph_module(self) -> "MultiMethodExirExportedProgram":
801812
return self._executorch_dialect_ir_program
802813

814+
def write_to_file(self, open_file: io.BufferedIOBase) -> None:
815+
"""
816+
Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over
817+
`buffer`, as it writes to file without copying into a contiguous block of memory first,
818+
reducing the peak memory usage.
819+
"""
820+
self._get_pte_data().write_to_file(open_file)
821+
803822

804823
# TODO(T152006915): Merge this into to_executorch and then delete it.
805824
def multi_method_program_to_executorch(
@@ -1210,10 +1229,19 @@ def buffer(self) -> bytes:
12101229
"""Returns the serialized ExecuTorch binary as a byte string.
12111230
12121231
Note that the call to `buffer` may allocate a very large amount of
1213-
contiguous memory, depending on the model size.
1232+
contiguous memory, depending on the model size. If writing to a file,
1233+
use `write_to_file` which won't incur additional copies.
12141234
"""
12151235
# TODO(T181494963): update pybinding to remove buffer cache, which can consume large
12161236
# amounts of memory longer than necessary.
12171237
if self._buffer is None:
12181238
self._buffer = bytes(self._pte_data)
12191239
return self._buffer
1240+
1241+
def write_to_file(self, open_file: io.BufferedIOBase) -> None:
1242+
"""
1243+
Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over
1244+
`buffer`, as it writes to file without copying into a contiguous block of memory first,
1245+
reducing the peak memory usage.
1246+
"""
1247+
self._pte_data.write_to_file(open_file)

0 commit comments

Comments
 (0)