Skip to content

Commit 4583941

Browse files
lucylqfacebook-github-bot
authored andcommitted
Add cord data structure (#2273)
Summary: Introduce cord data structure to store bytes/bytearrays during serialization. This allows us to manipulate bytes/bytearrays without copying data. Reviewed By: dbort Differential Revision: D54514244
1 parent 0de3a97 commit 4583941

File tree

4 files changed

+111
-0
lines changed

4 files changed

+111
-0
lines changed

exir/_serialize/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ runtime.python_library(
2929
name = "lib",
3030
srcs = [
3131
"__init__.py",
32+
"_cord.py",
3233
"_dataclass.py",
3334
"_flatbuffer.py",
3435
"_program.py",

exir/_serialize/_cord.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import io
8+
from typing import List, Union
9+
10+
11+
class Cord:
12+
"""A `bytes`-like sequence of bytes, stored non-contiguously.
13+
14+
Users can use a Cord to assemble large files and data blobs using references
15+
to and slices of other data, instead of copying and appending that data to a
16+
`bytes` or `bytearray` object.
17+
"""
18+
19+
def __init__(
20+
self,
21+
) -> None:
22+
"""Initialize Cord data structure."""
23+
self._buffers: List[bytes] = []
24+
self._byte_size: int = 0
25+
26+
def __len__(self):
27+
"""Number of bytes in the Cord."""
28+
return self._byte_size
29+
30+
def __bytes__(self) -> bytes:
31+
"""Return the contents of the Cord as a single `bytes` object."""
32+
return b"".join(self._buffers)
33+
34+
def append(self, data: Union[bytes, "Cord"]) -> None:
35+
"""Append a byte or Cord to the current Cord."""
36+
if isinstance(data, bytes):
37+
self._buffers.append(data)
38+
self._byte_size += len(data)
39+
elif isinstance(data, Cord):
40+
self._buffers.extend(data._buffers)
41+
self._byte_size += len(data)
42+
else:
43+
raise TypeError(f"Can only append bytes or Cords, received {type(data)}")
44+
45+
def write_to_file(self, outfile: io.BufferedIOBase) -> None:
46+
"""Write the Cord to a file."""
47+
for item in self._buffers:
48+
outfile.write(item)

exir/_serialize/test/TARGETS

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,13 @@ python_unittest(
2323
"//executorch/exir/_serialize:lib",
2424
],
2525
)
26+
27+
python_unittest(
28+
name = "cord",
29+
srcs = [
30+
"test_cord.py",
31+
],
32+
deps = [
33+
"//executorch/exir/_serialize:lib",
34+
],
35+
)

exir/_serialize/test/test_cord.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import io
9+
import unittest
10+
11+
from executorch.exir._serialize._cord import Cord
12+
13+
14+
class TestCord(unittest.TestCase):
15+
def test_cord_init(self) -> None:
16+
cord = Cord()
17+
self.assertEqual(0, len(cord))
18+
19+
def test_cord_append(self) -> None:
20+
cord = Cord()
21+
cord.append(b"Hello")
22+
self.assertEqual(5, len(cord))
23+
self.assertEqual(b"Hello", bytes(cord))
24+
25+
cord.append(b"World")
26+
self.assertEqual(10, len(cord))
27+
self.assertEqual(b"HelloWorld", bytes(cord))
28+
29+
def test_cord_append_cord(self) -> None:
30+
cord = Cord()
31+
cord.append(b"Hello")
32+
cord.append((b"World"))
33+
34+
cord2 = Cord()
35+
cord2.append(b"Prefix")
36+
cord2.append(cord)
37+
38+
self.assertEqual(16, len(cord2))
39+
self.assertEqual(b"PrefixHelloWorld", bytes(cord2))
40+
41+
# Confirm that no copies were made when appending a Cord.
42+
self.assertEqual(id(cord2._buffers[1]), id(cord._buffers[0]))
43+
self.assertEqual(id(cord2._buffers[2]), id(cord._buffers[1]))
44+
45+
def test_cord_write_to_file(self) -> None:
46+
cord = Cord()
47+
cord.append(b"Hello")
48+
cord.append(b"World")
49+
50+
outfile = io.BytesIO()
51+
cord.write_to_file(outfile)
52+
self.assertEqual(b"HelloWorld", outfile.getvalue())

0 commit comments

Comments
 (0)