Skip to content

Restore constant segment #5141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,24 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
location=DataLocation.INLINE, index=data_index
)

# Replace constants from constant_segment into constant_buffer.
if program.constant_segment and len(program.constant_segment.offsets) > 0:
buffers: List[Buffer] = []
constant_segment = segments[program.constant_segment.segment_index]
for i in range(len(program.constant_segment.offsets)):
start_offset = program.constant_segment.offsets[i]
# Note: this is the original end offset plus any padding between
# it and the next start offset.
end_offset = (
program.constant_segment.offsets[i + 1]
if i < len(program.constant_segment.offsets) - 1
else len(constant_segment)
)
buffers.append(Buffer(storage=constant_segment[start_offset:end_offset]))
program.constant_buffer = buffers
program.constant_segment.segment_index = 0
program.constant_segment.offsets = []

# Clear out the segments list since the original Program didn't have one.
program.segments = []
return program
Expand Down
19 changes: 18 additions & 1 deletion exir/_serialize/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,15 @@ def constant_segment_with_tensor_alignment(
f"{segment_table}",
)

# Convert back.
program2 = deserialize_pte_binary(pte_data)
# Programs are the same besides constant_buffer, as deserialization
# does not preserve constant segment; padding may be added
# during serialization.
self.assertEqual(program2.execution_plan, program.execution_plan)
# Number of constant tensors should be the same.
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))

def test_canonicalize_delegate_indices(self) -> None:
def make_execution_plan(
name: str, delegates: List[BackendDelegate]
Expand Down Expand Up @@ -462,7 +471,6 @@ def gen_blob_data(size: int, pattern: bytes) -> bytes:
assert len(ret) == size
return ret

@unittest.skip("TODO(T181362263): Update restore segments to restore cords")
def test_round_trip_with_segments(self) -> None:
# Create a program with some delegate data blobs.
program = get_test_program()
Expand Down Expand Up @@ -803,6 +811,15 @@ def test_constant_segment_and_delegate_segment(self) -> None:
+ b"\x40\x44\x44",
)

# Convert back.
program2 = deserialize_pte_binary(pte_data)
# Programs are the same besides constant_buffer, as deserialization
# does not preserve constant segment; padding may be added
# during serialization.
self.assertEqual(program2.execution_plan, program.execution_plan)
# Number of constant tensors should be the same.
self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))


# Common data for extended header tests. The two example values should produce
# the example data.
Expand Down
Loading