Skip to content

Commit 3e75a99

Browse files
committed
updated
1 parent 02ee0e1 commit 3e75a99

File tree

2 files changed

+47
-35
lines changed

2 files changed

+47
-35
lines changed

src/tri_map.c

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,21 +1049,6 @@ TriMap_map_dst_no_fill(TriMapObject *self, PyObject *arg) {
10491049
}
10501050

10511051

1052-
// PyObject *
1053-
// TriMap_map_merge_no_fill(TriMapObject *self, PyObject *arg) {
1054-
// if (!PyArray_Check(arg)) {
1055-
// PyErr_SetString(PyExc_TypeError, "Must provide an array");
1056-
// return NULL;
1057-
// }
1058-
// if (!self->finalized) {
1059-
// PyErr_SetString(PyExc_RuntimeError, "Finalization is required");
1060-
// return NULL;
1061-
// }
1062-
// PyArrayObject* array_from = (PyArrayObject*)arg;
1063-
// bool from_src = false;
1064-
// return AK_TM_map_no_fill(self, from_src, array_from);
1065-
// }
1066-
10671052
static inline PyObject *
10681053
TriMap_map_merge_no_fill(TriMapObject *self, PyObject *args)
10691054
{
@@ -1089,29 +1074,52 @@ TriMap_map_merge_no_fill(TriMapObject *self, PyObject *args)
10891074
PyErr_SetString(PyExc_TypeError, "Array dst must be 1D");
10901075
return NULL;
10911076
}
1092-
Py_RETURN_NONE;
10931077

1094-
// // TODO: resolve dtype from src, dst
1078+
// passing a borrowed refs; returns a new ref
1079+
PyArray_Descr* dtype = AK_resolve_dtype(
1080+
PyArray_DESCR(array_src),
1081+
PyArray_DESCR(array_dst));
1082+
bool dtype_is_obj = dtype->type_num == NPY_OBJECT;
1083+
bool dtype_is_unicode = dtype->type_num == NPY_UNICODE;
1084+
bool dtype_is_string = dtype->type_num == NPY_STRING;
10951085

1096-
// npy_intp dims[] = {tm->len};
1097-
// PyArrayObject* array_to;
1098-
// bool dtype_is_obj = PyArray_TYPE(array_from) == NPY_OBJECT;
1099-
// bool dtype_is_unicode = PyArray_TYPE(array_from) == NPY_UNICODE;
1100-
// bool dtype_is_string = PyArray_TYPE(array_from) == NPY_STRING;
1086+
npy_intp dims[] = {self->len};
1087+
1088+
// create to array_to
1089+
PyArrayObject* array_to;
1090+
if (dtype_is_obj) {
1091+
Py_DECREF(dtype); // not needed
1092+
// will initialize to NULL, not None
1093+
array_to = (PyArrayObject*)PyArray_SimpleNew(1, dims, NPY_OBJECT);
1094+
// Py_INCREF(array_from); // normalize refs when casting
1095+
}
1096+
else if (dtype_is_unicode || dtype_is_string) {
1097+
array_to = (PyArrayObject*)PyArray_Zeros(1, dims, dtype, 0); // steals dtype ref
1098+
// Py_INCREF(array_from); // normalize refs when casting
1099+
}
1100+
else {
1101+
array_to = (PyArrayObject*)PyArray_Empty(1, dims, dtype, 0); // steals dtype ref
1102+
// if (PyArray_TYPE(array_from) == NPY_DATETIME &&
1103+
// PyArray_TYPE(array_to) == NPY_DATETIME &&
1104+
// AK_dt_unit_from_array(array_from) != AK_dt_unit_from_array(array_to)
1105+
// ) {
1106+
// // if trying to cast into a dt64 array, need to pre-convert; array_from is originally borrowed; calling cast sets it to a new ref
1107+
// dtype = PyArray_DESCR(array_to); // borrowed ref
1108+
// Py_INCREF(dtype);
1109+
// array_from = (PyArrayObject*)PyArray_CastToType(array_from, dtype, 0);
1110+
// }
1111+
// else {
1112+
// Py_INCREF(array_from); // normalize refs when casting
1113+
// }
1114+
}
1115+
if (array_to == NULL) {
1116+
PyErr_SetNone(PyExc_MemoryError);
1117+
// Py_DECREF((PyObject*)array_from);
1118+
return NULL;
1119+
}
1120+
1121+
Py_RETURN_NONE;
11011122

1102-
// // create to array
1103-
// if (dtype_is_obj) { // initializes values to NULL
1104-
// array_to = (PyArrayObject*)PyArray_SimpleNew(1, dims, NPY_OBJECT);
1105-
// }
1106-
// else {
1107-
// PyArray_Descr* dtype = PyArray_DESCR(array_from); // borowed ref
1108-
// Py_INCREF(dtype);
1109-
// array_to = (PyArrayObject*)PyArray_Empty(1, dims, dtype, 0); // steals dtype ref
1110-
// }
1111-
// if (array_to == NULL) {
1112-
// PyErr_SetNone(PyExc_MemoryError);
1113-
// return NULL;
1114-
// }
11151123
// // transfer values
11161124
// if (dtype_is_obj) {
11171125
// if (AK_TM_transfer_object(tm, from_src, array_from, array_to)) {

test/test_tri_map.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,10 @@ def test_tri_map_map_c(self) -> None:
11501150
tm.register_many(2, np.array([0, 4], dtype=np.dtype(np.int64)))
11511151
tm.register_one(3, 1)
11521152
tm.register_one(4, 2)
1153+
1154+
with self.assertRaises(RuntimeError):
1155+
_ = tm.map_merge_no_fill(src, dst)
1156+
11531157
tm.finalize()
11541158

11551159
with self.assertRaises(TypeError):

0 commit comments

Comments
 (0)