Skip to content

Commit d600dfe

Browse files
committed
Allow index.map() to accept series and dictionary inputs in addition to functional inputs
1 parent 19fc8da commit d600dfe

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

pandas/core/indexes/base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from pandas.compat.numpy import function as nv
1313
from pandas import compat
1414

15-
1615
from pandas.core.dtypes.generic import ABCSeries, ABCMultiIndex, ABCPeriodIndex
1716
from pandas.core.dtypes.missing import isnull, array_equivalent
1817
from pandas.core.dtypes.common import (
@@ -2711,7 +2710,7 @@ def map(self, mapper):
27112710
27122711
Parameters
27132712
----------
2714-
mapper : callable
2713+
mapper : function, dict, or Series
27152714
Function to be applied.
27162715
27172716
Returns
@@ -2723,7 +2722,15 @@ def map(self, mapper):
27232722
27242723
"""
27252724
from .multi import MultiIndex
2726-
mapped_values = self._arrmap(self.values, mapper)
2725+
2726+
if isinstance(mapper, ABCSeries):
2727+
indexer = mapper.index.get_indexer(self._values)
2728+
mapped_values = algos.take_1d(mapper.values, indexer)
2729+
else:
2730+
if isinstance(mapper, dict):
2731+
mapper = mapper.get
2732+
mapped_values = self._arrmap(self._values, mapper)
2733+
27272734
attributes = self._get_attributes_dict()
27282735
if mapped_values.size and isinstance(mapped_values[0], tuple):
27292736
return MultiIndex.from_tuples(mapped_values,

pandas/tests/indexes/test_base.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,52 @@ def test_map_tseries_indices_return_index(self):
822822
exp = Index(range(24), name='hourly')
823823
tm.assert_index_equal(exp, date_index.map(lambda x: x.hour))
824824

825+
def test_map_with_series_all_indices(self):
826+
expected = Index(['foo', 'bar', 'baz'])
827+
mapper = Series(expected.values, index=[0, 1, 2])
828+
self.assert_index_equal(tm.makeIntIndex(3).map(mapper), expected)
829+
830+
# GH 12766
831+
# special = []
832+
special = ['catIndex']
833+
834+
for name in special:
835+
orig_values = ['a', 'B', 1, 'a']
836+
new_values = ['one', 2, 3.0, 'one']
837+
cur_index = CategoricalIndex(orig_values, name='XXX')
838+
mapper = pd.Series(new_values[:-1], index=orig_values[:-1])
839+
expected = CategoricalIndex(new_values, name='XXX')
840+
output = cur_index.map(mapper)
841+
self.assert_numpy_array_equal(expected.values.get_values(), output.values.get_values())
842+
self.assert_equal(expected.name, output.name)
843+
844+
845+
for name in list(set(self.indices.keys()) - set(special)):
846+
cur_index = self.indices[name]
847+
expected = Index(np.arange(len(cur_index), 0, -1))
848+
mapper = pd.Series(expected.values, index=cur_index)
849+
print(name)
850+
output = cur_index.map(mapper)
851+
self.assert_index_equal(expected, cur_index.map(mapper))
852+
853+
def test_map_with_categorical_series(self):
854+
# GH 12756
855+
a = Index([1, 2, 3, 4])
856+
b = Series(["even", "odd", "even", "odd"], dtype="category")
857+
c = Series(["even", "odd", "even", "odd"])
858+
859+
exp = CategoricalIndex(["odd", "even", "odd", np.nan])
860+
self.assert_index_equal(a.map(b), exp)
861+
exp = Index(["odd", "even", "odd", np.nan])
862+
self.assert_index_equal(a.map(c), exp)
863+
864+
def test_map_with_series_missing_values(self):
865+
# GH 12756
866+
expected = Index([2., np.nan, 'foo'])
867+
mapper = Series(['foo', 2., 'baz'], index=[0, 2, -1])
868+
output = Index([2, 1, 0]).map(mapper)
869+
self.assert_index_equal(output, expected)
870+
825871
def test_append_multiple(self):
826872
index = Index(['a', 'b', 'c', 'd', 'e', 'f'])
827873

pandas/tests/indexes/test_category.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,12 @@ def f(x):
234234
ordered=False)
235235
tm.assert_index_equal(result, exp)
236236

237+
result = ci.map(pd.Series([10, 20, 30], index=['A', 'B', 'C']))
238+
tm.assert_index_equal(result, exp)
239+
240+
result = ci.map({'A': 10, 'B': 20, 'C': 30})
241+
tm.assert_index_equal(result, exp)
242+
237243
def test_where(self):
238244
i = self.create_index()
239245
result = i.where(notnull(i))

0 commit comments

Comments
 (0)