Skip to content

Commit f1a7f2a

Browse files
committed
ENH: enhance apply() on Panel for arbitrary functions (rather than just ufuncs) (GH1148)
ENH: add property dtypes on Panel objects
1 parent ebd1db4 commit f1a7f2a

File tree

4 files changed

+204
-36
lines changed

4 files changed

+204
-36
lines changed

doc/source/release.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ Improvements to existing features
7373
- df.info() view now display dtype info per column (:issue: `5682`)
7474
- perf improvements in DataFrame ``count/dropna`` for ``axis=1``
7575
- Series.str.contains now has a `regex=False` keyword which can be faster for plain (non-regex) string patterns. (:issue: `5879`)
76+
- support ``dtypes`` on ``Panel``
77+
- extend ``Panel.apply`` to allow arbitrary functions (rather than only ufuncs) (:issue:`1148`)
7678

7779
.. _release.bug_fixes-0.13.1:
7880

pandas/core/panel.py

Lines changed: 139 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from pandas.core.internals import (BlockManager,
1818
create_block_manager_from_arrays,
1919
create_block_manager_from_blocks)
20+
from pandas.core.series import Series
2021
from pandas.core.frame import DataFrame
2122
from pandas.core.generic import NDFrame, _shared_docs
23+
from pandas.tools.util import cartesian_product
2224
from pandas import compat
2325
from pandas.util.decorators import deprecate, Appender, Substitution
2426
import pandas.core.common as com
@@ -333,26 +335,34 @@ def axis_pretty(a):
333335
[class_name, dims] + [axis_pretty(a) for a in self._AXIS_ORDERS])
334336
return output
335337

336-
def _get_plane_axes(self, axis):
338+
def _get_plane_axes_index(self, axis):
337339
"""
338-
Get my plane axes: these are already
340+
Get my plane axes indexes: these are already
339341
(as compared with higher level planes),
340-
as we are returning a DataFrame axes
342+
as we are returning a DataFrame axes indexes
341343
"""
342-
axis = self._get_axis_name(axis)
344+
axis_name = self._get_axis_name(axis)
343345

344-
if axis == 'major_axis':
345-
index = self.minor_axis
346-
columns = self.items
347-
if axis == 'minor_axis':
348-
index = self.major_axis
349-
columns = self.items
350-
elif axis == 'items':
351-
index = self.major_axis
352-
columns = self.minor_axis
346+
if axis_name == 'major_axis':
347+
index = 'minor_axis'
348+
columns = 'items'
349+
if axis_name == 'minor_axis':
350+
index = 'major_axis'
351+
columns = 'items'
352+
elif axis_name == 'items':
353+
index = 'major_axis'
354+
columns = 'minor_axis'
353355

354356
return index, columns
355357

358+
def _get_plane_axes(self, axis):
359+
"""
360+
Get my plane axes indexes: these are already
361+
(as compared with higher level planes),
362+
as we are returning a DataFrame axes
363+
"""
364+
return [ self._get_axis(axi) for axi in self._get_plane_axes_index(axis) ]
365+
356366
fromDict = from_dict
357367

358368
def to_sparse(self, fill_value=None, kind='block'):
@@ -431,6 +441,10 @@ def as_matrix(self):
431441
self._consolidate_inplace()
432442
return self._data.as_matrix()
433443

444+
@property
445+
def dtypes(self):
446+
return self.apply(lambda x: x.dtype, axis='items')
447+
434448
#----------------------------------------------------------------------
435449
# Getting and setting elements
436450

@@ -827,25 +841,104 @@ def to_frame(self, filter_observations=True):
827841
to_long = deprecate('to_long', to_frame)
828842
toLong = deprecate('toLong', to_frame)
829843

830-
def apply(self, func, axis='major'):
844+
def apply(self, func, axis='major', args=(), **kwargs):
831845
"""
832-
Apply
846+
Applies function along input axis of the Panel
833847
834848
Parameters
835849
----------
836-
func : numpy function
837-
Signature should match numpy.{sum, mean, var, std} etc.
850+
func : function
851+
Function to apply to each combination of 'other' axes
852+
e.g. if axis = 'items', then the combination of major_axis/minor_axis
853+
will be passed a Series
838854
axis : {'major', 'minor', 'items'}
839-
fill_value : boolean, default True
840-
Replace NaN values with specified first
855+
args : tuple
856+
Positional arguments to pass to function in addition to the
857+
array/series
858+
Additional keyword arguments will be passed as keywords to the function
859+
860+
Examples
861+
--------
862+
>>> p.apply(numpy.sqrt) # returns a Panel
863+
>>> p.apply(lambda x: x.sum(), axis=0) # equiv to p.sum(0)
864+
>>> p.apply(lambda x: x.sum(), axis=1) # equiv to p.sum(1)
865+
>>> p.apply(lambda x: x.sum(), axis=2) # equiv to p.sum(2)
841866
842867
Returns
843868
-------
844-
result : DataFrame or Panel
869+
result : Pandas Object
845870
"""
846-
i = self._get_axis_number(axis)
847-
result = np.apply_along_axis(func, i, self.values)
848-
return self._wrap_result(result, axis=axis)
871+
axis = self._get_axis_number(axis)
872+
axis_name = self._get_axis_name(axis)
873+
ax = self._get_axis(axis)
874+
values = self.values
875+
ndim = self.ndim
876+
877+
if args or kwargs and not isinstance(func, np.ufunc):
878+
f = lambda x: func(x, *args, **kwargs)
879+
else:
880+
f = func
881+
882+
# try ufunc like
883+
if isinstance(f, np.ufunc):
884+
try:
885+
result = np.apply_along_axis(func, axis, values)
886+
return self._wrap_result(result, axis=axis)
887+
except (AttributeError):
888+
pass
889+
890+
# iter thru the axes
891+
slice_axis = self._get_axis(axis)
892+
slice_indexer = [0]*(ndim-1)
893+
indexer = np.zeros(ndim, 'O')
894+
indlist = list(range(ndim))
895+
indlist.remove(axis)
896+
indexer[axis] = slice(None, None)
897+
indexer.put(indlist, slice_indexer)
898+
planes = [ self._get_axis(axi) for axi in indlist ]
899+
shape = np.array(self.shape).take(indlist)
900+
901+
# all the iteration points
902+
points = cartesian_product(planes)
903+
904+
results = []
905+
for i in xrange(np.prod(shape)):
906+
907+
# construct the object
908+
pts = tuple([ p[i] for p in points ])
909+
indexer.put(indlist, slice_indexer)
910+
911+
obj = Series(values[tuple(indexer)],index=slice_axis,name=pts)
912+
result = func(obj, *args, **kwargs)
913+
914+
results.append(result)
915+
916+
# increment the indexer
917+
slice_indexer[-1] += 1
918+
n = -1
919+
while (slice_indexer[n] >= shape[n]) and (n > (1-ndim)):
920+
slice_indexer[n-1] += 1
921+
slice_indexer[n] = 0
922+
n -= 1
923+
924+
# empty object
925+
if not len(results):
926+
return self._constructor(**self._construct_axes_dict())
927+
928+
# same ndim as current
929+
if isinstance(results[0],Series):
930+
arr = np.vstack([ r.values for r in results ])
931+
arr = arr.T.reshape(tuple([len(slice_axis)] + list(shape)))
932+
tranp = np.array([axis]+indlist).argsort()
933+
arr = arr.transpose(tuple(list(tranp)))
934+
return self._constructor(arr,**self._construct_axes_dict())
935+
936+
# ndim-1 shape
937+
results = np.array(results).reshape(shape)
938+
if results.ndim == 2 and axis_name != self._info_axis_name:
939+
results = results.T
940+
planes = planes[::-1]
941+
return self._construct_return_type(results,planes)
849942

850943
def _reduce(self, op, axis=0, skipna=True, numeric_only=None,
851944
filter_type=None, **kwds):
@@ -863,21 +956,41 @@ def _reduce(self, op, axis=0, skipna=True, numeric_only=None,
863956

864957
def _construct_return_type(self, result, axes=None, **kwargs):
865958
""" return the type for the ndim of the result """
866-
ndim = result.ndim
867-
if self.ndim == ndim:
959+
ndim = getattr(result,'ndim',None)
960+
961+
# need to assume they are the same
962+
if ndim is None:
963+
if isinstance(result,dict):
964+
ndim = getattr(result.values()[0],'ndim',None)
965+
966+
# a saclar result
967+
if ndim is None:
968+
ndim = 0
969+
970+
# have a dict, so top-level is +1 dim
971+
else:
972+
ndim += 1
973+
974+
# scalar
975+
if ndim == 0:
976+
return Series(result)
977+
978+
# same as self
979+
elif self.ndim == ndim:
868980
""" return the construction dictionary for these axes """
869981
if axes is None:
870982
return self._constructor(result)
871983
return self._constructor(result, **self._construct_axes_dict())
872984

985+
# sliced
873986
elif self.ndim == ndim + 1:
874987
if axes is None:
875988
return self._constructor_sliced(result)
876989
return self._constructor_sliced(
877990
result, **self._extract_axes_for_slice(self, axes))
878991

879992
raise PandasError('invalid _construct_return_type [self->%s] '
880-
'[result->%s]' % (self.ndim, result.ndim))
993+
'[result->%s]' % (self, result))
881994

882995
def _wrap_result(self, result, axis):
883996
axis = self._get_axis_name(axis)

pandas/core/panelnd.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@ def __init__(self, *args, **kwargs):
5656
self._init_data(*args, **kwargs)
5757
klass.__init__ = __init__
5858

59-
def _get_plane_axes(self, axis):
59+
def _get_plane_axes_index(self, axis):
60+
""" return the sliced index for this object """
6061

61-
axis = self._get_axis_name(axis)
62+
axis_name = self._get_axis_name(axis)
6263
index = self._AXIS_ORDERS.index(axis)
6364

6465
planes = []
@@ -67,8 +68,8 @@ def _get_plane_axes(self, axis):
6768
if index != self._AXIS_LEN:
6869
planes.extend(self._AXIS_ORDERS[index + 1:])
6970

70-
return [getattr(self, p) for p in planes]
71-
klass._get_plane_axes = _get_plane_axes
71+
return planes
72+
klass._get_plane_axes_index = _get_plane_axes_index
7273

7374
def _combine(self, other, func, axis=0):
7475
if isinstance(other, klass):

pandas/tests/test_panel.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,64 @@ def test_convert_objects(self):
10611061
result = p.convert_objects(convert_numeric='force')
10621062
assert_panel_equal(result, expected)
10631063

1064+
def test_dtypes(self):
1065+
1066+
result = self.panel.dtypes
1067+
expected = DataFrame(np.dtype('float64'),index=self.panel.major_axis,columns=self.panel.minor_axis)
1068+
assert_frame_equal(result, expected)
1069+
1070+
def test_apply(self):
1071+
# GH1148
1072+
1073+
from pandas import Series,DataFrame
1074+
1075+
# ufunc
1076+
applied = self.panel.apply(np.sqrt)
1077+
self.assert_(assert_almost_equal(applied.values,
1078+
np.sqrt(self.panel.values)))
1079+
1080+
# ufunc same shape
1081+
result = self.panel.apply(lambda x: x*2, axis='items')
1082+
expected = self.panel*2
1083+
assert_panel_equal(result, expected)
1084+
result = self.panel.apply(lambda x: x*2, axis='major_axis')
1085+
expected = self.panel*2
1086+
assert_panel_equal(result, expected)
1087+
result = self.panel.apply(lambda x: x*2, axis='minor_axis')
1088+
expected = self.panel*2
1089+
assert_panel_equal(result, expected)
1090+
1091+
# reduction to DataFrame
1092+
result = self.panel.apply(lambda x: x.dtype, axis='items')
1093+
expected = DataFrame(np.dtype('float64'),index=self.panel.major_axis,columns=self.panel.minor_axis)
1094+
assert_frame_equal(result,expected)
1095+
result = self.panel.apply(lambda x: x.dtype, axis='major_axis')
1096+
expected = DataFrame(np.dtype('float64'),index=self.panel.minor_axis,columns=self.panel.items)
1097+
assert_frame_equal(result,expected)
1098+
result = self.panel.apply(lambda x: x.dtype, axis='minor_axis')
1099+
expected = DataFrame(np.dtype('float64'),index=self.panel.major_axis,columns=self.panel.items)
1100+
assert_frame_equal(result,expected)
1101+
1102+
# reductions via other dims
1103+
expected = self.panel.sum(0)
1104+
result = self.panel.apply(lambda x: x.sum(), axis='items')
1105+
assert_frame_equal(result,expected)
1106+
expected = self.panel.sum(1)
1107+
result = self.panel.apply(lambda x: x.sum(), axis='major_axis')
1108+
assert_frame_equal(result,expected)
1109+
expected = self.panel.sum(2)
1110+
result = self.panel.apply(lambda x: x.sum(), axis='minor_axis')
1111+
assert_frame_equal(result,expected)
1112+
1113+
# pass args
1114+
result = self.panel.apply(lambda x, y: x.sum() + y, axis='items', args=[5])
1115+
expected = self.panel.sum(0) + 5
1116+
assert_frame_equal(result,expected)
1117+
1118+
result = self.panel.apply(lambda x, y: x.sum() + y, axis='items', y=5)
1119+
expected = self.panel.sum(0) + 5
1120+
assert_frame_equal(result,expected)
1121+
10641122
def test_reindex(self):
10651123
ref = self.panel['ItemB']
10661124

@@ -1989,12 +2047,6 @@ def test_get_dummies(self):
19892047
dummies = get_dummies(self.panel['Label'])
19902048
self.assert_(np.array_equal(dummies.values, minor_dummies.values))
19912049

1992-
def test_apply(self):
1993-
# ufunc
1994-
applied = self.panel.apply(np.sqrt)
1995-
self.assert_(assert_almost_equal(applied.values,
1996-
np.sqrt(self.panel.values)))
1997-
19982050
def test_mean(self):
19992051
means = self.panel.mean(level='minor')
20002052

0 commit comments

Comments
 (0)