Skip to content

Commit 9372fe2

Browse files
committed
Merge branch 'master' of github.com:InvestmentSystems/arraykit
2 parents 1c6a1fd + e585289 commit 9372fe2

File tree

3 files changed

+31
-31
lines changed

3 files changed

+31
-31
lines changed

src/auto_map.c

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,6 @@ typedef enum KeysArrayType{
7474
KAT_DTas,
7575
} KeysArrayType;
7676

77-
NPY_DATETIMEUNIT
78-
dt_unit_from_array(PyArrayObject* a) {
79-
// This is based on get_datetime_metadata_from_dtype in the NumPy source, but that function is private. This does not check that the dytpe is of the appropriate type.
80-
PyArray_Descr* dt = PyArray_DESCR(a); // borrowed ref
81-
PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyDataType_C_METADATA(dt))->meta);
82-
return dma->base;
83-
}
84-
85-
NPY_DATETIMEUNIT
86-
dt_unit_from_scalar(PyDatetimeScalarObject* dts) {
87-
// Based on convert_pyobject_to_datetime and related usage in datetime.c
88-
PyArray_DatetimeMetaData* dma = &(dts->obmeta);
89-
return dma->base;
90-
}
9177

9278
KeysArrayType
9379
at_to_kat(int array_t, PyArrayObject* a) {
@@ -123,7 +109,7 @@ at_to_kat(int array_t, PyArrayObject* a) {
123109
return KAT_STRING;
124110

125111
case NPY_DATETIME: {
126-
NPY_DATETIMEUNIT dtu = dt_unit_from_array(a);
112+
NPY_DATETIMEUNIT dtu = AK_dt_unit_from_array(a);
127113
switch (dtu) {
128114
case NPY_FR_Y:
129115
return KAT_DTY;
@@ -685,9 +671,6 @@ lookup_hash_obj(FAMObject *self, PyObject *key, Py_hash_t hash)
685671
int result = -1;
686672
Py_hash_t h = 0;
687673

688-
// AK_DEBUG_MSG_OBJ("lookup_hash_obj", key);
689-
// TODO: if key is a dt64, we need to get the units and compare to units before doing PyObject_RichCompareBool
690-
691674
while (1) {
692675
for (Py_ssize_t i = 0; i < SCAN; i++) {
693676
h = table[table_pos].hash;
@@ -702,6 +685,16 @@ lookup_hash_obj(FAMObject *self, PyObject *key, Py_hash_t hash)
702685
if (guess == key) { // Hit. Object ID comparison
703686
return table_pos;
704687
}
688+
689+
// if key is a dt64, only do PyObject_RichCompareBool if units match
690+
if (PyArray_IsScalar(key, Datetime) && PyArray_IsScalar(guess, Datetime)) {
691+
if (AK_dt_unit_from_scalar((PyDatetimeScalarObject *)key)
692+
!= AK_dt_unit_from_scalar((PyDatetimeScalarObject *)guess)) {
693+
table_pos++;
694+
continue;
695+
}
696+
}
697+
705698
result = PyObject_RichCompareBool(guess, key, Py_EQ);
706699
if (result < 0) { // Error.
707700
return -1;
@@ -1030,10 +1023,9 @@ lookup_datetime(FAMObject *self, PyObject* key) {
10301023
if (PyArray_IsScalar(key, Datetime)) {
10311024
v = (npy_int64)PyArrayScalar_VAL(key, Datetime);
10321025
// if we observe a NAT, we skip unit checks
1033-
// AK_DEBUG_MSG_OBJ("dt64 value", PyLong_FromLongLong(v));
10341026

10351027
if (v != NPY_DATETIME_NAT) {
1036-
NPY_DATETIMEUNIT key_unit = dt_unit_from_scalar(
1028+
NPY_DATETIMEUNIT key_unit = AK_dt_unit_from_scalar(
10371029
(PyDatetimeScalarObject *)key);
10381030
if (!kat_is_datetime_unit(self->keys_array_type, key_unit)) {
10391031
return -1;
@@ -1872,7 +1864,7 @@ fam_get_all(FAMObject *self, PyObject *key) {
18721864
GET_ALL_FLEXIBLE(char, char_get_end_p, lookup_hash_string, string_to_hash, PyBytes_FromStringAndSize);
18731865
break;
18741866
case NPY_DATETIME: {
1875-
NPY_DATETIMEUNIT key_unit = dt_unit_from_array(key_array);
1867+
NPY_DATETIMEUNIT key_unit = AK_dt_unit_from_array(key_array);
18761868
if (!kat_is_datetime_unit(self->keys_array_type, key_unit)) {
18771869
PyErr_SetString(PyExc_KeyError, "datetime64 units do not match");
18781870
Py_DECREF(array);
@@ -2070,7 +2062,7 @@ fam_get_any(FAMObject *self, PyObject *key) {
20702062
GET_ANY_FLEXIBLE(char, char_get_end_p, lookup_hash_string, string_to_hash);
20712063
break;
20722064
case NPY_DATETIME: {
2073-
NPY_DATETIMEUNIT key_unit = dt_unit_from_array(key_array);
2065+
NPY_DATETIMEUNIT key_unit = AK_dt_unit_from_array(key_array);
20742066
if (!kat_is_datetime_unit(self->keys_array_type, key_unit)) {
20752067
return values;
20762068
}

src/tri_map.c

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,6 @@
1111
# include "tri_map.h"
1212
# include "utilities.h"
1313

14-
static inline NPY_DATETIMEUNIT
15-
AK_dt_unit_from_array(PyArrayObject* a) {
16-
// This is based on get_datetime_metadata_from_dtype in the NumPy source, but that function is private. This does not check that the dtype is of the appropriate type.
17-
PyArray_Descr* dt = PyArray_DESCR(a); // borrowed ref
18-
PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyDataType_C_METADATA(dt))->meta);
19-
// PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyArray_DESCR(a)->c_metadata)->meta);
20-
return dma->base;
21-
}
22-
2314
typedef struct TriMapOne {
2415
Py_ssize_t from; // signed
2516
Py_ssize_t to;

src/utilities.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
1010

1111
# include "numpy/arrayobject.h"
12+
# include "numpy/arrayscalars.h"
1213

1314
static const size_t UCS4_SIZE = sizeof(Py_UCS4);
1415

@@ -318,4 +319,20 @@ AK_nonzero_1d(PyArrayObject* array) {
318319
return final;
319320
}
320321

322+
static inline NPY_DATETIMEUNIT
323+
AK_dt_unit_from_array(PyArrayObject* a) {
324+
// This is based on get_datetime_metadata_from_dtype in the NumPy source, but that function is private. This does not check that the dtype is of the appropriate type.
325+
PyArray_Descr* dt = PyArray_DESCR(a); // borrowed ref
326+
PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyDataType_C_METADATA(dt))->meta);
327+
// PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyArray_DESCR(a)->c_metadata)->meta);
328+
return dma->base;
329+
}
330+
331+
static inline NPY_DATETIMEUNIT
332+
AK_dt_unit_from_scalar(PyDatetimeScalarObject* dts) {
333+
// Based on convert_pyobject_to_datetime and related usage in datetime.c
334+
PyArray_DatetimeMetaData* dma = &(dts->obmeta);
335+
return dma->base;
336+
}
337+
321338
#endif /* ARRAYKIT_SRC_UTILITIES_H_ */

0 commit comments

Comments
 (0)