Skip to content

Commit f5a6627

Browse files
committed
optimzed object transfers to ignore filled values
1 parent 57ece6e commit f5a6627

File tree

2 files changed

+137
-16
lines changed

2 files changed

+137
-16
lines changed

src/tri_map.c

Lines changed: 92 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -765,8 +765,7 @@ static inline int
765765
AK_TM_transfer_object(TriMapObject* tm,
766766
bool from_src,
767767
PyArrayObject* array_from,
768-
PyArrayObject* array_to,
769-
bool clear_target
768+
PyArrayObject* array_to
770769
) {
771770
Py_ssize_t one_count = from_src ? tm->src_one_count : tm->dst_one_count;
772771
TriMapOne* one_pairs = from_src ? tm->src_one : tm->dst_one;
@@ -789,9 +788,6 @@ AK_TM_transfer_object(TriMapObject* tm,
789788
else { // will convert any value to an object
790789
pyo = PyArray_GETITEM(array_from, f);
791790
}
792-
if (clear_target) {
793-
Py_XDECREF(array_to_data[o->to]);
794-
}
795791
array_to_data[o->to] = pyo;
796792
}
797793
PyObject** t;
@@ -814,9 +810,6 @@ AK_TM_transfer_object(TriMapObject* tm,
814810
}
815811
while (t < t_end) {
816812
Py_INCREF(pyo); // one more than we need
817-
if (clear_target) {
818-
Py_XDECREF(*t);
819-
}
820813
*t++ = pyo;
821814
}
822815
Py_DECREF(pyo); // remove the extra ref
@@ -834,9 +827,6 @@ AK_TM_transfer_object(TriMapObject* tm,
834827
else {
835828
pyo = PyArray_GETITEM(array_from, f);
836829
}
837-
if (clear_target) {
838-
Py_XDECREF(*t);
839-
}
840830
*t++ = pyo;
841831
dst_pos++;
842832
}
@@ -845,6 +835,92 @@ AK_TM_transfer_object(TriMapObject* tm,
845835
return 0;
846836
}
847837

838+
// Returns -1 on error. Specialized transfer from any type of an array to an object array. For usage with merge, Will only transfer if the destination is not NULL.
839+
static inline int
840+
AK_TM_transfer_object_if_null(TriMapObject* tm,
841+
bool from_src,
842+
PyArrayObject* array_from,
843+
PyArrayObject* array_to
844+
) {
845+
Py_ssize_t one_count = from_src ? tm->src_one_count : tm->dst_one_count;
846+
TriMapOne* one_pairs = from_src ? tm->src_one : tm->dst_one;
847+
848+
// NOTE: could use PyArray_Scalar instead of PyArray_GETITEM if we wanted to store scalars instead of Python objects; however, that is pretty uncommon for object arrays to store PyArray_Scalars
849+
bool f_is_obj = PyArray_TYPE(array_from) == NPY_OBJECT;
850+
851+
// the passed in object array is contiguous and have NULL (not None) in each position
852+
PyObject** array_to_data = (PyObject**)PyArray_DATA(array_to);
853+
PyObject* pyo;
854+
void* f;
855+
TriMapOne* o = one_pairs;
856+
TriMapOne* o_end = o + one_count;
857+
for (; o < o_end; o++) {
858+
if (array_to_data[o->to] == NULL) {
859+
f = PyArray_GETPTR1(array_from, o->from);
860+
if (f_is_obj) {
861+
pyo = *(PyObject**)f;
862+
Py_INCREF(pyo);
863+
}
864+
else { // will convert any value to an object
865+
pyo = PyArray_GETITEM(array_from, f);
866+
}
867+
array_to_data[o->to] = pyo;
868+
}
869+
}
870+
PyObject** t;
871+
PyObject** t_end;
872+
npy_intp dst_pos;
873+
npy_int64 f_pos;
874+
PyArrayObject* dst;
875+
for (Py_ssize_t i = 0; i < tm->many_count; i++) {
876+
t = array_to_data + tm->many_to[i].start;
877+
t_end = array_to_data + tm->many_to[i].stop;
878+
879+
if (from_src) {
880+
while (t < t_end) {
881+
if (*t == NULL) {
882+
f = PyArray_GETPTR1(array_from, tm->many_from[i].src);
883+
if (f_is_obj) {
884+
pyo = *(PyObject**)f;
885+
Py_INCREF(pyo);
886+
}
887+
else {
888+
pyo = PyArray_GETITEM(array_from, f); // given a new ref
889+
}
890+
*t++ = pyo;
891+
}
892+
else {
893+
t++;
894+
}
895+
}
896+
}
897+
else { // from_dst, dst is an array
898+
dst_pos = 0;
899+
dst = tm->many_from[i].dst;
900+
while (t < t_end) {
901+
if (*t == NULL) {
902+
f_pos = *(npy_int64*)PyArray_GETPTR1(dst, dst_pos);
903+
f = PyArray_GETPTR1(array_from, f_pos);
904+
if (f_is_obj) {
905+
pyo = *(PyObject**)f;
906+
Py_INCREF(pyo);
907+
}
908+
else {
909+
pyo = PyArray_GETITEM(array_from, f);
910+
}
911+
*t++ = pyo;
912+
dst_pos++;
913+
}
914+
else {
915+
t++;
916+
dst_pos++;
917+
}
918+
}
919+
}
920+
}
921+
return 0;
922+
}
923+
848924
// Returns -1 on error.
849925
static inline int
850926
AK_TM_fill_object(TriMapObject* tm,
@@ -1008,7 +1084,7 @@ AK_TM_map_no_fill(TriMapObject* tm,
10081084
}
10091085
// transfer values
10101086
if (dtype_is_obj) {
1011-
if (AK_TM_transfer_object(tm, from_src, array_from, array_to, false)) {
1087+
if (AK_TM_transfer_object(tm, from_src, array_from, array_to)) {
10121088
Py_DECREF((PyObject*)array_to);
10131089
return NULL;
10141090
}
@@ -1063,6 +1139,7 @@ TriMap_map_dst_no_fill(TriMapObject *self, PyObject *arg) {
10631139
static inline PyObject *
10641140
TriMap_map_merge(TriMapObject *tm, PyObject *args)
10651141
{
1142+
// both are "from_" arrays
10661143
PyArrayObject* array_src;
10671144
PyArrayObject* array_dst;
10681145

@@ -1085,7 +1162,6 @@ TriMap_map_merge(TriMapObject *tm, PyObject *args)
10851162
PyErr_SetString(PyExc_TypeError, "Array dst must be 1D");
10861163
return NULL;
10871164
}
1088-
10891165
// passing a borrowed refs; returns a new ref
10901166
PyArray_Descr* dtype = AK_resolve_dtype(
10911167
PyArray_DESCR(array_src),
@@ -1133,12 +1209,12 @@ TriMap_map_merge(TriMapObject *tm, PyObject *args)
11331209
bool transfer_from_dst = PyArray_SIZE((PyArrayObject*)tm->final_src_fill) != 0;
11341210

11351211
if (dtype_is_obj) {
1136-
if (AK_TM_transfer_object(tm, true, array_src, array_to, false)) {
1212+
if (AK_TM_transfer_object(tm, true, array_src, array_to)) {
11371213
Py_DECREF((PyObject*)array_to);
11381214
return NULL;
11391215
}
11401216
if (transfer_from_dst) {
1141-
if (AK_TM_transfer_object(tm, false, array_dst, array_to, true)) {
1217+
if (AK_TM_transfer_object_if_null(tm, false, array_dst, array_to)) {
11421218
Py_DECREF((PyObject*)array_to);
11431219
return NULL;
11441220
}
@@ -1223,7 +1299,7 @@ AK_TM_map_fill(TriMapObject* tm,
12231299
}
12241300
// array_from, array_to inc refed and dec refed on error
12251301
if (dtype_is_obj) {
1226-
if (AK_TM_transfer_object(tm, from_src, array_from, array_to, false)) {
1302+
if (AK_TM_transfer_object(tm, from_src, array_from, array_to)) {
12271303
goto error;
12281304
}
12291305
if (AK_TM_fill_object(tm, from_src, array_to, fill_value)) {

test/test_tri_map.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,3 +1229,48 @@ def test_tri_map_merge_d(self) -> None:
12291229

12301230
post = tm.map_merge(src, dst)
12311231
self.assertEqual(post.tolist(), ['a', 'a', 'bbb', 'cc', 'cc', 'dddd', 'ee'])
1232+
1233+
1234+
1235+
def test_tri_map_merge_e(self) -> None:
1236+
src = np.array([None, False, -42, 'dddd'], dtype=object)
1237+
dst = np.array([-42, None, None, 'ee', -42], dtype=object)
1238+
1239+
tm = TriMap(len(src), len(dst))
1240+
tm.register_many(0, np.array([1, 2], dtype=np.dtype(np.int64)))
1241+
tm.register_one(1, -1)
1242+
tm.register_many(2, np.array([0, 4], dtype=np.dtype(np.int64)))
1243+
tm.register_one(3, -1)
1244+
tm.register_unmatched_dst()
1245+
tm.finalize()
1246+
1247+
post = tm.map_merge(src, dst)
1248+
self.assertEqual(post.tolist(), [None, None, False, -42, -42, 'dddd', 'ee'])
1249+
1250+
def test_tri_map_merge_f(self) -> None:
1251+
src = np.array([None, False, -42,], dtype=object)
1252+
dst = np.array([True, 'ee', 88], dtype=object)
1253+
1254+
tm = TriMap(len(src), len(dst))
1255+
tm.register_one(0, -1)
1256+
tm.register_one(1, -1)
1257+
tm.register_one(2, -1)
1258+
tm.register_unmatched_dst()
1259+
tm.finalize()
1260+
1261+
post = tm.map_merge(src, dst)
1262+
self.assertEqual(post.tolist(), [None, False, -42, True, 'ee', 88])
1263+
1264+
def test_tri_map_merge_g(self) -> None:
1265+
src = np.array([None, False, -42,], dtype=object)
1266+
dst = np.array([None, False, -42, 'ee', 'ff'], dtype=object)
1267+
1268+
tm = TriMap(len(src), len(dst))
1269+
tm.register_one(0, 0)
1270+
tm.register_one(1, 1)
1271+
tm.register_one(2, 2)
1272+
tm.register_unmatched_dst()
1273+
tm.finalize()
1274+
1275+
post = tm.map_merge(src, dst)
1276+
self.assertEqual(post.tolist(), [None, False, -42, 'ee', 'ff'])

0 commit comments

Comments
 (0)