@@ -272,6 +272,15 @@ def constant_segment_with_tensor_alignment(
272
272
f"{ segment_table } " ,
273
273
)
274
274
275
+ # Convert back.
276
+ program2 = deserialize_pte_binary (pte_data )
277
+ # Programs are the same besides constant_buffer, as deserialization
278
+ # does not preserve constant segment; padding may be added
279
+ # during serialization.
280
+ self .assertEqual (program2 .execution_plan , program .execution_plan )
281
+ # Number of constant tensors should be the same.
282
+ self .assertEqual (len (program2 .constant_buffer ), len (program .constant_buffer ))
283
+
275
284
def test_canonicalize_delegate_indices (self ) -> None :
276
285
def make_execution_plan (
277
286
name : str , delegates : List [BackendDelegate ]
@@ -462,7 +471,6 @@ def gen_blob_data(size: int, pattern: bytes) -> bytes:
462
471
assert len (ret ) == size
463
472
return ret
464
473
465
- @unittest .skip ("TODO(T181362263): Update restore segments to restore cords" )
466
474
def test_round_trip_with_segments (self ) -> None :
467
475
# Create a program with some delegate data blobs.
468
476
program = get_test_program ()
@@ -803,6 +811,15 @@ def test_constant_segment_and_delegate_segment(self) -> None:
803
811
+ b"\x40 \x44 \x44 " ,
804
812
)
805
813
814
+ # Convert back.
815
+ program2 = deserialize_pte_binary (pte_data )
816
+ # Programs are the same besides constant_buffer, as deserialization
817
+ # does not preserve constant segment; padding may be added
818
+ # during serialization.
819
+ self .assertEqual (program2 .execution_plan , program .execution_plan )
820
+ # Number of constant tensors should be the same.
821
+ self .assertEqual (len (program2 .constant_buffer ), len (program .constant_buffer ))
822
+
806
823
807
824
# Common data for extended header tests. The two example values should produce
808
825
# the example data.
0 commit comments