Skip to content

REF: update decimal tests to TestExtension #54455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 71 additions & 131 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,35 +65,71 @@ def data_for_grouping():
return DecimalArray([b, b, na, na, a, a, b, c])


class TestDtype(base.BaseDtypeTests):
pass
class TestDecimalArray(base.ExtensionTests):
def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
return None

def _supports_reduction(self, obj, op_name: str) -> bool:
return True

class TestInterface(base.BaseInterfaceTests):
pass
def check_reduce(self, s, op_name, skipna):
if op_name == "count":
return super().check_reduce(s, op_name, skipna)
else:
result = getattr(s, op_name)(skipna=skipna)
expected = getattr(np.asarray(s), op_name)()
tm.assert_almost_equal(result, expected)

def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request):
if all_numeric_reductions in ["kurt", "skew", "sem", "median"]:
mark = pytest.mark.xfail(raises=NotImplementedError)
request.node.add_marker(mark)
super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)

class TestConstructors(base.BaseConstructorsTests):
pass
def test_reduce_frame(self, data, all_numeric_reductions, skipna, request):
op_name = all_numeric_reductions
if op_name in ["skew", "median"]:
mark = pytest.mark.xfail(raises=NotImplementedError)
request.node.add_marker(mark)

return super().test_reduce_frame(data, all_numeric_reductions, skipna)

class TestReshaping(base.BaseReshapingTests):
pass
def test_compare_scalar(self, data, comparison_op):
ser = pd.Series(data)
self._compare_other(ser, data, comparison_op, 0.5)

def test_compare_array(self, data, comparison_op):
ser = pd.Series(data)

class TestGetitem(base.BaseGetitemTests):
def test_take_na_value_other_decimal(self):
arr = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
result = arr.take([0, -1], allow_fill=True, fill_value=decimal.Decimal("-1.0"))
expected = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("-1.0")])
tm.assert_extension_array_equal(result, expected)
alter = np.random.default_rng(2).choice([-1, 0, 1], len(data))
# Randomly double, halve or keep same value
other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter]
self._compare_other(ser, data, comparison_op, other)

def test_arith_series_with_array(self, data, all_arithmetic_operators):
op_name = all_arithmetic_operators
ser = pd.Series(data)

context = decimal.getcontext()
divbyzerotrap = context.traps[decimal.DivisionByZero]
invalidoptrap = context.traps[decimal.InvalidOperation]
context.traps[decimal.DivisionByZero] = 0
context.traps[decimal.InvalidOperation] = 0

class TestIndex(base.BaseIndexTests):
pass
# Decimal supports ops with int, but not float
other = pd.Series([int(d * 100) for d in data])
self.check_opname(ser, op_name, other)

if "mod" not in op_name:
self.check_opname(ser, op_name, ser * 2)

self.check_opname(ser, op_name, 0)
self.check_opname(ser, op_name, 5)
context.traps[decimal.DivisionByZero] = divbyzerotrap
context.traps[decimal.InvalidOperation] = invalidoptrap

class TestMissing(base.BaseMissingTests):
def test_fillna_frame(self, data_missing):
msg = "ExtensionArray.fillna added a 'copy' keyword"
with tm.assert_produces_warning(
Expand Down Expand Up @@ -141,59 +177,6 @@ def test_fillna_series_method(self, data_missing, fillna_method):
):
super().test_fillna_series_method(data_missing, fillna_method)


class Reduce:
def _supports_reduction(self, obj, op_name: str) -> bool:
return True

def check_reduce(self, s, op_name, skipna):
if op_name == "count":
return super().check_reduce(s, op_name, skipna)
else:
result = getattr(s, op_name)(skipna=skipna)
expected = getattr(np.asarray(s), op_name)()
tm.assert_almost_equal(result, expected)

def test_reduction_without_keepdims(self):
# GH52788
# test _reduce without keepdims

class DecimalArray2(DecimalArray):
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
# no keepdims in signature
return super()._reduce(name, skipna=skipna)

arr = DecimalArray2([decimal.Decimal(2) for _ in range(100)])

ser = pd.Series(arr)
result = ser.agg("sum")
expected = decimal.Decimal(200)
assert result == expected

df = pd.DataFrame({"a": arr, "b": arr})
with tm.assert_produces_warning(FutureWarning):
result = df.agg("sum")
expected = pd.Series({"a": 200, "b": 200}, dtype=object)
tm.assert_series_equal(result, expected)


class TestReduce(Reduce, base.BaseReduceTests):
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request):
if all_numeric_reductions in ["kurt", "skew", "sem", "median"]:
mark = pytest.mark.xfail(raises=NotImplementedError)
request.node.add_marker(mark)
super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)

def test_reduce_frame(self, data, all_numeric_reductions, skipna, request):
op_name = all_numeric_reductions
if op_name in ["skew", "median"]:
mark = pytest.mark.xfail(raises=NotImplementedError)
request.node.add_marker(mark)

return super().test_reduce_frame(data, all_numeric_reductions, skipna)


class TestMethods(base.BaseMethodsTests):
def test_fillna_copy_frame(self, data_missing, using_copy_on_write):
warn = FutureWarning if not using_copy_on_write else None
msg = "ExtensionArray.fillna added a 'copy' keyword"
Expand Down Expand Up @@ -226,27 +209,31 @@ def test_value_counts(self, all_data, dropna, request):

tm.assert_series_equal(result, expected)


class TestCasting(base.BaseCastingTests):
pass


class TestGroupby(base.BaseGroupbyTests):
pass


class TestSetitem(base.BaseSetitemTests):
pass


class TestPrinting(base.BasePrintingTests):
def test_series_repr(self, data):
# Overriding this base test to explicitly test that
# the custom _formatter is used
ser = pd.Series(data)
assert data.dtype.name in repr(ser)
assert "Decimal: " in repr(ser)

@pytest.mark.xfail(
reason="Looks like the test (incorrectly) implicitly assumes int/bool dtype"
)
def test_invert(self, data):
super().test_invert(data)

@pytest.mark.xfail(reason="Inconsistent array-vs-scalar behavior")
@pytest.mark.parametrize("ufunc", [np.positive, np.negative, np.abs])
def test_unary_ufunc_dunder_equivalence(self, data, ufunc):
super().test_unary_ufunc_dunder_equivalence(data, ufunc)


def test_take_na_value_other_decimal():
arr = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
result = arr.take([0, -1], allow_fill=True, fill_value=decimal.Decimal("-1.0"))
expected = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("-1.0")])
tm.assert_extension_array_equal(result, expected)


def test_series_constructor_coerce_data_to_extension_dtype():
dtype = DecimalDtype()
Expand Down Expand Up @@ -305,53 +292,6 @@ def test_astype_dispatches(frame):
assert result.dtype.context.prec == ctx.prec


class TestArithmeticOps(base.BaseArithmeticOpsTests):
series_scalar_exc = None
frame_scalar_exc = None
series_array_exc = None

def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
return None

def test_arith_series_with_array(self, data, all_arithmetic_operators):
op_name = all_arithmetic_operators
s = pd.Series(data)

context = decimal.getcontext()
divbyzerotrap = context.traps[decimal.DivisionByZero]
invalidoptrap = context.traps[decimal.InvalidOperation]
context.traps[decimal.DivisionByZero] = 0
context.traps[decimal.InvalidOperation] = 0

# Decimal supports ops with int, but not float
other = pd.Series([int(d * 100) for d in data])
self.check_opname(s, op_name, other)

if "mod" not in op_name:
self.check_opname(s, op_name, s * 2)

self.check_opname(s, op_name, 0)
self.check_opname(s, op_name, 5)
context.traps[decimal.DivisionByZero] = divbyzerotrap
context.traps[decimal.InvalidOperation] = invalidoptrap


class TestComparisonOps(base.BaseComparisonOpsTests):
def test_compare_scalar(self, data, comparison_op):
s = pd.Series(data)
self._compare_other(s, data, comparison_op, 0.5)

def test_compare_array(self, data, comparison_op):
s = pd.Series(data)

alter = np.random.default_rng(2).choice([-1, 0, 1], len(data))
# Randomly double, halve or keep same value
other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter]
self._compare_other(s, data, comparison_op, other)


class DecimalArrayWithoutFromSequence(DecimalArray):
"""Helper class for testing error handling in _from_sequence."""

Expand Down