-
Notifications
You must be signed in to change notification settings - Fork 16
Add Cython bindings for arrow::ArrayBuilder classes #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dd468dc
37dbf12
9204949
a90ebc8
ac9a78e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright 2021-present MongoDB, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Cython compiler directives | ||
# distutils: language=c++ | ||
# cython: language_level=3 | ||
|
||
cdef class _BuilderBase: | ||
def append_values(self, values): | ||
for value in values: | ||
self.append(value) | ||
|
||
@property | ||
def null_count(self): | ||
return self.builder.get().null_count() | ||
|
||
def __len__(self): | ||
return self.builder.get().length() | ||
|
||
|
||
cdef class Int32Builder(_BuilderBase): | ||
cdef: | ||
shared_ptr[CInt32Builder] builder | ||
|
||
def __cinit__(self, MemoryPool memory_pool=None): | ||
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) | ||
self.builder.reset(new CInt32Builder(pool)) | ||
|
||
def append(self, value): | ||
if value is None or value is np.nan: | ||
self.builder.get().AppendNull() | ||
elif isinstance(value, int): | ||
self.builder.get().Append(value) | ||
else: | ||
raise TypeError('Int32Builder only accepts integer objects') | ||
|
||
def finish(self): | ||
cdef shared_ptr[CArray] out | ||
with nogil: | ||
self.builder.get().Finish(&out) | ||
return pyarrow_wrap_array(out) | ||
|
||
cdef shared_ptr[CInt32Builder] unwrap(self): | ||
return self.builder | ||
|
||
|
||
cdef class Int64Builder(_BuilderBase): | ||
cdef: | ||
shared_ptr[CInt64Builder] builder | ||
|
||
def __cinit__(self, MemoryPool memory_pool=None): | ||
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) | ||
self.builder.reset(new CInt64Builder(pool)) | ||
|
||
def append(self, value): | ||
if value is None or value is np.nan: | ||
self.builder.get().AppendNull() | ||
elif isinstance(value, int): | ||
self.builder.get().Append(value) | ||
else: | ||
raise TypeError('Int64Builder only accepts integer objects') | ||
|
||
def finish(self): | ||
cdef shared_ptr[CArray] out | ||
with nogil: | ||
self.builder.get().Finish(&out) | ||
return pyarrow_wrap_array(out) | ||
|
||
cdef shared_ptr[CInt64Builder] unwrap(self): | ||
return self.builder | ||
|
||
|
||
cdef class DoubleBuilder(_BuilderBase): | ||
cdef: | ||
shared_ptr[CDoubleBuilder] builder | ||
|
||
def __cinit__(self, MemoryPool memory_pool=None): | ||
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) | ||
self.builder.reset(new CDoubleBuilder(pool)) | ||
|
||
def append(self, value): | ||
if value is None or value is np.nan: | ||
self.builder.get().AppendNull() | ||
elif isinstance(value, (int, float)): | ||
self.builder.get().Append(value) | ||
else: | ||
raise TypeError('DoubleBuilder only accepts floats and ints') | ||
|
||
def finish(self): | ||
cdef shared_ptr[CArray] out | ||
with nogil: | ||
self.builder.get().Finish(&out) | ||
return pyarrow_wrap_array(out) | ||
|
||
cdef shared_ptr[CDoubleBuilder] unwrap(self): | ||
return self.builder | ||
|
||
|
||
cdef class DatetimeBuilder(_BuilderBase): | ||
cdef: | ||
shared_ptr[CTimestampBuilder] builder | ||
TimestampType dtype | ||
|
||
def __cinit__(self, TimestampType dtype=timestamp('ms'), | ||
MemoryPool memory_pool=None): | ||
cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) | ||
if dtype in (timestamp('us'), timestamp('ns')): | ||
raise ValueError("Microsecond resolution temporal type is not " | ||
"suitable for use with MongoDB's UTC datetime " | ||
"type which has resolution of milliseconds.") | ||
self.dtype = dtype | ||
self.builder.reset(new CTimestampBuilder( | ||
pyarrow_unwrap_data_type(self.dtype), pool)) | ||
|
||
def append(self, value): | ||
if value is None or value is np.nan: | ||
self.builder.get().AppendNull() | ||
elif isinstance(value, datetime.datetime): | ||
self.builder.get().Append( | ||
datetime_to_int64(value, self.dtype)) | ||
else: | ||
raise TypeError('TimestampBuilder only accepts datetime objects') | ||
|
||
def finish(self): | ||
cdef shared_ptr[CArray] out | ||
with nogil: | ||
self.builder.get().Finish(&out) | ||
return pyarrow_wrap_array(out) | ||
|
||
@property | ||
def unit(self): | ||
return self.dtype | ||
|
||
cdef shared_ptr[CTimestampBuilder] unwrap(self): | ||
return self.builder |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright 2021-present MongoDB, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Cython compiler directives | ||
# distutils: language=c++ | ||
# cython: language_level=3 | ||
|
||
# Stdlib imports | ||
import datetime | ||
|
||
# Python imports | ||
import numpy as np | ||
from pyarrow import timestamp | ||
|
||
# Cython imports | ||
from pyarrow.lib cimport * | ||
|
||
|
||
# Utilities | ||
include "utils.pyi" | ||
|
||
# Builders | ||
include "builders.pyi" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright 2021-present MongoDB, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
def datetime_to_int64(dtm, data_type): | ||
# TODO: rewrite as a cdef which directly accesses data_type as a CTimestampType instance | ||
# TODO: make this function aware of datatype.timezone() | ||
total_seconds = int((dtm - datetime.datetime(1970, 1, 1)).total_seconds()) | ||
total_microseconds = int(total_seconds) * 10**6 + dtm.microsecond | ||
|
||
if data_type.unit == 's': | ||
factor = 1. | ||
elif data_type.unit == 'ms': | ||
factor = 10. ** 3 | ||
elif data_type.unit == 'us': | ||
factor = 10. ** 6 | ||
elif data_type.unit == 'ns': | ||
factor = 10. ** 9 | ||
else: | ||
raise ValueError('Unsupported timestamp unit {}'.format( | ||
data_type.unit)) | ||
|
||
int64_t = int(total_microseconds * factor / (10. ** 6)) | ||
return int64_t |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,9 @@ | |
|
||
import os | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
|
||
|
||
def get_pymongoarrow_version(): | ||
"""Single source the version.""" | ||
|
@@ -15,16 +18,29 @@ def get_pymongoarrow_version(): | |
|
||
|
||
def get_extension_modules(): | ||
modules = cythonize(['pymongoarrow/*.pyx', | ||
'pymongoarrow/libbson/*.pyx']) | ||
for module in modules: | ||
arrow_modules = cythonize(['pymongoarrow/*.pyx']) | ||
libbson_modules = cythonize(['pymongoarrow/libbson/*.pyx']) | ||
|
||
for module in libbson_modules: | ||
module.libraries.append('bson-1.0') | ||
return modules | ||
|
||
for module in arrow_modules: | ||
module.include_dirs.append(np.get_include()) | ||
module.include_dirs.append(pa.get_include()) | ||
module.libraries.extend(pa.get_libraries()) | ||
module.library_dirs.extend(pa.get_library_dirs()) | ||
|
||
# https://arrow.apache.org/docs/python/extending.html#example | ||
if os.name == 'posix': | ||
module.extra_compile_args.append('-std=c++11') | ||
|
||
return arrow_modules + libbson_modules | ||
|
||
|
||
setup( | ||
name='pymongoarrow', | ||
version=get_pymongoarrow_version(), | ||
packages=find_packages(), | ||
ext_modules=get_extension_modules(), | ||
setup_requires=['cython >= 0.29']) | ||
install_requires=['pyarrow >= 3', 'pymongo >= 3.11,<4'], | ||
setup_requires=['cython >= 0.29', 'pyarrow >= 3', 'numpy >= 1.16.6']) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me see if I have this right. So we need pyarrow+numpy to build+install the wheel but the wheel can be installed and used even without pyarrow+numpy? Or are we missing a necessary There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are missing the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2021-present MongoDB, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright 2021-present MongoDB, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from datetime import datetime, timedelta | ||
from unittest import TestCase | ||
|
||
from pyarrow import Array, timestamp, int32, int64 | ||
|
||
from pymongoarrow.lib import ( | ||
DatetimeBuilder, DoubleBuilder, Int32Builder, Int64Builder) | ||
|
||
|
||
class TestIntBuildersMixin: | ||
def test_simple(self): | ||
builder = self.builder_cls() | ||
builder.append(0) | ||
builder.append_values([1, 2, 3, 4]) | ||
builder.append(None) | ||
arr = builder.finish() | ||
|
||
self.assertIsInstance(arr, Array) | ||
self.assertEqual(arr.null_count, 1) | ||
self.assertEqual(len(arr), 6) | ||
self.assertEqual( | ||
arr.to_pylist(), [0, 1, 2, 3, 4, None]) | ||
self.assertEqual(arr.type, self.data_type) | ||
|
||
|
||
class TestInt32Builder(TestCase, TestIntBuildersMixin): | ||
def setUp(self): | ||
self.builder_cls = Int32Builder | ||
self.data_type = int32() | ||
|
||
|
||
class TestInt64Builder(TestCase, TestIntBuildersMixin): | ||
def setUp(self): | ||
self.builder_cls = Int64Builder | ||
self.data_type = int64() | ||
|
||
|
||
class TestDate64Builder(TestCase): | ||
def test_default_unit(self): | ||
# Check default unit | ||
builder = DatetimeBuilder() | ||
self.assertEqual(builder.unit, timestamp('ms')) | ||
|
||
def _test_simple(self, tstamp_units, kwarg_name): | ||
builder = DatetimeBuilder(dtype=timestamp(tstamp_units)) | ||
datetimes = [datetime(1970, 1, 1) + timedelta(**{kwarg_name: k*100}) | ||
for k in range(5)] | ||
builder.append(datetimes[0]) | ||
builder.append_values(datetimes[1:]) | ||
builder.append(None) | ||
arr = builder.finish() | ||
|
||
self.assertIsInstance(arr, Array) | ||
self.assertEqual(arr.null_count, 1) | ||
self.assertEqual(len(arr), len(datetimes) + 1) | ||
self.assertEqual(arr.to_pylist(), datetimes + [None]) | ||
self.assertEqual(arr.type, timestamp(tstamp_units)) | ||
|
||
def test_simple(self): | ||
# milliseconds | ||
self._test_simple('ms', 'milliseconds') | ||
# seconds | ||
self._test_simple('s', 'seconds') | ||
|
||
def test_unsupported_units(self): | ||
with self.assertRaises(ValueError): | ||
DatetimeBuilder(dtype=timestamp('us')) | ||
|
||
with self.assertRaises(ValueError): | ||
DatetimeBuilder(dtype=timestamp('ns')) | ||
|
||
|
||
class TestDoubleBuilder(TestCase): | ||
def test_simple(self): | ||
builder = DoubleBuilder() | ||
builder.append(0.123) | ||
builder.append_values([1.234, 2.345, 3.456, 4.567]) | ||
builder.append(None) | ||
arr = builder.finish() | ||
|
||
self.assertIsInstance(arr, Array) | ||
self.assertEqual(arr.null_count, 1) | ||
self.assertEqual(len(arr), 6) | ||
self.assertEqual( | ||
arr.to_pylist(), [0.123, 1.234, 2.345, 3.456, 4.567, None]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment linking to https://arrow.apache.org/docs/python/extending.html#example so we know where this pattern comes from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.