Skip to content

Commit 2a37ffb

Browse files
committed
Allow index.map() to accept series and dictionary inputs in addition to functional inputs
1 parent c65a0f5 commit 2a37ffb

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

pandas/core/indexes/base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from pandas.compat.numpy import function as nv
1414
from pandas import compat
1515

16-
1716
from pandas.core.dtypes.generic import (
1817
ABCSeries,
1918
ABCMultiIndex,
@@ -2864,7 +2863,7 @@ def map(self, mapper):
28642863
28652864
Parameters
28662865
----------
2867-
mapper : callable
2866+
mapper : function, dict, or Series
28682867
Function to be applied.
28692868
28702869
Returns
@@ -2876,7 +2875,15 @@ def map(self, mapper):
28762875
28772876
"""
28782877
from .multi import MultiIndex
2879-
mapped_values = self._arrmap(self.values, mapper)
2878+
2879+
if isinstance(mapper, ABCSeries):
2880+
indexer = mapper.index.get_indexer(self._values)
2881+
mapped_values = algos.take_1d(mapper.values, indexer)
2882+
else:
2883+
if isinstance(mapper, dict):
2884+
mapper = mapper.get
2885+
mapped_values = self._arrmap(self._values, mapper)
2886+
28802887
attributes = self._get_attributes_dict()
28812888
if mapped_values.size and isinstance(mapped_values[0], tuple):
28822889
return MultiIndex.from_tuples(mapped_values,

pandas/tests/indexes/test_base.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,52 @@ def test_map_tseries_indices_return_index(self):
829829
exp = Index(range(24), name='hourly')
830830
tm.assert_index_equal(exp, date_index.map(lambda x: x.hour))
831831

832+
def test_map_with_series_all_indices(self):
833+
expected = Index(['foo', 'bar', 'baz'])
834+
mapper = Series(expected.values, index=[0, 1, 2])
835+
self.assert_index_equal(tm.makeIntIndex(3).map(mapper), expected)
836+
837+
# GH 12766
838+
# special = []
839+
special = ['catIndex']
840+
841+
for name in special:
842+
orig_values = ['a', 'B', 1, 'a']
843+
new_values = ['one', 2, 3.0, 'one']
844+
cur_index = CategoricalIndex(orig_values, name='XXX')
845+
mapper = pd.Series(new_values[:-1], index=orig_values[:-1])
846+
expected = CategoricalIndex(new_values, name='XXX')
847+
output = cur_index.map(mapper)
848+
self.assert_numpy_array_equal(expected.values.get_values(), output.values.get_values())
849+
self.assert_equal(expected.name, output.name)
850+
851+
852+
for name in list(set(self.indices.keys()) - set(special)):
853+
cur_index = self.indices[name]
854+
expected = Index(np.arange(len(cur_index), 0, -1))
855+
mapper = pd.Series(expected.values, index=cur_index)
856+
print(name)
857+
output = cur_index.map(mapper)
858+
self.assert_index_equal(expected, cur_index.map(mapper))
859+
860+
def test_map_with_categorical_series(self):
861+
# GH 12756
862+
a = Index([1, 2, 3, 4])
863+
b = Series(["even", "odd", "even", "odd"], dtype="category")
864+
c = Series(["even", "odd", "even", "odd"])
865+
866+
exp = CategoricalIndex(["odd", "even", "odd", np.nan])
867+
self.assert_index_equal(a.map(b), exp)
868+
exp = Index(["odd", "even", "odd", np.nan])
869+
self.assert_index_equal(a.map(c), exp)
870+
871+
def test_map_with_series_missing_values(self):
872+
# GH 12756
873+
expected = Index([2., np.nan, 'foo'])
874+
mapper = Series(['foo', 2., 'baz'], index=[0, 2, -1])
875+
output = Index([2, 1, 0]).map(mapper)
876+
self.assert_index_equal(output, expected)
877+
832878
def test_append_multiple(self):
833879
index = Index(['a', 'b', 'c', 'd', 'e', 'f'])
834880

pandas/tests/indexes/test_category.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,12 @@ def f(x):
244244
ordered=False)
245245
tm.assert_index_equal(result, exp)
246246

247+
result = ci.map(pd.Series([10, 20, 30], index=['A', 'B', 'C']))
248+
tm.assert_index_equal(result, exp)
249+
250+
result = ci.map({'A': 10, 'B': 20, 'C': 30})
251+
tm.assert_index_equal(result, exp)
252+
247253
def test_where(self):
248254
i = self.create_index()
249255
result = i.where(notna(i))

0 commit comments

Comments
 (0)