Skip to content

Commit 36aea65

Browse files
lucylqfacebook-github-bot
authored andcommitted
Restore constant segment (#5141)
Summary: Pull Request resolved: #5141 Restore constant segment in deserialize_pte_binary. Note that programs are not identical afterwards, as we do not store the size of the constant buffer. Instead, the restored program will contain tensor+padding in each buffer. Reviewed By: dbort Differential Revision: D62278416
1 parent cd9d536 commit 36aea65

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

exir/_serialize/_program.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,24 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
553553
location=DataLocation.INLINE, index=data_index
554554
)
555555

556+
# Replace constants from constant_segment into constant_buffer.
557+
if program.constant_segment and len(program.constant_segment.offsets) > 0:
558+
buffers: List[Buffer] = []
559+
constant_segment = segments[program.constant_segment.segment_index]
560+
for i in range(len(program.constant_segment.offsets)):
561+
start_offset = program.constant_segment.offsets[i]
562+
# Note: this is the original end off set plus any padding between
563+
# it and the next start offset.
564+
end_offset = (
565+
program.constant_segment.offsets[i + 1]
566+
if i < len(program.constant_segment.offsets) - 1
567+
else len(constant_segment)
568+
)
569+
buffers.append(Buffer(storage=constant_segment[start_offset:end_offset]))
570+
program.constant_buffer = buffers
571+
program.constant_segment.segment_index = 0
572+
program.constant_segment.offsets = []
573+
556574
# Clear out the segments list since the original Program didn't have one.
557575
program.segments = []
558576
return program

exir/_serialize/test/test_program.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,14 @@ def constant_segment_with_tensor_alignment(
272272
f"{segment_table}",
273273
)
274274

275+
# Convert back.
276+
program2 = deserialize_pte_binary(pte_data)
277+
# Programs are the same besides constant_buffer, as deserialization
278+
# does not preserve constant segment.
279+
self.assertEqual(program2.execution_plan, program.execution_plan)
280+
# Number of constant tensors should be the same.
281+
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))
282+
275283
def test_canonicalize_delegate_indices(self) -> None:
276284
def make_execution_plan(
277285
name: str, delegates: List[BackendDelegate]
@@ -462,7 +470,6 @@ def gen_blob_data(size: int, pattern: bytes) -> bytes:
462470
assert len(ret) == size
463471
return ret
464472

465-
@unittest.skip("TODO(T181362263): Update restore segments to restore cords")
466473
def test_round_trip_with_segments(self) -> None:
467474
# Create a program with some delegate data blobs.
468475
program = get_test_program()
@@ -803,6 +810,11 @@ def test_constant_segment_and_delegate_segment(self) -> None:
803810
+ b"\x40\x44\x44",
804811
)
805812

813+
# Convert back.
814+
program2 = deserialize_pte_binary(pte_data)
815+
self.assertEqual(program2.execution_plan, program.execution_plan)
816+
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))
817+
806818

807819
# Common data for extended header tests. The two example values should produce
808820
# the example data.

0 commit comments

Comments
 (0)