Skip to content

Commit dedb1e0

Browse files
Add Cython bindings for arrow::ArrayBuilder classes for Int32, Int64, Double and Timestamp types (#2)
1 parent 0bb2434 commit dedb1e0

File tree

6 files changed

+346
-5
lines changed

6 files changed

+346
-5
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2021-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Cython compiler directives
16+
# distutils: language=c++
17+
# cython: language_level=3
18+
19+
cdef class _BuilderBase:
20+
def append_values(self, values):
21+
for value in values:
22+
self.append(value)
23+
24+
@property
25+
def null_count(self):
26+
return self.builder.get().null_count()
27+
28+
def __len__(self):
29+
return self.builder.get().length()
30+
31+
32+
cdef class Int32Builder(_BuilderBase):
33+
cdef:
34+
shared_ptr[CInt32Builder] builder
35+
36+
def __cinit__(self, MemoryPool memory_pool=None):
37+
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
38+
self.builder.reset(new CInt32Builder(pool))
39+
40+
def append(self, value):
41+
if value is None or value is np.nan:
42+
self.builder.get().AppendNull()
43+
elif isinstance(value, int):
44+
self.builder.get().Append(value)
45+
else:
46+
raise TypeError('Int32Builder only accepts integer objects')
47+
48+
def finish(self):
49+
cdef shared_ptr[CArray] out
50+
with nogil:
51+
self.builder.get().Finish(&out)
52+
return pyarrow_wrap_array(out)
53+
54+
cdef shared_ptr[CInt32Builder] unwrap(self):
55+
return self.builder
56+
57+
58+
cdef class Int64Builder(_BuilderBase):
59+
cdef:
60+
shared_ptr[CInt64Builder] builder
61+
62+
def __cinit__(self, MemoryPool memory_pool=None):
63+
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
64+
self.builder.reset(new CInt64Builder(pool))
65+
66+
def append(self, value):
67+
if value is None or value is np.nan:
68+
self.builder.get().AppendNull()
69+
elif isinstance(value, int):
70+
self.builder.get().Append(value)
71+
else:
72+
raise TypeError('Int64Builder only accepts integer objects')
73+
74+
def finish(self):
75+
cdef shared_ptr[CArray] out
76+
with nogil:
77+
self.builder.get().Finish(&out)
78+
return pyarrow_wrap_array(out)
79+
80+
cdef shared_ptr[CInt64Builder] unwrap(self):
81+
return self.builder
82+
83+
84+
cdef class DoubleBuilder(_BuilderBase):
85+
cdef:
86+
shared_ptr[CDoubleBuilder] builder
87+
88+
def __cinit__(self, MemoryPool memory_pool=None):
89+
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
90+
self.builder.reset(new CDoubleBuilder(pool))
91+
92+
def append(self, value):
93+
if value is None or value is np.nan:
94+
self.builder.get().AppendNull()
95+
elif isinstance(value, (int, float)):
96+
self.builder.get().Append(value)
97+
else:
98+
raise TypeError('DoubleBuilder only accepts floats and ints')
99+
100+
def finish(self):
101+
cdef shared_ptr[CArray] out
102+
with nogil:
103+
self.builder.get().Finish(&out)
104+
return pyarrow_wrap_array(out)
105+
106+
cdef shared_ptr[CDoubleBuilder] unwrap(self):
107+
return self.builder
108+
109+
110+
cdef class DatetimeBuilder(_BuilderBase):
111+
cdef:
112+
shared_ptr[CTimestampBuilder] builder
113+
TimestampType dtype
114+
115+
def __cinit__(self, TimestampType dtype=timestamp('ms'),
116+
MemoryPool memory_pool=None):
117+
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
118+
if dtype in (timestamp('us'), timestamp('ns')):
119+
raise ValueError("Microsecond resolution temporal type is not "
120+
"suitable for use with MongoDB's UTC datetime "
121+
"type which has resolution of milliseconds.")
122+
self.dtype = dtype
123+
self.builder.reset(new CTimestampBuilder(
124+
pyarrow_unwrap_data_type(self.dtype), pool))
125+
126+
def append(self, value):
127+
if value is None or value is np.nan:
128+
self.builder.get().AppendNull()
129+
elif isinstance(value, datetime.datetime):
130+
self.builder.get().Append(
131+
datetime_to_int64(value, self.dtype))
132+
else:
133+
raise TypeError('TimestampBuilder only accepts datetime objects')
134+
135+
def finish(self):
136+
cdef shared_ptr[CArray] out
137+
with nogil:
138+
self.builder.get().Finish(&out)
139+
return pyarrow_wrap_array(out)
140+
141+
@property
142+
def unit(self):
143+
return self.dtype
144+
145+
cdef shared_ptr[CTimestampBuilder] unwrap(self):
146+
return self.builder

bindings/python/pymongoarrow/lib.pyx

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2021-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Cython compiler directives
16+
# distutils: language=c++
17+
# cython: language_level=3
18+
19+
# Stdlib imports
20+
import datetime
21+
22+
# Python imports
23+
import numpy as np
24+
from pyarrow import timestamp
25+
26+
# Cython imports
27+
from pyarrow.lib cimport *
28+
29+
30+
# Utilities
31+
include "utils.pyi"
32+
33+
# Builders
34+
include "builders.pyi"
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2021-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
def datetime_to_int64(dtm, data_type):
16+
# TODO: rewrite as a cdef which directly accesses data_type as a CTimestampType instance
17+
# TODO: make this function aware of datatype.timezone()
18+
total_seconds = int((dtm - datetime.datetime(1970, 1, 1)).total_seconds())
19+
total_microseconds = int(total_seconds) * 10**6 + dtm.microsecond
20+
21+
if data_type.unit == 's':
22+
factor = 1.
23+
elif data_type.unit == 'ms':
24+
factor = 10. ** 3
25+
elif data_type.unit == 'us':
26+
factor = 10. ** 6
27+
elif data_type.unit == 'ns':
28+
factor = 10. ** 9
29+
else:
30+
raise ValueError('Unsupported timestamp unit {}'.format(
31+
data_type.unit))
32+
33+
int64_t = int(total_microseconds * factor / (10. ** 6))
34+
return int64_t

bindings/python/setup.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import os
55

6+
import numpy as np
7+
import pyarrow as pa
8+
69

710
def get_pymongoarrow_version():
811
"""Single source the version."""
@@ -15,16 +18,29 @@ def get_pymongoarrow_version():
1518

1619

1720
def get_extension_modules():
18-
modules = cythonize(['pymongoarrow/*.pyx',
19-
'pymongoarrow/libbson/*.pyx'])
20-
for module in modules:
21+
arrow_modules = cythonize(['pymongoarrow/*.pyx'])
22+
libbson_modules = cythonize(['pymongoarrow/libbson/*.pyx'])
23+
24+
for module in libbson_modules:
2125
module.libraries.append('bson-1.0')
22-
return modules
26+
27+
for module in arrow_modules:
28+
module.include_dirs.append(np.get_include())
29+
module.include_dirs.append(pa.get_include())
30+
module.libraries.extend(pa.get_libraries())
31+
module.library_dirs.extend(pa.get_library_dirs())
32+
33+
# https://arrow.apache.org/docs/python/extending.html#example
34+
if os.name == 'posix':
35+
module.extra_compile_args.append('-std=c++11')
36+
37+
return arrow_modules + libbson_modules
2338

2439

2540
setup(
2641
name='pymongoarrow',
2742
version=get_pymongoarrow_version(),
2843
packages=find_packages(),
2944
ext_modules=get_extension_modules(),
30-
setup_requires=['cython >= 0.29'])
45+
install_requires=['pyarrow >= 3', 'pymongo >= 3.11,<4'],
46+
setup_requires=['cython >= 0.29', 'pyarrow >= 3', 'numpy >= 1.16.6'])

bindings/python/test/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2021-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

bindings/python/test/test_builders.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2021-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from datetime import datetime, timedelta
15+
from unittest import TestCase
16+
17+
from pyarrow import Array, timestamp, int32, int64
18+
19+
from pymongoarrow.lib import (
20+
DatetimeBuilder, DoubleBuilder, Int32Builder, Int64Builder)
21+
22+
23+
class TestIntBuildersMixin:
24+
def test_simple(self):
25+
builder = self.builder_cls()
26+
builder.append(0)
27+
builder.append_values([1, 2, 3, 4])
28+
builder.append(None)
29+
arr = builder.finish()
30+
31+
self.assertIsInstance(arr, Array)
32+
self.assertEqual(arr.null_count, 1)
33+
self.assertEqual(len(arr), 6)
34+
self.assertEqual(
35+
arr.to_pylist(), [0, 1, 2, 3, 4, None])
36+
self.assertEqual(arr.type, self.data_type)
37+
38+
39+
class TestInt32Builder(TestCase, TestIntBuildersMixin):
40+
def setUp(self):
41+
self.builder_cls = Int32Builder
42+
self.data_type = int32()
43+
44+
45+
class TestInt64Builder(TestCase, TestIntBuildersMixin):
46+
def setUp(self):
47+
self.builder_cls = Int64Builder
48+
self.data_type = int64()
49+
50+
51+
class TestDate64Builder(TestCase):
52+
def test_default_unit(self):
53+
# Check default unit
54+
builder = DatetimeBuilder()
55+
self.assertEqual(builder.unit, timestamp('ms'))
56+
57+
def _test_simple(self, tstamp_units, kwarg_name):
58+
builder = DatetimeBuilder(dtype=timestamp(tstamp_units))
59+
datetimes = [datetime(1970, 1, 1) + timedelta(**{kwarg_name: k*100})
60+
for k in range(5)]
61+
builder.append(datetimes[0])
62+
builder.append_values(datetimes[1:])
63+
builder.append(None)
64+
arr = builder.finish()
65+
66+
self.assertIsInstance(arr, Array)
67+
self.assertEqual(arr.null_count, 1)
68+
self.assertEqual(len(arr), len(datetimes) + 1)
69+
self.assertEqual(arr.to_pylist(), datetimes + [None])
70+
self.assertEqual(arr.type, timestamp(tstamp_units))
71+
72+
def test_simple(self):
73+
# milliseconds
74+
self._test_simple('ms', 'milliseconds')
75+
# seconds
76+
self._test_simple('s', 'seconds')
77+
78+
def test_unsupported_units(self):
79+
with self.assertRaises(ValueError):
80+
DatetimeBuilder(dtype=timestamp('us'))
81+
82+
with self.assertRaises(ValueError):
83+
DatetimeBuilder(dtype=timestamp('ns'))
84+
85+
86+
class TestDoubleBuilder(TestCase):
87+
def test_simple(self):
88+
builder = DoubleBuilder()
89+
builder.append(0.123)
90+
builder.append_values([1.234, 2.345, 3.456, 4.567])
91+
builder.append(None)
92+
arr = builder.finish()
93+
94+
self.assertIsInstance(arr, Array)
95+
self.assertEqual(arr.null_count, 1)
96+
self.assertEqual(len(arr), 6)
97+
self.assertEqual(
98+
arr.to_pylist(), [0.123, 1.234, 2.345, 3.456, 4.567, None])

0 commit comments

Comments
 (0)