Skip to content

Commit 144e0b5

Browse files
feat: modify pd.compare to compare with absolute and relative tolerance
Co-authored-by: Tomaz Silva <[email protected]>
1 parent b162331 commit 144e0b5

File tree

6 files changed

+230
-4
lines changed

6 files changed

+230
-4
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,14 @@ Other enhancements
3939
- Users can globally disable any ``PerformanceWarning`` by setting the option ``mode.performance_warnings`` to ``False`` (:issue:`56920`)
4040
- :meth:`Styler.format_index_names` can now be used to format the index and column names (:issue:`48936` and :issue:`47489`)
4141
- :class:`.errors.DtypeWarning` improved to include column names when mixed data types are detected (:issue:`58174`)
42+
- :meth:`DataFrame.compare` now supports comparing floating point values with tolerance (:issue:`58827`)
4243
- :meth:`DataFrame.corrwith` now accepts ``min_periods`` as optional arguments, as in :meth:`DataFrame.corr` and :meth:`Series.corr` (:issue:`9490`)
4344
- :meth:`DataFrame.cummin`, :meth:`DataFrame.cummax`, :meth:`DataFrame.cumprod` and :meth:`DataFrame.cumsum` methods now have a ``numeric_only`` parameter (:issue:`53072`)
4445
- :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`)
46+
- :meth:`Series.compare` now supports comparing floating point values with tolerance (:issue:`58827`)
4547
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
4648
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
49+
4750
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)
4851
-
4952

pandas/core/frame.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
is_integer_dtype,
102102
is_iterator,
103103
is_list_like,
104+
is_number,
104105
is_scalar,
105106
is_sequence,
106107
needs_i8_conversion,
@@ -8460,6 +8461,20 @@ def rpow(
84608461
2 b b 3.0 3.0 3.0 4.0
84618462
3 b b NaN NaN 4.0 4.0
84628463
4 a a 5.0 5.0 5.0 5.0
8464+
8465+
Compare dataframes with tolerance (float)
8466+
8467+
>>> df.compare(df2, atol=1)
8468+
col1
8469+
self other
8470+
0 a c
8471+
8472+
Compare dataframes with tolerance (dict)
8473+
8474+
>>> df.compare(df2, atol={{"col3": 1}})
8475+
col1
8476+
self other
8477+
0 a c
84638478
"""
84648479
),
84658480
klass=_shared_doc_kwargs["klass"],
@@ -8471,13 +8486,31 @@ def compare(
84718486
keep_shape: bool = False,
84728487
keep_equal: bool = False,
84738488
result_names: Suffixes = ("self", "other"),
8489+
check_exact: bool | lib.NoDefault = lib.no_default,
8490+
rtol: float | ListLike | dict | lib.NoDefault = lib.no_default,
8491+
atol: float | ListLike | dict | lib.NoDefault = lib.no_default,
84748492
) -> DataFrame:
8493+
if rtol is not lib.no_default:
8494+
if not (is_number(rtol) or is_dict_like(rtol) or is_list_like(rtol)):
8495+
raise TypeError(
8496+
f"rtol must be a number, list or dict, got {type(rtol)}"
8497+
)
8498+
8499+
if atol is not lib.no_default:
8500+
if not (is_number(atol) or is_dict_like(atol) or is_list_like(atol)):
8501+
raise TypeError(
8502+
f"atol must be a number, list or dict, got {type(atol)}"
8503+
)
8504+
84758505
return super().compare(
84768506
other=other,
84778507
align_axis=align_axis,
84788508
keep_shape=keep_shape,
84798509
keep_equal=keep_equal,
84808510
result_names=result_names,
8511+
check_exact=check_exact,
8512+
rtol=rtol,
8513+
atol=atol,
84818514
)
84828515

84838516
def combine(

pandas/core/generic.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
is_bool_dtype,
115115
is_dict_like,
116116
is_extension_array_dtype,
117+
is_float_dtype,
117118
is_list_like,
118119
is_number,
119120
is_numeric_dtype,
@@ -9205,17 +9206,64 @@ def compare(
92059206
keep_shape: bool = False,
92069207
keep_equal: bool = False,
92079208
result_names: Suffixes = ("self", "other"),
9208-
):
9209+
check_exact: bool | lib.NoDefault = lib.no_default,
9210+
rtol: float | ListLike | dict | lib.NoDefault = lib.no_default,
9211+
atol: float | ListLike | dict | lib.NoDefault = lib.no_default,
9212+
) -> DataFrame | Series:
9213+
if (
9214+
check_exact is lib.no_default
9215+
and rtol is lib.no_default
9216+
and atol is lib.no_default
9217+
):
9218+
check_exact = True
9219+
elif check_exact is lib.no_default: # tolerance is specified
9220+
check_exact = False
9221+
9222+
rtol = rtol if rtol is not lib.no_default else 1.0e-5
9223+
atol = atol if atol is not lib.no_default else 1.0e-8
9224+
92099225
if type(self) is not type(other):
92109226
cls_self, cls_other = type(self).__name__, type(other).__name__
92119227
raise TypeError(
92129228
f"can only compare '{cls_self}' (not '{cls_other}') with '{cls_self}'"
92139229
)
92149230

9215-
# error: Unsupported left operand type for & ("Self")
9216-
mask = ~((self == other) | (self.isna() & other.isna())) # type: ignore[operator]
9217-
mask.fillna(True, inplace=True)
9231+
if not check_exact:
9232+
if isinstance(self, ABCDataFrame):
9233+
mask = np.ones(self.shape, dtype=bool)
92189234

9235+
for i, col in enumerate(self.columns):
9236+
if is_dict_like(rtol):
9237+
r_tol = rtol.get(col, 1.0e-5)
9238+
elif is_list_like(rtol):
9239+
r_tol = rtol[i]
9240+
else:
9241+
r_tol = rtol
9242+
9243+
if is_dict_like(atol):
9244+
a_tol = atol.get(col, 1.0e-8)
9245+
elif is_list_like(atol):
9246+
a_tol = atol[i]
9247+
else:
9248+
a_tol = atol
9249+
9250+
if is_float_dtype(self[col]) and is_float_dtype(other[col]):
9251+
mask[:, self.columns.get_loc(col)] = np.isclose(
9252+
self[col], other[col], rtol=r_tol, atol=a_tol
9253+
)
9254+
else:
9255+
mask[:, self.columns.get_loc(col)] = self[col] == other[col]
9256+
# is series
9257+
else:
9258+
if is_float_dtype(self):
9259+
mask = np.isclose(self, other, rtol=rtol, atol=atol)
9260+
else:
9261+
mask = self == other
9262+
else:
9263+
mask = self == other
9264+
9265+
mask = ~(mask | (self.isna() & other.isna()))
9266+
mask.fillna(True, inplace=True)
92199267
if not keep_equal:
92209268
self = self.where(mask)
92219269
other = other.where(mask)
@@ -9229,6 +9277,7 @@ def compare(
92299277
else:
92309278
self = self[mask]
92319279
other = other[mask]
9280+
92329281
if not isinstance(result_names, tuple):
92339282
raise TypeError(
92349283
f"Passing 'result_names' as a {type(result_names)} is not "

pandas/core/series.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
is_integer,
7272
is_iterator,
7373
is_list_like,
74+
is_number,
7475
is_object_dtype,
7576
is_scalar,
7677
pandas_dtype,
@@ -2986,6 +2987,17 @@ def _append(
29862987
2 c c
29872988
3 d b
29882989
4 e e
2990+
2991+
Compare dataframes with tolerance
2992+
2993+
2994+
>>> s1 = pd.Series([1.0, 2.0])
2995+
>>> s2 = pd.Series([1.1, 2.2])
2996+
>>> s1.compare(s2, atol=0.1)
2997+
col1
2998+
self other
2999+
0 2.0 2.2
3000+
29893001
"""
29903002
),
29913003
klass=_shared_doc_kwargs["klass"],
@@ -2997,13 +3009,27 @@ def compare(
29973009
keep_shape: bool = False,
29983010
keep_equal: bool = False,
29993011
result_names: Suffixes = ("self", "other"),
3012+
check_exact: bool | lib.NoDefault = lib.no_default,
3013+
rtol: int | float | lib.NoDefault = lib.no_default,
3014+
atol: int | float | lib.NoDefault = lib.no_default,
30003015
) -> DataFrame | Series:
3016+
if rtol is not lib.no_default:
3017+
if not is_number(rtol):
3018+
raise TypeError(f"rtol must be a number, got {type(atol)}")
3019+
3020+
if atol is not lib.no_default:
3021+
if not is_number(atol):
3022+
raise TypeError(f"atol must be number, got {type(atol)}")
3023+
30013024
return super().compare(
30023025
other=other,
30033026
align_axis=align_axis,
30043027
keep_shape=keep_shape,
30053028
keep_equal=keep_equal,
30063029
result_names=result_names,
3030+
check_exact=check_exact,
3031+
rtol=rtol,
3032+
atol=atol,
30073033
)
30083034

30093035
def combine(

pandas/tests/frame/methods/test_compare.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
from pandas._libs import lib
45
from pandas.compat.numpy import np_version_gte1p25
56

67
import pandas as pd
@@ -214,6 +215,83 @@ def test_compare_result_names():
214215
tm.assert_frame_equal(result, expected)
215216

216217

218+
@pytest.mark.parametrize(
219+
"atol, rtol, check_exact, expected_self, expected_other",
220+
[
221+
(lib.no_default, lib.no_default, True, [1.0, 2.0, 4], [0.4, 1.6, 3.5]),
222+
(0, 0, False, [1.0, 2.0, 4], [0.4, 1.6, 3.5]),
223+
(0.5, 0, False, [1.0], [0.4]),
224+
(0, 0.5, False, [1.0], [0.4]),
225+
(0.5, 0.00000001, False, [1.0], [0.4]),
226+
(0.00000001, 0.5, False, [1.0], [0.4]),
227+
(lib.no_default, lib.no_default, False, [1.0, 2.0, 4], [0.4, 1.6, 3.5]),
228+
(0.5, lib.no_default, False, [1.0], [0.4]),
229+
(lib.no_default, 0.5, False, [1.0], [0.4]),
230+
("a", lib.no_default, False, None, None),
231+
],
232+
)
233+
def test_compare_tolerance_float(
234+
atol, rtol, check_exact, expected_self, expected_other
235+
):
236+
df1 = pd.DataFrame(
237+
{"col1": ["a", "b", "c"], "col2": [1.0, 2.0, np.nan], "col3": [1.0, 2.0, 4]}
238+
)
239+
240+
df2 = pd.DataFrame(
241+
{"col1": ["a", "b", "c"], "col2": [1.0, 2.0, np.nan], "col3": [0.4, 1.6, 3.5]}
242+
)
243+
244+
if expected_self is None:
245+
with pytest.raises(TypeError):
246+
df1.compare(df2, atol=atol, rtol=rtol, check_exact=check_exact)
247+
return
248+
249+
result = df1.compare(df2, atol=atol, rtol=rtol, check_exact=check_exact)
250+
251+
expected_data = {
252+
("col3", "self"): pd.Series(expected_self),
253+
("col3", "other"): pd.Series(expected_other),
254+
}
255+
256+
expected = pd.DataFrame(expected_data)
257+
258+
tm.assert_frame_equal(result, expected)
259+
260+
261+
@pytest.mark.parametrize(
262+
"atol, expected_self, expected_other",
263+
[
264+
([0.1, 0.2], [1.0, 2.0], [1.2, 2.2]),
265+
((0.1, 0.2), [1.0, 2.0], [1.2, 2.2]),
266+
({"col1": 0.1, "col2": 0.2}, [1.0, 2.0], [1.2, 2.2]),
267+
({"col2": 0.2}, [1.0, 2.0], [1.2, 2.2]),
268+
({"col1": "a"}, None, None),
269+
((0.1, "a"), None, None),
270+
([0.1, "a"], None, None),
271+
],
272+
)
273+
def test_compare_tolerance_dict_or_list(atol, expected_self, expected_other):
274+
df1 = pd.DataFrame({"col1": [1.0, 2.0], "col2": [3.0, 4.0]})
275+
276+
df2 = pd.DataFrame({"col1": [1.2, 2.2], "col2": [3.2, 4.2]})
277+
278+
if expected_self is None:
279+
with pytest.raises(TypeError):
280+
df1.compare(df2, atol=atol)
281+
return
282+
283+
result = df1.compare(df2, atol=atol)
284+
285+
expected_data = {
286+
("col1", "self"): pd.Series(expected_self),
287+
("col1", "other"): pd.Series(expected_other),
288+
}
289+
290+
expected = pd.DataFrame(expected_data)
291+
292+
tm.assert_frame_equal(result, expected)
293+
294+
217295
@pytest.mark.parametrize(
218296
"result_names",
219297
[

pandas/tests/series/methods/test_compare.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import pytest
33

4+
from pandas._libs import lib
5+
46
import pandas as pd
57
import pandas._testing as tm
68

@@ -115,6 +117,41 @@ def test_compare_different_lengths():
115117
ser1.compare(ser2)
116118

117119

120+
@pytest.mark.parametrize(
121+
"atol, rtol, check_exact, expected_self, expected_other",
122+
[
123+
(lib.no_default, lib.no_default, True, [1.0, 2.0, 4], [0.4, 1.6, 3.5]),
124+
(0, 0, False, [1.0, 2.0, 4], [0.4, 1.6, 3.5]),
125+
(0.5, 0, False, [1.0], [0.4]),
126+
(0, 0.5, False, [1.0], [0.4]),
127+
(0.5, 0.00000001, False, [1.0], [0.4]),
128+
(0.00000001, 0.5, False, [1.0], [0.4]),
129+
(lib.no_default, lib.no_default, False, [1.0, 2.0, 4], [0.4, 1.6, 3.5]),
130+
(0.5, lib.no_default, False, [1.0], [0.4]),
131+
(lib.no_default, 0.5, False, [1.0], [0.4]),
132+
("a", lib.no_default, False, None, None),
133+
],
134+
)
135+
def test_compare_tolerance_float(
136+
atol, rtol, check_exact, expected_self, expected_other
137+
):
138+
df1 = pd.Series([1.0, 2.0, 4])
139+
140+
df2 = pd.Series([0.4, 1.6, 3.5])
141+
142+
if expected_self is None:
143+
with pytest.raises(TypeError):
144+
df1.compare(df2, atol=atol, rtol=rtol, check_exact=check_exact)
145+
return
146+
147+
result = df1.compare(df2, atol=atol, rtol=rtol, check_exact=check_exact)
148+
149+
expected_data = {"self": expected_self, "other": expected_other}
150+
expected = pd.DataFrame(expected_data)
151+
152+
tm.assert_frame_equal(result, expected)
153+
154+
118155
def test_compare_datetime64_and_string():
119156
# Issue https://github.com/pandas-dev/pandas/issues/45506
120157
# Catch OverflowError when comparing datetime64 and string

0 commit comments

Comments
 (0)