Skip to content

Commit 7961cce

Browse files
committed
Add optional Arrow deserialization support
1 parent b434f18 commit 7961cce

File tree

7 files changed

+65
-0
lines changed

7 files changed

+65
-0
lines changed

elastic_transport/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@
108108
except ImportError:
109109
pass
110110

111+
try:
112+
from elastic_transport._serializer import PyArrowSerializer # noqa: F401
113+
114+
__all__.append("PyArrowSerializer")
115+
except ImportError:
116+
pass
117+
111118
_logger = logging.getLogger("elastic_transport")
112119
_logger.addHandler(logging.NullHandler())
113120
del _logger

elastic_transport/_serializer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
except ModuleNotFoundError:
3030
orjson = None # type: ignore[assignment]
3131

32+
try:
33+
import pyarrow as pa
34+
except ModuleNotFoundError:
35+
pa = None # type: ignore[assignment]
36+
3237

3338
class Serializer:
3439
"""Serializer interface."""
@@ -192,12 +197,38 @@ def dumps(self, data: Any) -> bytes:
192197
return bytes(buffer)
193198

194199

200+
if pa is not None:
201+
202+
class PyArrowSerializer(Serializer):
203+
"""PyArrow serializer for deserializing Arrow Stream data."""
204+
205+
mimetype: ClassVar[str] = "application/vnd.apache.arrow.stream"
206+
207+
def loads(self, data: bytes) -> pa.Table:
208+
try:
209+
with pa.ipc.open_stream(data) as reader:
210+
return reader.read_all()
211+
except pa.ArrowException as e:
212+
raise SerializationError(
213+
message=f"Unable to deserialize as Arrow stream: {data!r}",
214+
errors=(e,),
215+
)
216+
217+
def dumps(self, data: Any) -> bytes:
218+
raise SerializationError(
219+
message="Elasticsearch does not accept Arrow input data"
220+
)
221+
222+
195223
DEFAULT_SERIALIZERS = {
196224
JsonSerializer.mimetype: JsonSerializer(),
197225
TextSerializer.mimetype: TextSerializer(),
198226
NdjsonSerializer.mimetype: NdjsonSerializer(),
199227
}
200228

229+
if pa is not None:
230+
DEFAULT_SERIALIZERS[PyArrowSerializer.mimetype] = PyArrowSerializer()
231+
201232

202233
class SerializerCollection:
203234
"""Collection of serializers that can be fetched by mimetype. Used by

noxfile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def lint(session):
4646
"mypy==1.7.1",
4747
"types-requests",
4848
"types-certifi",
49+
"pyarrow-stubs",
4950
)
5051
# https://github.com/python/typeshed/issues/10786
5152
session.run(

requirements-min.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ requests==2.26.0
22
urllib3==1.26.2
33
aiohttp==3.8.0
44
httpx==0.27.0
5+
pyarrow==1.0.0

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
"opentelemetry-api",
7171
"opentelemetry-sdk",
7272
"orjson",
73+
"pyarrow",
7374
# Override Read the Docs default (sphinx<2)
7475
"sphinx>2",
7576
"furo",

tests/test_package.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def test__all__sorted(module):
2929
# Optional dependencies are added at the end
3030
if "OrjsonSerializer" in module_all:
3131
module_all.remove("OrjsonSerializer")
32+
if "PyArrowSerializer" in module_all:
33+
module_all.remove("PyArrowSerializer")
34+
3235
assert module_all == sorted(module_all)
3336

3437

tests/test_serializer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
from datetime import date
2020
from decimal import Decimal
2121

22+
import pyarrow as pa
2223
import pytest
2324

2425
from elastic_transport import (
2526
JsonSerializer,
2627
NdjsonSerializer,
2728
OrjsonSerializer,
29+
PyArrowSerializer,
2830
SerializationError,
2931
SerializerCollection,
3032
TextSerializer,
@@ -191,3 +193,22 @@ def test_ndjson_dumps():
191193
b'{"key:"value"}\n'
192194
b'{"bytes":"too"}\n'
193195
)
196+
197+
198+
def test_pyarrow_loads():
199+
data = [
200+
pa.array([1, 2, 3, 4]),
201+
pa.array(["foo", "bar", "baz", None]),
202+
pa.array([True, None, False, True]),
203+
]
204+
batch = pa.record_batch(data, names=["f0", "f1", "f2"])
205+
sink = pa.BufferOutputStream()
206+
with pa.ipc.new_stream(sink, batch.schema) as writer:
207+
writer.write_batch(batch)
208+
209+
serializer = PyArrowSerializer()
210+
assert serializer.loads(sink.getvalue()).to_pydict() == {
211+
"f0": [1, 2, 3, 4],
212+
"f1": ["foo", "bar", "baz", None],
213+
"f2": [True, None, False, True],
214+
}

0 commit comments

Comments
 (0)