-
-
Notifications
You must be signed in to change notification settings - Fork 18.6k
REF: do extract_array earlier in series arith/comparison ops #28066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
33487af
17f103c
edd49db
8859cf7
112f0f0
2cb1590
dec3e51
a0b4ffa
ef96214
0c86bec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,10 +34,11 @@ | |
ABCIndexClass, | ||
ABCSeries, | ||
ABCSparseSeries, | ||
ABCTimedeltaArray, | ||
ABCTimedeltaIndex, | ||
) | ||
from pandas.core.dtypes.missing import isna, notna | ||
|
||
import pandas as pd | ||
from pandas._typing import ArrayLike | ||
from pandas.core.construction import array, extract_array | ||
from pandas.core.ops.array_ops import comp_method_OBJECT_ARRAY, define_na_arithmetic_op | ||
|
@@ -148,6 +149,8 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]): | |
Be careful to call this *after* determining the `name` attribute to be | ||
attached to the result of the arithmetic operation. | ||
""" | ||
from pandas.core.arrays import TimedeltaArray | ||
|
||
if type(obj) is datetime.timedelta: | ||
# GH#22390 cast up to Timedelta to rely on Timedelta | ||
# implementation; otherwise operation against numeric-dtype | ||
|
@@ -157,12 +160,10 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]): | |
if isna(obj): | ||
# wrapping timedelta64("NaT") in Timedelta returns NaT, | ||
# which would incorrectly be treated as a datetime-NaT, so | ||
# we broadcast and wrap in a Series | ||
# we broadcast and wrap in a TimedeltaArray | ||
obj = obj.astype("timedelta64[ns]") | ||
right = np.broadcast_to(obj, shape) | ||
|
||
# Note: we use Series instead of TimedeltaIndex to avoid having | ||
# to worry about catching NullFrequencyError. | ||
return pd.Series(right) | ||
return TimedeltaArray(right) | ||
|
||
# In particular non-nanosecond timedelta64 needs to be cast to | ||
# nanoseconds, or else we get undesired behavior like | ||
|
@@ -173,7 +174,7 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]): | |
# GH#22390 Unfortunately we need to special-case right-hand | ||
# timedelta64 dtypes because numpy casts integer dtypes to | ||
# timedelta64 when operating with timedelta64 | ||
return pd.TimedeltaIndex(obj) | ||
return TimedeltaArray._from_sequence(obj) | ||
return obj | ||
|
||
|
||
|
@@ -520,13 +521,29 @@ def column_op(a, b): | |
return result | ||
|
||
|
||
def dispatch_to_extension_op(op, left, right): | ||
def dispatch_to_extension_op(op, left, right, keep_null_freq: bool = False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @simonjayhawkins do we have a way of typing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can u type left here (EA / np.ndarray) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
i'm not aware of being able to exclude types. if a particular type raises, (in this case TYpeError?) then maybe could use overloads with a return type of NoReturn https://mypy.readthedocs.io/en/latest/more_types.html#the-noreturn-type (New in version 3.5.4) could maybe use the following pattern to allow checking with older Python...
|
||
""" | ||
Assume that left or right is a Series backed by an ExtensionArray, | ||
apply the operator defined by op. | ||
|
||
Parameters | ||
---------- | ||
op : binary operator | ||
left : ExtensionArray or np.ndarray | ||
right : object | ||
keep_null_freq : bool, default False | ||
Whether to re-raise a NullFrequencyError unchanged, as opposed to | ||
catching and raising TypeError. | ||
|
||
Returns | ||
------- | ||
ExtensionArray or np.ndarray | ||
2-tuple of these if op is divmod or rdivmod | ||
""" | ||
# NB: left and right should already be unboxed, so neither should be | ||
# a Series or Index. | ||
|
||
if left.dtype.kind in "mM": | ||
if left.dtype.kind in "mM" and isinstance(left, np.ndarray): | ||
# We need to cast datetime64 and timedelta64 ndarrays to | ||
# DatetimeArray/TimedeltaArray. But we avoid wrapping others in | ||
# PandasArray as that behaves poorly with e.g. IntegerArray. | ||
|
@@ -535,15 +552,13 @@ def dispatch_to_extension_op(op, left, right): | |
# The op calls will raise TypeError if the op is not defined | ||
# on the ExtensionArray | ||
|
||
# unbox Series and Index to arrays | ||
new_left = extract_array(left, extract_numpy=True) | ||
new_right = extract_array(right, extract_numpy=True) | ||
|
||
try: | ||
res_values = op(new_left, new_right) | ||
res_values = op(left, right) | ||
except NullFrequencyError: | ||
# DatetimeIndex and TimedeltaIndex with freq == None raise ValueError | ||
# on add/sub of integers (or int-like). We re-raise as a TypeError. | ||
if keep_null_freq: | ||
raise | ||
raise TypeError( | ||
"incompatible type for a datetime/timedelta " | ||
"operation [{name}]".format(name=op.__name__) | ||
|
@@ -615,25 +630,29 @@ def wrapper(left, right): | |
if isinstance(right, ABCDataFrame): | ||
return NotImplemented | ||
|
||
keep_null_freq = isinstance( | ||
right, | ||
(ABCDatetimeIndex, ABCDatetimeArray, ABCTimedeltaIndex, ABCTimedeltaArray), | ||
) | ||
|
||
left, right = _align_method_SERIES(left, right) | ||
res_name = get_op_result_name(left, right) | ||
right = maybe_upcast_for_op(right, left.shape) | ||
|
||
if should_extension_dispatch(left, right): | ||
result = dispatch_to_extension_op(op, left, right) | ||
lvalues = extract_array(left, extract_numpy=True) | ||
rvalues = extract_array(right, extract_numpy=True) | ||
|
||
elif is_timedelta64_dtype(right) or isinstance( | ||
right, (ABCDatetimeArray, ABCDatetimeIndex) | ||
): | ||
# We should only get here with td64 right with non-scalar values | ||
# for right upcast by maybe_upcast_for_op | ||
assert not isinstance(right, (np.timedelta64, np.ndarray)) | ||
result = op(left._values, right) | ||
rvalues = maybe_upcast_for_op(rvalues, lvalues.shape) | ||
|
||
else: | ||
lvalues = extract_array(left, extract_numpy=True) | ||
rvalues = extract_array(right, extract_numpy=True) | ||
if should_extension_dispatch(left, rvalues): | ||
TomAugspurger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
result = dispatch_to_extension_op(op, lvalues, rvalues, keep_null_freq) | ||
|
||
elif is_timedelta64_dtype(rvalues) or isinstance(rvalues, ABCDatetimeArray): | ||
# We should only get here with td64 rvalues with non-scalar values | ||
# for rvalues upcast by maybe_upcast_for_op | ||
assert not isinstance(rvalues, (np.timedelta64, np.ndarray)) | ||
result = dispatch_to_extension_op(op, lvalues, rvalues, keep_null_freq) | ||
|
||
else: | ||
with np.errstate(all="ignore"): | ||
result = na_op(lvalues, rvalues) | ||
|
||
|
@@ -708,25 +727,25 @@ def wrapper(self, other, axis=None): | |
if len(self) != len(other): | ||
raise ValueError("Lengths must match to compare") | ||
|
||
if should_extension_dispatch(self, other): | ||
res_values = dispatch_to_extension_op(op, self, other) | ||
lvalues = extract_array(self, extract_numpy=True) | ||
rvalues = extract_array(other, extract_numpy=True) | ||
|
||
elif is_scalar(other) and isna(other): | ||
if should_extension_dispatch(lvalues, rvalues): | ||
res_values = dispatch_to_extension_op(op, lvalues, rvalues) | ||
|
||
elif is_scalar(rvalues) and isna(rvalues): | ||
# numpy does not like comparisons vs None | ||
if op is operator.ne: | ||
res_values = np.ones(len(self), dtype=bool) | ||
res_values = np.ones(len(lvalues), dtype=bool) | ||
else: | ||
res_values = np.zeros(len(self), dtype=bool) | ||
res_values = np.zeros(len(lvalues), dtype=bool) | ||
|
||
else: | ||
lvalues = extract_array(self, extract_numpy=True) | ||
rvalues = extract_array(other, extract_numpy=True) | ||
|
||
with np.errstate(all="ignore"): | ||
res_values = na_op(lvalues, rvalues) | ||
if is_scalar(res_values): | ||
raise TypeError( | ||
"Could not compare {typ} type with Series".format(typ=type(other)) | ||
"Could not compare {typ} type with Series".format(typ=type(rvalues)) | ||
) | ||
|
||
result = self._constructor(res_values, index=self.index) | ||
|
Uh oh!
There was an error while loading. Please reload this page.