Skip to content

Commit 549f14b

Browse files
authored
Restore constant segment
Differential Revision: D62278416 Pull Request resolved: #5141
1 parent 657789e commit 549f14b

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

exir/_serialize/_program.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,24 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
553553
location=DataLocation.INLINE, index=data_index
554554
)
555555

556+
# Replace constants from constant_segment into constant_buffer.
557+
if program.constant_segment and len(program.constant_segment.offsets) > 0:
558+
buffers: List[Buffer] = []
559+
constant_segment = segments[program.constant_segment.segment_index]
560+
for i in range(len(program.constant_segment.offsets)):
561+
start_offset = program.constant_segment.offsets[i]
562+
# Note: this is the original end offset plus any padding between
563+
# it and the next start offset.
564+
end_offset = (
565+
program.constant_segment.offsets[i + 1]
566+
if i < len(program.constant_segment.offsets) - 1
567+
else len(constant_segment)
568+
)
569+
buffers.append(Buffer(storage=constant_segment[start_offset:end_offset]))
570+
program.constant_buffer = buffers
571+
program.constant_segment.segment_index = 0
572+
program.constant_segment.offsets = []
573+
556574
# Clear out the segments list since the original Program didn't have one.
557575
program.segments = []
558576
return program

exir/_serialize/test/test_program.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,15 @@ def constant_segment_with_tensor_alignment(
272272
f"{segment_table}",
273273
)
274274

275+
# Convert back.
276+
program2 = deserialize_pte_binary(pte_data)
277+
# Programs are the same besides constant_buffer, as deserialization
278+
# does not preserve constant segment; padding may be added
279+
# during serialization.
280+
self.assertEqual(program2.execution_plan, program.execution_plan)
281+
# Number of constant tensors should be the same.
282+
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))
283+
275284
def test_canonicalize_delegate_indices(self) -> None:
276285
def make_execution_plan(
277286
name: str, delegates: List[BackendDelegate]
@@ -462,7 +471,6 @@ def gen_blob_data(size: int, pattern: bytes) -> bytes:
462471
assert len(ret) == size
463472
return ret
464473

465-
@unittest.skip("TODO(T181362263): Update restore segments to restore cords")
466474
def test_round_trip_with_segments(self) -> None:
467475
# Create a program with some delegate data blobs.
468476
program = get_test_program()
@@ -803,6 +811,15 @@ def test_constant_segment_and_delegate_segment(self) -> None:
803811
+ b"\x40\x44\x44",
804812
)
805813

814+
# Convert back.
815+
program2 = deserialize_pte_binary(pte_data)
816+
# Programs are the same besides constant_buffer, as deserialization
817+
# does not preserve constant segment; padding may be added
818+
# during serialization.
819+
self.assertEqual(program2.execution_plan, program.execution_plan)
820+
# Number of constant tensors should be the same.
821+
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))
822+
806823

807824
# Common data for extended header tests. The two example values should produce
808825
# the example data.

0 commit comments

Comments
 (0)