Skip to content

Commit 055d8bf

Browse files
committed
Consolidated group_any_all Cython func
1 parent 977275b commit 055d8bf

File tree

2 files changed

+33
-60
lines changed

2 files changed

+33
-60
lines changed

pandas/_libs/groupby.pyx

Lines changed: 28 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -312,58 +312,23 @@ def group_fillna_indexer(ndarray[int64_t] out, ndarray[int64_t] labels,
312312

313313
@cython.boundscheck(False)
314314
@cython.wraparound(False)
315-
def group_any(ndarray[uint8_t] out,
316-
ndarray[int64_t] labels,
317-
ndarray[uint8_t] values,
318-
ndarray[uint8_t] mask,
319-
bint skipna):
320-
"""Aggregated boolean values to show if any group element is truthful
315+
def group_any_all(ndarray[uint8_t] out,
316+
ndarray[int64_t] labels,
317+
ndarray[uint8_t] values,
318+
ndarray[uint8_t] mask,
319+
object val_test,
320+
bint skipna):
321+
"""Aggregated boolean values to show truthfulness of group elements
321322
322323
Parameters
323324
----------
324325
out : array of values which this method will write its results to
325-
labels : array containing unique label for each group, with its ordering
326-
matching up to the corresponding record in `values`
327-
values : array containing the truth value of each element
328-
mask : array indicating whether a value is na or not
329-
skipna : boolean
330-
Flag to ignore nan values during truth testing
331-
332-
Notes
333-
-----
334-
This method modifies the `out` parameter rather than returning an object.
335-
The returned values will either be 0 or 1 (False or True, respectively).
336-
"""
337-
cdef:
338-
Py_ssize_t i, N=len(labels)
339-
int64_t lab
340-
341-
with nogil:
342-
for i in range(N):
343-
lab = labels[i]
344-
if lab < 0 or (skipna and mask[i]):
345-
continue
346-
347-
if values[i]:
348-
out[lab] = 1
349-
350-
351-
@cython.boundscheck(False)
352-
@cython.wraparound(False)
353-
def group_all(ndarray[uint8_t] out,
354-
ndarray[int64_t] labels,
355-
ndarray[uint8_t] values,
356-
ndarray[uint8_t] mask,
357-
bint skipna):
358-
"""Aggregated boolean values to show if all group elements are truthful
359-
360-
Parameters
361-
----------
362-
out : array of values which this method will write its results to
363-
labels : array containing unique label for each group, with its ordering
364-
matching up to the corresponding record in `values`
326+
labels : array containing unique label for each group, with its
327+
ordering matching up to the corresponding record in `values`
365328
values : array containing the truth value of each element
366329
mask : array indicating whether a value is na or not
330+
val_test : str {'any', 'all'}
331+
String object dictating whether to use any or all truth testing
367332
skipna : boolean
368333
Flag to ignore nan values during truth testing
369334
@@ -374,23 +339,31 @@ def group_all(ndarray[uint8_t] out,
374339
"""
375340
cdef:
376341
Py_ssize_t i, N=len(labels)
377-
int64_t lab
378-
ndarray[int64_t] bool_mask
379-
ndarray[uint8_t] isna_mask
342+
int64_t lab, flag_val
343+
344+
if val_test == 'all':
345+
# Because the 'all' value of an empty iterable in Python is True we can
346+
# start with an array full of ones and set to zero when a False value
347+
# is encountered
348+
flag_val = 0
349+
elif val_test == 'any':
350+
# Because the 'any' value of an empty iterable in Python is False we
351+
# can start with an array full of zeros and set to one only if any
352+
# value encountered is True
353+
flag_val = 1
354+
else:
355+
raise ValueError("'bool_func' must be either 'any' or 'all'!")
380356

381-
# Because the 'all' value of an empty iterable in Python is True we can
382-
# start with an array full of ones and set to zero when a False value is
383-
# encountered
384-
out.fill(1)
357+
out.fill(1 - flag_val)
385358

386359
with nogil:
387360
for i in range(N):
388361
lab = labels[i]
389362
if lab < 0 or (skipna and mask[i]):
390363
continue
391364

392-
if not values[i]:
393-
out[lab] = 0
365+
if values[i] == flag_val:
366+
out[lab] = flag_val
394367

395368

396369
# generated from template

pandas/core/groupby.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,7 +1219,7 @@ class GroupBy(_GroupBy):
12191219
"""
12201220
_apply_whitelist = _common_apply_whitelist
12211221

1222-
def _bool_agg(self, how, skipna):
1222+
def _bool_agg(self, val_test, skipna):
12231223
"""Shared func to call any / all Cython GroupBy implementations"""
12241224

12251225
def objs_to_bool(vals):
@@ -1233,14 +1233,14 @@ def objs_to_bool(vals):
12331233
def result_to_bool(result):
12341234
return result.astype(np.bool, copy=False)
12351235

1236-
return self._get_cythonized_result(how, self.grouper,
1236+
return self._get_cythonized_result('group_any_all', self.grouper,
12371237
aggregate=True,
12381238
cython_dtype=np.uint8,
12391239
needs_values=True,
12401240
needs_mask=True,
12411241
pre_processing=objs_to_bool,
12421242
post_processing=result_to_bool,
1243-
skipna=skipna)
1243+
val_test=val_test, skipna=skipna)
12441244

12451245
@Substitution(name='groupby')
12461246
@Appender(_doc_template)
@@ -1252,7 +1252,7 @@ def any(self, skipna=True):
12521252
skipna : bool, default True
12531253
Flag to ignore nan values during truth testing
12541254
"""
1255-
return self._bool_agg('group_any', skipna)
1255+
return self._bool_agg('any', skipna)
12561256

12571257
@Substitution(name='groupby')
12581258
@Appender(_doc_template)
@@ -1264,7 +1264,7 @@ def all(self, skipna=True):
12641264
skipna : bool, default True
12651265
Flag to ignore nan values during truth testing
12661266
"""
1267-
return self._bool_agg('group_all', skipna)
1267+
return self._bool_agg('all', skipna)
12681268

12691269
@Substitution(name='groupby')
12701270
@Appender(_doc_template)

0 commit comments

Comments
 (0)