|
10 | 10 |
|
11 | 11 | import numpy as np
|
12 | 12 |
|
| 13 | +from pandas._libs import lib |
13 | 14 | from pandas._libs.missing import is_matching_na
|
14 | 15 | from pandas._libs.sparse import SparseIndex
|
15 | 16 | import pandas._libs.testing as _testing
|
@@ -811,14 +812,14 @@ def assert_series_equal(
|
811 | 812 | check_index_type: bool | Literal["equiv"] = "equiv",
|
812 | 813 | check_series_type: bool = True,
|
813 | 814 | check_names: bool = True,
|
814 |
| - check_exact: bool = False, |
| 815 | + check_exact: bool | lib.NoDefault = lib.no_default, |
815 | 816 | check_datetimelike_compat: bool = False,
|
816 | 817 | check_categorical: bool = True,
|
817 | 818 | check_category_order: bool = True,
|
818 | 819 | check_freq: bool = True,
|
819 | 820 | check_flags: bool = True,
|
820 |
| - rtol: float = 1.0e-5, |
821 |
| - atol: float = 1.0e-8, |
| 821 | + rtol: float | lib.NoDefault = lib.no_default, |
| 822 | + atol: float | lib.NoDefault = lib.no_default, |
822 | 823 | obj: str = "Series",
|
823 | 824 | *,
|
824 | 825 | check_index: bool = True,
|
@@ -877,6 +878,25 @@ def assert_series_equal(
|
877 | 878 | >>> tm.assert_series_equal(a, b)
|
878 | 879 | """
|
879 | 880 | __tracebackhide__ = True
|
| 881 | + if ( |
| 882 | + check_exact is lib.no_default |
| 883 | + and rtol is lib.no_default |
| 884 | + and atol is lib.no_default |
| 885 | + ): |
| 886 | + if ( |
| 887 | + is_numeric_dtype(left.dtype) |
| 888 | + and not is_float_dtype(left.dtype) |
| 889 | + or is_numeric_dtype(right.dtype) |
| 890 | + and not is_float_dtype(right.dtype) |
| 891 | + ): |
| 892 | + check_exact = True |
| 893 | + else: |
| 894 | + check_exact = False |
| 895 | + elif check_exact is lib.no_default: |
| 896 | + check_exact = False |
| 897 | + |
| 898 | + rtol = rtol if rtol is not lib.no_default else 1.0e-5 |
| 899 | + atol = atol if atol is not lib.no_default else 1.0e-8 |
880 | 900 |
|
881 | 901 | if not check_index and check_like:
|
882 | 902 | raise ValueError("check_like must be False if check_index is False")
|
@@ -931,10 +951,7 @@ def assert_series_equal(
|
931 | 951 | pass
|
932 | 952 | else:
|
933 | 953 | assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}")
|
934 |
| - if check_exact or ( |
935 |
| - (is_numeric_dtype(left.dtype) and not is_float_dtype(left.dtype)) |
936 |
| - or (is_numeric_dtype(right.dtype) and not is_float_dtype(right.dtype)) |
937 |
| - ): |
| 954 | + if check_exact: |
938 | 955 | left_values = left._values
|
939 | 956 | right_values = right._values
|
940 | 957 | # Only check exact if dtype is numeric
|
|
0 commit comments