|
13 | 13 |
|
14 | 14 | import torch
|
15 | 15 | import torch._export
|
16 |
| -from executorch.exir._serialize import _serialize_pte_binary |
17 | 16 | from executorch.exir._serialize._cord import Cord
|
| 17 | +from executorch.exir._serialize._serialize import serialize |
| 18 | +from executorch.exir._serialize.data_serializer import DataSerializer |
18 | 19 | from executorch.exir._warnings import experimental
|
19 | 20 | from executorch.exir.backend.backend_api import to_backend
|
20 | 21 | from executorch.exir.backend.partitioner import Partitioner
|
|
56 | 57 | EXIREdgeDialectVerifier,
|
57 | 58 | get_aten_verifier,
|
58 | 59 | )
|
| 60 | +from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer |
59 | 61 | from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
|
60 | 62 | from torch.export import ExportedProgram
|
61 | 63 | from torch.export._remove_auto_functionalized_pass import (
|
@@ -494,23 +496,23 @@ def __init__(
|
494 | 496 | )
|
495 | 497 | self.exported_program = exir_exported_program.exported_program
|
496 | 498 | self._pte_data: Optional[Cord] = None
|
| 499 | + self._data_files: Optional[Dict[str, Cord]] = None |
497 | 500 | self._buffer: Optional[bytes] = None
|
498 | 501 | self._emitter_output: Optional[EmitterOutput] = None
|
499 | 502 | self._emit_stacktrace: bool = emit_stacktrace
|
500 | 503 | self._extract_delegate_segments: bool = extract_delegate_segments
|
501 | 504 | self._segment_alignment: int = segment_alignment
|
502 | 505 | self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment
|
503 | 506 | self._delegate_alignment: Optional[int] = delegate_alignment
|
| 507 | + self._data_serializer: DataSerializer = FlatTensorSerializer() |
504 | 508 |
|
505 | 509 | def _get_pte_data(self) -> Cord:
|
506 | 510 | if self._pte_data is None:
|
507 |
| - self._pte_data = _serialize_pte_binary( |
508 |
| - program=self.program, |
509 |
| - extract_delegate_segments=self._extract_delegate_segments, |
510 |
| - segment_alignment=self._segment_alignment, |
511 |
| - constant_tensor_alignment=self._constant_tensor_alignment, |
512 |
| - delegate_alignment=self._delegate_alignment, |
| 511 | + assert self._emitter_output is not None |
| 512 | + self._pte_data, self._data_files = serialize( |
| 513 | + self._emitter_output, ExecutorchBackendConfig(), self._data_serializer |
513 | 514 | )
|
| 515 | + assert self._pte_data is not None |
514 | 516 | return self._pte_data
|
515 | 517 |
|
516 | 518 | @property
|
@@ -1443,14 +1445,11 @@ def __init__(
|
1443 | 1445 | self._config_methods,
|
1444 | 1446 | )
|
1445 | 1447 |
|
| 1448 | + self._data_serializer = FlatTensorSerializer() |
| 1449 | + |
1446 | 1450 | # Serialize emitter output, ready to be written to a file.
|
1447 |
| - self._pte_data: Cord = _serialize_pte_binary( |
1448 |
| - program=self._emitter_output.program, |
1449 |
| - mutable_data=self._emitter_output.mutable_data, |
1450 |
| - extract_delegate_segments=backend_config.extract_delegate_segments, |
1451 |
| - segment_alignment=backend_config.segment_alignment, |
1452 |
| - constant_tensor_alignment=backend_config.constant_tensor_alignment, |
1453 |
| - delegate_alignment=backend_config.delegate_alignment, |
| 1451 | + self._pte_data, self._data_files = serialize( |
| 1452 | + self._emitter_output, ExecutorchBackendConfig(), self._data_serializer |
1454 | 1453 | )
|
1455 | 1454 | self._buffer: Optional[bytes] = None
|
1456 | 1455 |
|
@@ -1532,3 +1531,8 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
|
1532 | 1531 | reducing the peak memory usage.
|
1533 | 1532 | """
|
1534 | 1533 | self._pte_data.write_to_file(open_file)
|
| 1534 | + |
| 1535 | + for filename, cord in self._data_files.items(): |
| 1536 | + filename = filename + ".ptd" |
| 1537 | + with open(filename, "wb") as file: |
| 1538 | + cord.write_to_file(file) |
0 commit comments