Skip to content

Introduce write_to_file api #2307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def lower_module_and_test_output(
VulkanBackend.__name__,
)

# TODO(T181494963): update pybinding when we remove buffer cache.
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
inputs_flattened, _ = tree_flatten(sample_inputs)
Expand Down
5 changes: 1 addition & 4 deletions examples/apple/mps/scripts/mps_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,10 @@
if not args.use_fp16:
extension = "fp32"
model_name = f"{model_name}_{extension}"
program_buffer = bundled_program_buffer
else:
program_buffer = executorch_program.buffer

if args.generate_etrecord:
etrecord_path = "etrecord.bin"
logging.info("generating etrecord.bin")
generate_etrecord(etrecord_path, edge_program_manager_copy, executorch_program)

save_pte_program(program_buffer, model_name)
save_pte_program(executorch_program, model_name)
2 changes: 1 addition & 1 deletion examples/models/llama2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,4 +343,4 @@ def save_to_pte(self, output_name: str) -> None:
output_name (Optional[str]): The name of the .pte file.
"""
assert output_name, "Need a valid output name"
save_pte_program(self.export_program.buffer, output_name, self.output_dir)
save_pte_program(self.export_program, output_name, self.output_dir)
2 changes: 1 addition & 1 deletion examples/portable/custom_ops/custom_ops_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main():
(input,),
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
save_pte_program(prog.buffer, model_name)
save_pte_program(prog, model_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/portable/custom_ops/custom_ops_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main():
(input,),
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
save_pte_program(prog.buffer, model_name)
save_pte_program(prog, model_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/portable/scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main() -> None:
dynamic_shapes=dynamic_shapes,
backend_config=backend_config,
)
save_pte_program(prog.buffer, args.model_name, args.output_dir)
save_pte_program(prog, args.model_name, args.output_dir)


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions examples/portable/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,13 @@ def export_to_exec_prog(
return exec_prog


def save_pte_program(buffer: bytes, model_name: str, output_dir: str = "") -> None:
def save_pte_program(
prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
) -> None:
filename = os.path.join(output_dir, f"{model_name}.pte")
try:
with open(filename, "wb") as file:
file.write(buffer)
prog.write_to_file(file)
logging.info(f"Saved exported program to {filename}")
except Exception as e:
logging.error(f"Error while saving to {filename}: {e}")
2 changes: 1 addition & 1 deletion examples/qualcomm/scripts/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@
executorch_program = delegated_program.to_executorch(
config=ExecutorchBackendConfig(extract_constant_segment=False)
)
save_pte_program(executorch_program.buffer, args.model_name)
save_pte_program(executorch_program, args.model_name)
2 changes: 1 addition & 1 deletion examples/xnnpack/aot_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,4 @@

quant_tag = "q8" if args.quantize else "fp32"
model_name = f"{args.model_name}_xnnpack_{quant_tag}"
save_pte_program(exec_prog.buffer, model_name, args.output_dir)
save_pte_program(exec_prog, model_name, args.output_dir)
2 changes: 1 addition & 1 deletion examples/xnnpack/quantization/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def main() -> None:
prog = edge_m.to_executorch(
config=ExecutorchBackendConfig(extract_constant_segment=False)
)
save_pte_program(prog.buffer, f"{args.model_name}_quantized")
save_pte_program(prog, f"{args.model_name}_quantized")
end = time.perf_counter()
logging.info(f"Save time: {end - start}s")
logging.info("finished")
Expand Down
2 changes: 1 addition & 1 deletion examples/xtensa/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ def forward(self, x: torch.Tensor):
logging.info(f"Final exported graph:\n{exec_prog.exported_program().graph}")

# Save the program as XtensaDemoModel.pte
save_pte_program(exec_prog.buffer, "XtensaDemoModel")
save_pte_program(exec_prog, "XtensaDemoModel")
8 changes: 3 additions & 5 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def serialize_pte_binary(
segment_alignment: int = 4096,
constant_tensor_alignment: Optional[int] = None,
delegate_alignment: Optional[int] = None,
) -> bytes:
) -> Cord:
"""Returns the runtime binary representation of the given Program.

Args:
Expand Down Expand Up @@ -429,7 +429,7 @@ def serialize_pte_binary(

# If there are no segments present, do not insert the extended header.
if len(segments_data) == 0:
return result.data
return Cord(result.data)

# Size of the header to insert. Its size is padded to the largest
# force_align value present in the schema.
Expand Down Expand Up @@ -482,9 +482,7 @@ def serialize_pte_binary(
len(pte_data) == segment_base_offset
), f"Offset of first segment {len(pte_data)} != segment base offset {segment_base_offset}"
pte_data.append(segments_data)

# TODO(lfq): this creates a copy of all the data; once we update existing callsites this will change.
return bytes(pte_data)
return pte_data


def _restore_segments(program: Program, segment_data: bytes) -> Program:
Expand Down
52 changes: 33 additions & 19 deletions exir/_serialize/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,13 @@ def constant_segment_with_tensor_alignment(
add_constant_data(program, blobs)

# Extract blobs into constant segment during serialization.
pte_data = serialize_pte_binary(
program,
extract_constant_segment=True,
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=constant_tensor_alignment,
pte_data = bytes(
serialize_pte_binary(
program,
extract_constant_segment=True,
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=constant_tensor_alignment,
)
)

# The input Program should not be modified.
Expand Down Expand Up @@ -395,7 +397,7 @@ def test_round_trip_no_header_no_segments(self) -> None:
deserializing.
"""
program = get_test_program()
pte_data = serialize_pte_binary(program)
pte_data = bytes(serialize_pte_binary(program))
self.assertGreater(len(pte_data), 16)

# File magic should be present at the expected offset.
Expand All @@ -418,7 +420,7 @@ def test_round_trip_large_buffer_sizes(self) -> None:
"""
program = get_test_program()
program.execution_plan[0].non_const_buffer_sizes = [0, 2**48]
flatbuffer_from_py = serialize_pte_binary(program)
flatbuffer_from_py = bytes(serialize_pte_binary(program))
self.assert_programs_equal(program, deserialize_pte_binary(flatbuffer_from_py))

def test_round_trip_no_segments_and_no_header(self) -> None:
Expand All @@ -428,8 +430,10 @@ def test_round_trip_no_segments_and_no_header(self) -> None:
that a Program remains the same after serializing and deserializing.
"""
program = get_test_program()
pte_data = serialize_pte_binary(
program, extract_delegate_segments=True, extract_constant_segment=True
pte_data = bytes(
serialize_pte_binary(
program, extract_delegate_segments=True, extract_constant_segment=True
)
)
self.assertGreater(len(pte_data), 16)

Expand Down Expand Up @@ -477,8 +481,12 @@ def test_round_trip_with_segments(self) -> None:
add_delegate_data(program, program.execution_plan[0], blobs)

# Extract the blobs into segments during serialization.
pte_data = serialize_pte_binary(
program, extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT
pte_data = bytes(
serialize_pte_binary(
program,
extract_delegate_segments=True,
segment_alignment=SEGMENT_ALIGNMENT,
)
)

# The input Program should not have been modified.
Expand Down Expand Up @@ -588,8 +596,12 @@ def test_unused_inline_delegate_blobs_with_segments(self) -> None:
add_delegate_data(program, program.execution_plan[0], blobs)

# Extract the blobs into segments should succeeed.
pte_data = serialize_pte_binary(
program, extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT
pte_data = bytes(
serialize_pte_binary(
program,
extract_delegate_segments=True,
segment_alignment=SEGMENT_ALIGNMENT,
)
)
self.assertGreater(len(pte_data), 16)

Expand Down Expand Up @@ -644,12 +656,14 @@ def test_constant_segment_and_delegate_segment(self) -> None:
add_delegate_data(program, program.execution_plan[0], delegate_blobs)

# Extract the blobs into segments during serialization.
pte_data = serialize_pte_binary(
program,
extract_delegate_segments=True,
extract_constant_segment=True,
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
pte_data = bytes(
serialize_pte_binary(
program,
extract_delegate_segments=True,
extract_constant_segment=True,
segment_alignment=SEGMENT_ALIGNMENT,
constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
)
)

# The input Program should not be modified.
Expand Down
15 changes: 9 additions & 6 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def original_module(self) -> ExportedProgram:
return self._original_exported_program

# TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
# TODO(T181463742): avoid calling bytes(..) which incurs large copies.
def buffer(
self,
extract_delegate_segments: bool = False,
Expand All @@ -141,12 +142,14 @@ def buffer(
"""
Returns a buffer containing the serialized ExecuTorch binary.
"""
out = _serialize_pte_binary(
program=self.program(),
extract_delegate_segments=extract_delegate_segments,
segment_alignment=segment_alignment,
constant_tensor_alignment=constant_tensor_alignment,
delegate_alignment=delegate_alignment,
out = bytes(
_serialize_pte_binary(
program=self.program(),
extract_delegate_segments=extract_delegate_segments,
segment_alignment=segment_alignment,
constant_tensor_alignment=constant_tensor_alignment,
delegate_alignment=delegate_alignment,
)
)
return out

Expand Down
Loading