Skip to content

Commit b028827

Browse files
committed
Refactor with maybe_convert_for_categorical
1 parent b80cff8 commit b028827

File tree

5 files changed

+96
-35
lines changed

5 files changed

+96
-35
lines changed

doc/source/io.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,8 @@ Individual columns can be parsed as a ``Categorical`` using a dict specification
469469
470470
pd.read_csv(StringIO(data), dtype={'col1': 'category'}).dtypes
471471
472+
.. versionadded:: 0.21.0
473+
472474
Specifying ``dtype='cateogry'`` will result in an unordered ``Categorical``
473475
whose ``categories`` are the unique values observed in the data. For more
474476
control on the categories and order, create a

pandas/_libs/parsers.pyx

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ from pandas.core.dtypes.common import (
4848
from pandas.core.categorical import Categorical, _recode_for_categories
4949
from pandas.core.algorithms import take_1d
5050
from pandas.core.dtypes.concat import union_categoricals
51-
from pandas import Index, to_numeric, to_datetime, to_timedelta
51+
from pandas.core.dtypes.cast import maybe_convert_for_categorical
52+
from pandas import Index
5253

5354
import pandas.io.common as com
5455

@@ -1274,19 +1275,7 @@ cdef class TextReader:
12741275
na_hashset, self.c_encoding)
12751276
cats = Index(cats)
12761277

1277-
# Determine if we should convert inferred string
1278-
# categories to a specialized type
1279-
if (isinstance(dtype, CategoricalDtype) and
1280-
dtype.categories is not None):
1281-
if dtype.categories.is_numeric():
1282-
# is ignore correct?
1283-
cats = to_numeric(cats, errors='ignore')
1284-
elif dtype.categories.is_all_dates:
1285-
# is ignore correct?
1286-
if is_datetime64_dtype(dtype.categories):
1287-
cats = to_datetime(cats, errors='ignore')
1288-
else:
1289-
cats = to_timedelta(cats, errors='ignore')
1278+
cats = maybe_convert_for_categorical(cats, dtype)
12901279

12911280
if (isinstance(dtype, CategoricalDtype) and
12921281
dtype.categories is not None):
@@ -1298,8 +1287,7 @@ cdef class TextReader:
12981287
# sort categories and recode if necessary
12991288
unsorted = cats.copy()
13001289
categories = cats.sort_values()
1301-
indexer = categories.get_indexer(unsorted)
1302-
codes = take_1d(indexer, codes, fill_value=-1)
1290+
codes = _recode_for_categories(codes, unsorted, categories)
13031291
ordered = False
13041292
else:
13051293
categories = cats

pandas/core/dtypes/cast.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
_ensure_int32, _ensure_int64,
2525
_NS_DTYPE, _TD_DTYPE, _INT64_DTYPE,
2626
_POSSIBLY_CAST_DTYPES)
27-
from .dtypes import ExtensionDtype, DatetimeTZDtype, PeriodDtype
27+
from .dtypes import (ExtensionDtype, DatetimeTZDtype, PeriodDtype,
28+
CategoricalDtype)
2829
from .generic import (ABCDatetimeIndex, ABCPeriodIndex,
2930
ABCSeries)
3031
from .missing import isna, notna
@@ -604,6 +605,41 @@ def conv(r, dtype):
604605
return [conv(r, dtype) for r, dtype in zip(result, dtypes)]
605606

606607

608+
def maybe_convert_for_categorical(categories, dtype):
609+
"""Convert ``categories`` depending on ``dtype``.
610+
611+
Converts to numeric, datetime, or timedelta types, when ``dtype`` is
612+
a CategoricalDtype with known, non-object categories.
613+
614+
Parameters
615+
----------
616+
categories : array-like
617+
type : CategoricalDtype
618+
619+
Returns
620+
-------
621+
new_categories : array or Index
622+
623+
Examples
624+
--------
625+
>>> maybe_convert_for_categorical(['1', '2'], CategoricalDtype([1, 2]))
626+
array([ 1, 2])
627+
>>> maybe_convert_for_categorical([1, 'a'], CategoricalDtype([1, 2]))
628+
array([ 1., nan])
629+
"""
630+
if isinstance(dtype, CategoricalDtype) and dtype.categories is not None:
631+
from pandas import to_numeric, to_datetime, to_timedelta
632+
633+
if dtype.categories.is_numeric():
634+
categories = to_numeric(categories, errors='coerce')
635+
elif is_datetime64_dtype(dtype.categories):
636+
categories = to_datetime(categories, errors='coerce')
637+
elif is_timedelta64_dtype(dtype.categories):
638+
categories = to_timedelta(categories, errors='coerce')
639+
640+
return categories
641+
642+
607643
def astype_nansafe(arr, dtype, copy=True):
608644
""" return a view if copy is False, but
609645
need to be very careful as the result shape could change! """

pandas/io/parsers.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@
1212

1313
import numpy as np
1414

15-
from pandas import compat, to_numeric, to_timedelta
15+
from pandas import compat
1616
from pandas.compat import (range, lrange, PY3, StringIO, lzip,
1717
zip, string_types, map, u)
1818
from pandas.core.dtypes.common import (
1919
is_integer, _ensure_object,
2020
is_list_like, is_integer_dtype,
2121
is_float, is_dtype_equal,
2222
is_object_dtype, is_string_dtype,
23-
is_scalar, is_categorical_dtype,
24-
is_datetime64_dtype, is_timedelta64_dtype)
23+
is_scalar, is_categorical_dtype)
2524
from pandas.core.dtypes.dtypes import CategoricalDtype
2625
from pandas.core.dtypes.missing import isna
27-
from pandas.core.dtypes.cast import astype_nansafe
26+
from pandas.core.dtypes.cast import (astype_nansafe,
27+
maybe_convert_for_categorical)
2828
from pandas.core.index import (Index, MultiIndex, RangeIndex,
2929
_ensure_index_from_sequences)
3030
from pandas.core.series import Series
@@ -1609,21 +1609,16 @@ def _cast_types(self, values, cast_type, column):
16091609
# as strings
16101610
known_cats = (isinstance(cast_type, CategoricalDtype) and
16111611
cast_type.categories is not None)
1612-
str_values = is_object_dtype(values)
1613-
1614-
if known_cats and str_values:
1615-
if cast_type.categories.is_numeric():
1616-
values = to_numeric(values, errors='ignore')
1617-
elif is_datetime64_dtype(cast_type.categories):
1618-
values = tools.to_datetime(values, errors='ignore')
1619-
elif is_timedelta64_dtype(cast_type.categories):
1620-
values = to_timedelta(values, errors='ignore')
1621-
values = Categorical(values, categories=cast_type.categories,
1622-
ordered=cast_type.ordered)
1612+
1613+
categories = ordered = None
1614+
if known_cats:
1615+
values = maybe_convert_for_categorical(values, cast_type)
1616+
categories = cast_type.categories
1617+
ordered = cast_type.ordered
16231618
elif not is_object_dtype(values):
16241619
values = astype_nansafe(values, str)
1625-
else:
1626-
values = Categorical(values)
1620+
values = Categorical(values, categories=categories,
1621+
ordered=ordered)
16271622
else:
16281623
try:
16291624
values = astype_nansafe(values, cast_type, copy=True)

pandas/tests/dtypes/test_cast.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pandas.core.dtypes.cast import (
1717
maybe_downcast_to_dtype,
1818
maybe_convert_objects,
19+
maybe_convert_for_categorical,
1920
cast_scalar_to_array,
2021
infer_dtype_from_scalar,
2122
infer_dtype_from_array,
@@ -25,7 +26,8 @@
2526
from pandas.core.dtypes.dtypes import (
2627
CategoricalDtype,
2728
DatetimeTZDtype,
28-
PeriodDtype)
29+
PeriodDtype,
30+
CategoricalDtype)
2931
from pandas.core.dtypes.common import (
3032
is_dtype_equal)
3133
from pandas.util import testing as tm
@@ -299,6 +301,44 @@ def test_maybe_infer_to_datetimelike(self):
299301
[NaT, 'b', 1]]))
300302
assert result.size == 6
301303

304+
def test_maybe_convert_for_categorical_noop(self):
305+
expected = ['1', '2']
306+
result = maybe_convert_for_categorical(expected, None)
307+
assert result == expected
308+
309+
result = maybe_convert_for_categorical(expected, CategoricalDtype())
310+
assert result == expected
311+
312+
result = maybe_convert_for_categorical(expected, 'category')
313+
assert result == expected
314+
315+
@pytest.mark.parametrize('categories, dtype, expected', [
316+
(['1', '2'], [1, 2, 3], np.array([1, 2])),
317+
(['1', '2', 'a'], [1, 2, 3], np.array([1, 2, np.nan])),
318+
])
319+
def test_maybe_convert_for_categorical(self, categories, dtype, expected):
320+
dtype = CategoricalDtype(dtype)
321+
result = maybe_convert_for_categorical(categories, dtype)
322+
tm.assert_numpy_array_equal(result, expected)
323+
324+
@pytest.mark.parametrize('categories, dtype, expected', [
325+
(['2016', '2017'], pd.to_datetime(['2016', '2017']),
326+
pd.to_datetime(['2016', '2017'])),
327+
(['2016', '2017', 'bad'], pd.to_datetime(['2016', '2017']),
328+
pd.to_datetime(['2016', '2017', 'NaT'])),
329+
330+
(['1H', '2H'], pd.to_timedelta(['1H', '2H']),
331+
pd.to_timedelta(['1H', '2H'])),
332+
(['1H', '2H', 'bad'], pd.to_timedelta(['1H', '2H']),
333+
pd.to_timedelta(['1H', '2H', 'NaT'])),
334+
335+
])
336+
def test_maybe_convert_for_categorical_dates(self, categories, dtype,
337+
expected):
338+
dtype = CategoricalDtype(dtype)
339+
result = maybe_convert_for_categorical(categories, dtype)
340+
tm.assert_index_equal(result, expected)
341+
302342

303343
class TestConvert(object):
304344

0 commit comments

Comments
 (0)