Skip to content

Commit f607881

Browse files
authored
Merge pull request #179 from static-frame/171/trimap-merge
2 parents 5dd5bcb + 7e3fe82 commit f607881

File tree

3 files changed

+381
-15
lines changed

3 files changed

+381
-15
lines changed

src/__init__.pyi

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,31 @@ class ArrayGO:
4545
def extend(self, __values: tp.Iterable[object]) -> None: ...
4646

4747
class TriMap:
48-
def __init__(self, src_len: int, dst_len: int) -> None: ...
48+
def __init__(self, /, src_len: int, dst_len: int) -> None: ...
4949
def __repr__(self) -> str: ...
50-
def register_one(self, src_from: int, dst_from: int) -> None: ...
50+
def register_one(self, /, src_from: int, dst_from: int) -> None: ...
5151
def register_unmatched_dst(self) -> None: ...
52-
def register_many(self, src_from: int, dst_from: np.ndarray) -> None: ...
52+
def register_many(self, /, src_from: int, dst_from: np.ndarray) -> None: ...
5353
def finalize(self) -> None: ...
5454
def is_many(self) -> bool: ...
5555
def src_no_fill(self) -> bool: ...
5656
def dst_no_fill(self) -> bool: ...
57-
def map_src_no_fill(self, array_from: np.ndarray) -> np.ndarray: ...
58-
def map_dst_no_fill(self, array_from: np.ndarray) -> np.ndarray: ...
59-
def map_src_fill(self,
57+
def map_src_no_fill(self, /, array_from: np.ndarray) -> np.ndarray: ...
58+
def map_dst_no_fill(self, /, array_from: np.ndarray) -> np.ndarray: ...
59+
def map_src_fill(self, /,
6060
array_from: np.ndarray,
6161
fill_value: tp.Any,
6262
fill_value_dtype: np.dtype
6363
) -> np.ndarray: ...
64-
def map_dst_fill(self,
64+
def map_dst_fill(self, /,
6565
array_from: np.ndarray,
6666
fill_value: tp.Any,
6767
fill_value_dtype: np.dtype
6868
) -> np.ndarray: ...
69+
def map_merge(self, /,
70+
array_from_src: np.ndarray,
71+
array_from_dst: np.ndarray,
72+
) -> np.ndarray: ...
6973

7074
class BlockIndex:
7175
shape: tp.Tuple[int, int]

src/tri_map.c

Lines changed: 194 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ TriMap_finalize(TriMapObject *self, PyObject *Py_UNUSED(unused)) {
401401

402402
npy_intp dims[] = {tm->len};
403403

404+
// initialize all to False
404405
final_src_match = PyArray_ZEROS(1, dims, NPY_BOOL, 0);
405406
if (final_src_match == NULL) {
406407
goto error;
@@ -772,7 +773,7 @@ AK_TM_transfer_object(TriMapObject* tm,
772773
// NOTE: could use PyArray_Scalar instead of PyArray_GETITEM if we wanted to store scalars instead of Python objects; however, that is pretty uncommon for object arrays to store PyArray_Scalars
773774
bool f_is_obj = PyArray_TYPE(array_from) == NPY_OBJECT;
774775

775-
// the passed in object array is assumed to be contiguous and have NULL (not None) in each position
776+
// the passed in object array is contiguous and have NULL (not None) in each position
776777
PyObject** array_to_data = (PyObject**)PyArray_DATA(array_to);
777778
PyObject* pyo;
778779
void* f;
@@ -811,7 +812,7 @@ AK_TM_transfer_object(TriMapObject* tm,
811812
Py_INCREF(pyo); // one more than we need
812813
*t++ = pyo;
813814
}
814-
Py_DECREF(pyo); // remove the extra one
815+
Py_DECREF(pyo); // remove the extra ref
815816
}
816817
else { // from_dst, dst is an array
817818
dst_pos = 0;
@@ -834,6 +835,92 @@ AK_TM_transfer_object(TriMapObject* tm,
834835
return 0;
835836
}
836837

838+
// Returns -1 on error. Specialized transfer from any type of an array to an object array. For usage with merge, Will only transfer if the destination is not NULL.
839+
static inline int
840+
AK_TM_transfer_object_if_null(TriMapObject* tm,
841+
bool from_src,
842+
PyArrayObject* array_from,
843+
PyArrayObject* array_to
844+
) {
845+
Py_ssize_t one_count = from_src ? tm->src_one_count : tm->dst_one_count;
846+
TriMapOne* one_pairs = from_src ? tm->src_one : tm->dst_one;
847+
848+
// NOTE: could use PyArray_Scalar instead of PyArray_GETITEM if we wanted to store scalars instead of Python objects; however, that is pretty uncommon for object arrays to store PyArray_Scalars
849+
bool f_is_obj = PyArray_TYPE(array_from) == NPY_OBJECT;
850+
851+
// the passed in object array is contiguous and have NULL (not None) in each position
852+
PyObject** array_to_data = (PyObject**)PyArray_DATA(array_to);
853+
PyObject* pyo;
854+
void* f;
855+
TriMapOne* o = one_pairs;
856+
TriMapOne* o_end = o + one_count;
857+
for (; o < o_end; o++) {
858+
if (array_to_data[o->to] == NULL) {
859+
f = PyArray_GETPTR1(array_from, o->from);
860+
if (f_is_obj) {
861+
pyo = *(PyObject**)f;
862+
Py_INCREF(pyo);
863+
}
864+
else { // will convert any value to an object
865+
pyo = PyArray_GETITEM(array_from, f);
866+
}
867+
array_to_data[o->to] = pyo;
868+
}
869+
}
870+
PyObject** t;
871+
PyObject** t_end;
872+
npy_intp dst_pos;
873+
npy_int64 f_pos;
874+
PyArrayObject* dst;
875+
for (Py_ssize_t i = 0; i < tm->many_count; i++) {
876+
t = array_to_data + tm->many_to[i].start;
877+
t_end = array_to_data + tm->many_to[i].stop;
878+
879+
if (from_src) {
880+
while (t < t_end) {
881+
if (*t == NULL) {
882+
f = PyArray_GETPTR1(array_from, tm->many_from[i].src);
883+
if (f_is_obj) {
884+
pyo = *(PyObject**)f;
885+
Py_INCREF(pyo);
886+
}
887+
else {
888+
pyo = PyArray_GETITEM(array_from, f); // given a new ref
889+
}
890+
*t++ = pyo;
891+
}
892+
else {
893+
t++;
894+
}
895+
}
896+
}
897+
else { // from_dst, dst is an array
898+
dst_pos = 0;
899+
dst = tm->many_from[i].dst;
900+
while (t < t_end) {
901+
if (*t == NULL) {
902+
f_pos = *(npy_int64*)PyArray_GETPTR1(dst, dst_pos);
903+
f = PyArray_GETPTR1(array_from, f_pos);
904+
if (f_is_obj) {
905+
pyo = *(PyObject**)f;
906+
Py_INCREF(pyo);
907+
}
908+
else {
909+
pyo = PyArray_GETITEM(array_from, f);
910+
}
911+
*t++ = pyo;
912+
dst_pos++;
913+
}
914+
else {
915+
t++;
916+
dst_pos++;
917+
}
918+
}
919+
}
920+
}
921+
return 0;
922+
}
923+
837924
// Returns -1 on error.
838925
static inline int
839926
AK_TM_fill_object(TriMapObject* tm,
@@ -855,7 +942,7 @@ AK_TM_fill_object(TriMapObject* tm,
855942
return 0;
856943
}
857944

858-
#define AK_TM_TRANSFER_FLEXIBLE(c_type) do { \
945+
#define AK_TM_TRANSFER_FLEXIBLE(c_type, from_src, array_from, array_to) do {\
859946
Py_ssize_t one_count = from_src ? tm->src_one_count : tm->dst_one_count;\
860947
TriMapOne* one_pairs = from_src ? tm->src_one : tm->dst_one; \
861948
npy_intp t_element_size = PyArray_ITEMSIZE(array_to); \
@@ -1003,10 +1090,10 @@ AK_TM_map_no_fill(TriMapObject* tm,
10031090
}
10041091
}
10051092
else if (dtype_is_unicode) {
1006-
AK_TM_TRANSFER_FLEXIBLE(Py_UCS4);
1093+
AK_TM_TRANSFER_FLEXIBLE(Py_UCS4, from_src, array_from, array_to);
10071094
}
10081095
else if (dtype_is_string) {
1009-
AK_TM_TRANSFER_FLEXIBLE(char);
1096+
AK_TM_TRANSFER_FLEXIBLE(char, from_src, array_from, array_to);
10101097
}
10111098
else {
10121099
if (AK_TM_transfer_scalar(tm, from_src, array_from, array_to)) {
@@ -1048,6 +1135,102 @@ TriMap_map_dst_no_fill(TriMapObject *self, PyObject *arg) {
10481135
return AK_TM_map_no_fill(self, from_src, array_from);
10491136
}
10501137

1138+
static inline PyObject *
1139+
TriMap_map_merge(TriMapObject *tm, PyObject *args)
1140+
{
1141+
// both are "from_" arrays
1142+
PyArrayObject* array_src;
1143+
PyArrayObject* array_dst;
1144+
1145+
if (!PyArg_ParseTuple(args,
1146+
"O!O!:map_merge",
1147+
&PyArray_Type, &array_src,
1148+
&PyArray_Type, &array_dst
1149+
)) {
1150+
return NULL;
1151+
}
1152+
if (!tm->finalized) {
1153+
PyErr_SetString(PyExc_RuntimeError, "Finalization is required");
1154+
return NULL;
1155+
}
1156+
if (!(PyArray_NDIM(array_src) == 1)) {
1157+
PyErr_SetString(PyExc_TypeError, "Array src must be 1D");
1158+
return NULL;
1159+
}
1160+
if (!(PyArray_NDIM(array_dst) == 1)) {
1161+
PyErr_SetString(PyExc_TypeError, "Array dst must be 1D");
1162+
return NULL;
1163+
}
1164+
// passing a borrowed refs; returns a new ref
1165+
PyArray_Descr* dtype = AK_resolve_dtype(
1166+
PyArray_DESCR(array_src),
1167+
PyArray_DESCR(array_dst));
1168+
bool dtype_is_obj = dtype->type_num == NPY_OBJECT;
1169+
bool dtype_is_unicode = dtype->type_num == NPY_UNICODE;
1170+
bool dtype_is_string = dtype->type_num == NPY_STRING;
1171+
1172+
npy_intp dims[] = {tm->len};
1173+
1174+
// create to array_to
1175+
PyArrayObject* array_to;
1176+
if (dtype_is_obj) {
1177+
Py_DECREF(dtype); // not needed
1178+
// will initialize to NULL, not None
1179+
array_to = (PyArrayObject*)PyArray_SimpleNew(1, dims, NPY_OBJECT);
1180+
}
1181+
else if (dtype_is_unicode || dtype_is_string) {
1182+
array_to = (PyArrayObject*)PyArray_Zeros(1, dims, dtype, 0); // steals dtype ref
1183+
}
1184+
else {
1185+
array_to = (PyArrayObject*)PyArray_Empty(1, dims, dtype, 0); // steals dtype ref
1186+
}
1187+
if (array_to == NULL) {
1188+
PyErr_SetNone(PyExc_MemoryError);
1189+
return NULL;
1190+
}
1191+
1192+
// if we have fill values in src, we need to transfer from dst
1193+
bool transfer_from_dst = PyArray_SIZE((PyArrayObject*)tm->final_src_fill) != 0;
1194+
1195+
if (dtype_is_obj) {
1196+
if (AK_TM_transfer_object(tm, true, array_src, array_to)) {
1197+
Py_DECREF((PyObject*)array_to);
1198+
return NULL;
1199+
}
1200+
if (transfer_from_dst) {
1201+
if (AK_TM_transfer_object_if_null(tm, false, array_dst, array_to)) {
1202+
Py_DECREF((PyObject*)array_to);
1203+
return NULL;
1204+
}
1205+
}
1206+
}
1207+
else if (dtype_is_unicode) {
1208+
AK_TM_TRANSFER_FLEXIBLE(Py_UCS4, true, array_src, array_to);
1209+
if (transfer_from_dst) {
1210+
AK_TM_TRANSFER_FLEXIBLE(Py_UCS4, false, array_dst, array_to);
1211+
}
1212+
}
1213+
else if (dtype_is_string) {
1214+
AK_TM_TRANSFER_FLEXIBLE(char, true, array_src, array_to);
1215+
if (transfer_from_dst) {
1216+
AK_TM_TRANSFER_FLEXIBLE(char, false, array_dst, array_to);
1217+
}
1218+
}
1219+
else {
1220+
if (AK_TM_transfer_scalar(tm, true, array_src, array_to)) {
1221+
Py_DECREF((PyObject*)array_to);
1222+
return NULL;
1223+
}
1224+
if (transfer_from_dst) {
1225+
if (AK_TM_transfer_scalar(tm, false, array_dst, array_to)) {
1226+
Py_DECREF((PyObject*)array_to);
1227+
return NULL;
1228+
}
1229+
}
1230+
}
1231+
return (PyObject*)array_to;
1232+
}
1233+
10511234
// Returns NULL on error.
10521235
static inline PyObject *
10531236
AK_TM_map_fill(TriMapObject* tm,
@@ -1108,19 +1291,19 @@ AK_TM_map_fill(TriMapObject* tm,
11081291
}
11091292
}
11101293
else if (dtype_is_unicode) {
1111-
AK_TM_TRANSFER_FLEXIBLE(Py_UCS4);
1294+
AK_TM_TRANSFER_FLEXIBLE(Py_UCS4, from_src, array_from, array_to);
11121295
if (AK_TM_fill_unicode(tm, from_src, array_to, fill_value)) {
11131296
goto error;
11141297
}
11151298
}
11161299
else if (dtype_is_string) {
1117-
AK_TM_TRANSFER_FLEXIBLE(char);
1300+
AK_TM_TRANSFER_FLEXIBLE(char, from_src, array_from, array_to);
11181301
if (AK_TM_fill_string(tm, from_src, array_to, fill_value)) {
11191302
goto error;
11201303
}
11211304
}
11221305
else {
1123-
// Most simple is to fill with scalar, then overwrite values as needed; for object and flexible dtypes this is not efficient; for object dtypes, this obbligates us to decref the filled value when assigning
1306+
// Most simple is to fill with scalar, then overwrite values as needed; for object and flexible dtypes this is not efficient; for object dtypes, this obligates us to decref the filled value when assigning
11241307
if (PyArray_FillWithScalar(array_to, fill_value)) { // -1 on error
11251308
goto error;
11261309
}
@@ -1180,6 +1363,8 @@ TriMap_map_dst_fill(TriMapObject *self, PyObject *args) {
11801363
return AK_TM_map_fill(self, from_src, array_from, fill_value, fill_value_dtype);
11811364
}
11821365

1366+
1367+
11831368
static PyMethodDef TriMap_methods[] = {
11841369
{"register_one", (PyCFunction)TriMap_register_one, METH_VARARGS, NULL},
11851370
{"register_unmatched_dst", (PyCFunction)TriMap_register_unmatched_dst, METH_NOARGS, NULL},
@@ -1192,6 +1377,7 @@ static PyMethodDef TriMap_methods[] = {
11921377
{"map_dst_no_fill", (PyCFunction)TriMap_map_dst_no_fill, METH_O, NULL},
11931378
{"map_src_fill", (PyCFunction)TriMap_map_src_fill, METH_VARARGS, NULL},
11941379
{"map_dst_fill", (PyCFunction)TriMap_map_dst_fill, METH_VARARGS, NULL},
1380+
{"map_merge", (PyCFunction)TriMap_map_merge, METH_VARARGS, NULL},
11951381
{NULL},
11961382
};
11971383

0 commit comments

Comments
 (0)