Skip to content

Commit 33f3acb

Browse files
committed
progress on map_merge
1 parent 3e75a99 commit 33f3acb

File tree

3 files changed

+117
-45
lines changed

3 files changed

+117
-45
lines changed

src/__init__.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,6 @@ class TriMap:
5656
def dst_no_fill(self) -> bool: ...
5757
def map_src_no_fill(self, /, array_from: np.ndarray) -> np.ndarray: ...
5858
def map_dst_no_fill(self, /, array_from: np.ndarray) -> np.ndarray: ...
59-
def map_merge_no_fill(self, /,
60-
array_from_src: np.ndarray,
61-
array_from_dst: np.ndarray,
62-
) -> np.ndarray: ...
6359
def map_src_fill(self, /,
6460
array_from: np.ndarray,
6561
fill_value: tp.Any,
@@ -70,6 +66,10 @@ class TriMap:
7066
fill_value: tp.Any,
7167
fill_value_dtype: np.dtype
7268
) -> np.ndarray: ...
69+
def map_merge(self, /,
70+
array_from_src: np.ndarray,
71+
array_from_dst: np.ndarray,
72+
) -> np.ndarray: ...
7373

7474
class BlockIndex:
7575
shape: tp.Tuple[int, int]

src/tri_map.c

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ TriMap_finalize(TriMapObject *self, PyObject *Py_UNUSED(unused)) {
401401

402402
npy_intp dims[] = {tm->len};
403403

404+
// initialize all to False
404405
final_src_match = PyArray_ZEROS(1, dims, NPY_BOOL, 0);
405406
if (final_src_match == NULL) {
406407
goto error;
@@ -855,7 +856,7 @@ AK_TM_fill_object(TriMapObject* tm,
855856
return 0;
856857
}
857858

858-
#define AK_TM_TRANSFER_FLEXIBLE(c_type) do { \
859+
#define AK_TM_TRANSFER_FLEXIBLE(c_type, from_src, array_from, array_to) do {\
859860
Py_ssize_t one_count = from_src ? tm->src_one_count : tm->dst_one_count;\
860861
TriMapOne* one_pairs = from_src ? tm->src_one : tm->dst_one; \
861862
npy_intp t_element_size = PyArray_ITEMSIZE(array_to); \
@@ -1003,10 +1004,10 @@ AK_TM_map_no_fill(TriMapObject* tm,
10031004
}
10041005
}
10051006
else if (dtype_is_unicode) {
1006-
AK_TM_TRANSFER_FLEXIBLE(Py_UCS4);
1007+
AK_TM_TRANSFER_FLEXIBLE(Py_UCS4, from_src, array_from, array_to);
10071008
}
10081009
else if (dtype_is_string) {
1009-
AK_TM_TRANSFER_FLEXIBLE(char);
1010+
AK_TM_TRANSFER_FLEXIBLE(char, from_src, array_from, array_to);
10101011
}
10111012
else {
10121013
if (AK_TM_transfer_scalar(tm, from_src, array_from, array_to)) {
@@ -1050,19 +1051,19 @@ TriMap_map_dst_no_fill(TriMapObject *self, PyObject *arg) {
10501051

10511052

10521053
static inline PyObject *
1053-
TriMap_map_merge_no_fill(TriMapObject *self, PyObject *args)
1054+
TriMap_map_merge(TriMapObject *tm, PyObject *args)
10541055
{
10551056
PyArrayObject* array_src;
10561057
PyArrayObject* array_dst;
10571058

10581059
if (!PyArg_ParseTuple(args,
1059-
"O!O!:map_merge_no_fill",
1060+
"O!O!:map_merge",
10601061
&PyArray_Type, &array_src,
10611062
&PyArray_Type, &array_dst
10621063
)) {
10631064
return NULL;
10641065
}
1065-
if (!self->finalized) {
1066+
if (!tm->finalized) {
10661067
PyErr_SetString(PyExc_RuntimeError, "Finalization is required");
10671068
return NULL;
10681069
}
@@ -1083,7 +1084,7 @@ TriMap_map_merge_no_fill(TriMapObject *self, PyObject *args)
10831084
bool dtype_is_unicode = dtype->type_num == NPY_UNICODE;
10841085
bool dtype_is_string = dtype->type_num == NPY_STRING;
10851086

1086-
npy_intp dims[] = {self->len};
1087+
npy_intp dims[] = {tm->len};
10871088

10881089
// create to array_to
10891090
PyArrayObject* array_to;
@@ -1118,7 +1119,41 @@ TriMap_map_merge_no_fill(TriMapObject *self, PyObject *args)
11181119
return NULL;
11191120
}
11201121

1121-
Py_RETURN_NONE;
1122+
// if we have fill values in src, we need to transfer from dst
1123+
bool transfer_from_dst = PyArray_SIZE((PyArrayObject*)tm->final_src_fill) != 0;
1124+
1125+
if (dtype_is_obj) {
1126+
if (AK_TM_transfer_object(tm, true, array_src, array_to)) {
1127+
Py_DECREF((PyObject*)array_to);
1128+
return NULL;
1129+
}
1130+
}
1131+
else if (dtype_is_unicode) {
1132+
AK_TM_TRANSFER_FLEXIBLE(Py_UCS4, true, array_src, array_to);
1133+
if (transfer_from_dst) {
1134+
AK_TM_TRANSFER_FLEXIBLE(Py_UCS4, false, array_dst, array_to);
1135+
}
1136+
}
1137+
else if (dtype_is_string) {
1138+
AK_TM_TRANSFER_FLEXIBLE(char, true, array_src, array_to);
1139+
if (transfer_from_dst) {
1140+
AK_TM_TRANSFER_FLEXIBLE(char, false, array_dst, array_to);
1141+
}
1142+
}
1143+
else {
1144+
if (AK_TM_transfer_scalar(tm, true, array_src, array_to)) {
1145+
Py_DECREF((PyObject*)array_to);
1146+
return NULL;
1147+
}
1148+
if (transfer_from_dst) {
1149+
if (AK_TM_transfer_scalar(tm, false, array_dst, array_to)) {
1150+
Py_DECREF((PyObject*)array_to);
1151+
return NULL;
1152+
}
1153+
}
1154+
}
1155+
1156+
return (PyObject*)array_to;
11221157

11231158
// // transfer values
11241159
// if (dtype_is_obj) {
@@ -1205,19 +1240,19 @@ AK_TM_map_fill(TriMapObject* tm,
12051240
}
12061241
}
12071242
else if (dtype_is_unicode) {
1208-
AK_TM_TRANSFER_FLEXIBLE(Py_UCS4);
1243+
AK_TM_TRANSFER_FLEXIBLE(Py_UCS4, from_src, array_from, array_to);
12091244
if (AK_TM_fill_unicode(tm, from_src, array_to, fill_value)) {
12101245
goto error;
12111246
}
12121247
}
12131248
else if (dtype_is_string) {
1214-
AK_TM_TRANSFER_FLEXIBLE(char);
1249+
AK_TM_TRANSFER_FLEXIBLE(char, from_src, array_from, array_to);
12151250
if (AK_TM_fill_string(tm, from_src, array_to, fill_value)) {
12161251
goto error;
12171252
}
12181253
}
12191254
else {
1220-
// Most simple is to fill with scalar, then overwrite values as needed; for object and flexible dtypes this is not efficient; for object dtypes, this obbligates us to decref the filled value when assigning
1255+
// Most simple is to fill with scalar, then overwrite values as needed; for object and flexible dtypes this is not efficient; for object dtypes, this obligates us to decref the filled value when assigning
12211256
if (PyArray_FillWithScalar(array_to, fill_value)) { // -1 on error
12221257
goto error;
12231258
}
@@ -1279,29 +1314,6 @@ TriMap_map_dst_fill(TriMapObject *self, PyObject *args) {
12791314

12801315

12811316

1282-
// PyObject *
1283-
// TriMap_map_merge_fill(TriMapObject *self, PyObject *args) {
1284-
// PyArrayObject* array_from;
1285-
// PyObject* fill_value;
1286-
// PyArray_Descr* fill_value_dtype;
1287-
// if (!PyArg_ParseTuple(args,
1288-
// "O!OO!:map_dst_fill",
1289-
// &PyArray_Type, &array_from,
1290-
// &fill_value,
1291-
// &PyArrayDescr_Type, &fill_value_dtype
1292-
// )) {
1293-
// return NULL;
1294-
// }
1295-
// if (!self->finalized) {
1296-
// PyErr_SetString(PyExc_RuntimeError, "Finalization is required");
1297-
// return NULL;
1298-
// }
1299-
// bool from_src = false;
1300-
// return AK_TM_map_fill(self, from_src, array_from, fill_value, fill_value_dtype);
1301-
// }
1302-
1303-
1304-
13051317
static PyMethodDef TriMap_methods[] = {
13061318
{"register_one", (PyCFunction)TriMap_register_one, METH_VARARGS, NULL},
13071319
{"register_unmatched_dst", (PyCFunction)TriMap_register_unmatched_dst, METH_NOARGS, NULL},
@@ -1312,9 +1324,9 @@ static PyMethodDef TriMap_methods[] = {
13121324
{"dst_no_fill", (PyCFunction)TriMap_dst_no_fill, METH_NOARGS, NULL},
13131325
{"map_src_no_fill", (PyCFunction)TriMap_map_src_no_fill, METH_O, NULL},
13141326
{"map_dst_no_fill", (PyCFunction)TriMap_map_dst_no_fill, METH_O, NULL},
1315-
{"map_merge_no_fill", (PyCFunction)TriMap_map_merge_no_fill, METH_VARARGS, NULL},
13161327
{"map_src_fill", (PyCFunction)TriMap_map_src_fill, METH_VARARGS, NULL},
13171328
{"map_dst_fill", (PyCFunction)TriMap_map_dst_fill, METH_VARARGS, NULL},
1329+
{"map_merge", (PyCFunction)TriMap_map_merge, METH_VARARGS, NULL},
13181330
{NULL},
13191331
};
13201332

test/test_tri_map.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,7 @@ def test_tri_map_map_dt64_d(self) -> None:
11401140

11411141
#---------------------------------------------------------------------------
11421142

1143-
def test_tri_map_map_c(self) -> None:
1143+
def test_tri_map_merge_a(self) -> None:
11441144
src = np.array([0, 200, 300, 400, 0], dtype=np.int64)
11451145
dst = np.array([300, 400, 0, 200, 300], dtype=np.int64)
11461146

@@ -1152,20 +1152,80 @@ def test_tri_map_map_c(self) -> None:
11521152
tm.register_one(4, 2)
11531153

11541154
with self.assertRaises(RuntimeError):
1155-
_ = tm.map_merge_no_fill(src, dst)
1155+
_ = tm.map_merge(src, dst)
11561156

11571157
tm.finalize()
11581158

11591159
with self.assertRaises(TypeError):
1160-
_ = tm.map_merge_no_fill(3, dst)
1160+
_ = tm.map_merge(3, dst)
11611161

11621162
with self.assertRaises(TypeError):
1163-
_ = tm.map_merge_no_fill(src, 3)
1163+
_ = tm.map_merge(src, 3)
11641164

11651165
with self.assertRaises(TypeError):
1166-
_ = tm.map_merge_no_fill(src.reshape(5, 1), dst)
1166+
_ = tm.map_merge(src.reshape(5, 1), dst)
11671167

11681168
with self.assertRaises(TypeError):
1169-
_ = tm.map_merge_no_fill(src, dst.reshape(5, 1))
1169+
_ = tm.map_merge(src, dst.reshape(5, 1))
1170+
1171+
1172+
def test_tri_map_merge_a(self) -> None:
1173+
src = np.array([0, 200, 300, 400, 0], dtype=np.int64)
1174+
dst = np.array([300, 400, 0, 200, 300, 50, 50], dtype=np.int64)
1175+
1176+
tm = TriMap(len(src), len(dst))
1177+
tm.register_one(0, 2)
1178+
tm.register_one(1, 3)
1179+
tm.register_many(2, np.array([0, 4], dtype=np.dtype(np.int64)))
1180+
tm.register_one(3, 1)
1181+
tm.register_one(4, 2)
1182+
tm.register_unmatched_dst()
1183+
tm.finalize()
1184+
1185+
post = tm.map_merge(src, dst)
1186+
self.assertEqual(post.tolist(), [0, 200, 300, 300, 400, 0, 50, 50])
1187+
1188+
def test_tri_map_merge_b(self) -> None:
1189+
src = np.array([0, 200, 300, 400], dtype=np.int64)
1190+
dst = np.array([50, 80, 200, 300, 0, 200, 300, 70, 80], dtype=np.int32)
1191+
1192+
tm = TriMap(len(src), len(dst))
1193+
tm.register_one(0, 4)
1194+
tm.register_many(1, np.array([2, 5], dtype=np.dtype(np.int64)))
1195+
tm.register_many(2, np.array([3, 6], dtype=np.dtype(np.int64)))
1196+
tm.register_one(3, -1)
1197+
tm.register_unmatched_dst()
1198+
tm.finalize()
1199+
1200+
post = tm.map_merge(src, dst)
1201+
self.assertEqual(post.tolist(), [0, 200, 200, 300, 300, 400, 50, 80, 70, 80])
1202+
1203+
def test_tri_map_merge_c(self) -> None:
1204+
src = np.array([0, 200, 300, 400], dtype=np.int64)
1205+
dst = np.array([400, 200, 300], dtype=np.int64)
1206+
1207+
tm = TriMap(len(src), len(dst))
1208+
tm.register_one(0, -1)
1209+
tm.register_one(1, 1)
1210+
tm.register_one(2, 2)
1211+
tm.register_one(3, 0)
1212+
tm.register_unmatched_dst()
1213+
tm.finalize()
11701214

1215+
post = tm.map_merge(src, dst)
1216+
self.assertEqual(post.tolist(), [0, 200, 300, 400])
1217+
1218+
def test_tri_map_merge_d(self) -> None:
1219+
src = np.array(['a', 'bbb', 'cc', 'dddd'])
1220+
dst = np.array(['cc', 'a', 'a', 'ee', 'cc'])
1221+
1222+
tm = TriMap(len(src), len(dst))
1223+
tm.register_many(0, np.array([1, 2], dtype=np.dtype(np.int64)))
1224+
tm.register_one(1, -1)
1225+
tm.register_many(2, np.array([0, 4], dtype=np.dtype(np.int64)))
1226+
tm.register_one(3, -1)
1227+
tm.register_unmatched_dst()
1228+
tm.finalize()
11711229

1230+
post = tm.map_merge(src, dst)
1231+
self.assertEqual(post.tolist(), ['a', 'a', 'bbb', 'cc', 'cc', 'dddd', 'ee'])

0 commit comments

Comments
 (0)