Skip to content

Commit d23a5f8

Browse files
authored
TST: Parameterize test_rank.py (#45114)
1 parent 61e0eef commit d23a5f8

File tree

2 files changed

+188
-164
lines changed

2 files changed

+188
-164
lines changed

pandas/tests/frame/test_arithmetic.py

Lines changed: 82 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -94,64 +94,92 @@ def test_frame_in_list(self):
9494
with pytest.raises(ValueError, match=msg):
9595
df in [None]
9696

97-
def test_comparison_invalid(self):
98-
def check(df, df2):
99-
100-
for (x, y) in [(df, df2), (df2, df)]:
101-
# we expect the result to match Series comparisons for
102-
# == and !=, inequalities should raise
103-
result = x == y
104-
expected = DataFrame(
105-
{col: x[col] == y[col] for col in x.columns},
106-
index=x.index,
107-
columns=x.columns,
108-
)
109-
tm.assert_frame_equal(result, expected)
110-
111-
result = x != y
112-
expected = DataFrame(
113-
{col: x[col] != y[col] for col in x.columns},
114-
index=x.index,
115-
columns=x.columns,
116-
)
117-
tm.assert_frame_equal(result, expected)
118-
119-
msgs = [
120-
r"Invalid comparison between dtype=datetime64\[ns\] and ndarray",
121-
"invalid type promotion",
122-
(
123-
# npdev 1.20.0
124-
r"The DTypes <class 'numpy.dtype\[.*\]'> and "
125-
r"<class 'numpy.dtype\[.*\]'> do not have a common DType."
126-
),
127-
]
128-
msg = "|".join(msgs)
129-
with pytest.raises(TypeError, match=msg):
130-
x >= y
131-
with pytest.raises(TypeError, match=msg):
132-
x > y
133-
with pytest.raises(TypeError, match=msg):
134-
x < y
135-
with pytest.raises(TypeError, match=msg):
136-
x <= y
137-
97+
@pytest.mark.parametrize(
98+
"arg, arg2",
99+
[
100+
[
101+
{
102+
"a": np.random.randint(10, size=10),
103+
"b": pd.date_range("20010101", periods=10),
104+
},
105+
{
106+
"a": np.random.randint(10, size=10),
107+
"b": np.random.randint(10, size=10),
108+
},
109+
],
110+
[
111+
{
112+
"a": np.random.randint(10, size=10),
113+
"b": np.random.randint(10, size=10),
114+
},
115+
{
116+
"a": np.random.randint(10, size=10),
117+
"b": pd.date_range("20010101", periods=10),
118+
},
119+
],
120+
[
121+
{
122+
"a": pd.date_range("20010101", periods=10),
123+
"b": pd.date_range("20010101", periods=10),
124+
},
125+
{
126+
"a": np.random.randint(10, size=10),
127+
"b": np.random.randint(10, size=10),
128+
},
129+
],
130+
[
131+
{
132+
"a": np.random.randint(10, size=10),
133+
"b": pd.date_range("20010101", periods=10),
134+
},
135+
{
136+
"a": pd.date_range("20010101", periods=10),
137+
"b": pd.date_range("20010101", periods=10),
138+
},
139+
],
140+
],
141+
)
142+
def test_comparison_invalid(self, arg, arg2):
138143
# GH4968
139144
# invalid date/int comparisons
140-
df = DataFrame(np.random.randint(10, size=(10, 1)), columns=["a"])
141-
df["dates"] = pd.date_range("20010101", periods=len(df))
142-
143-
df2 = df.copy()
144-
df2["dates"] = df["a"]
145-
check(df, df2)
145+
x = DataFrame(arg)
146+
y = DataFrame(arg2)
147+
# we expect the result to match Series comparisons for
148+
# == and !=, inequalities should raise
149+
result = x == y
150+
expected = DataFrame(
151+
{col: x[col] == y[col] for col in x.columns},
152+
index=x.index,
153+
columns=x.columns,
154+
)
155+
tm.assert_frame_equal(result, expected)
146156

147-
df = DataFrame(np.random.randint(10, size=(10, 2)), columns=["a", "b"])
148-
df2 = DataFrame(
149-
{
150-
"a": pd.date_range("20010101", periods=len(df)),
151-
"b": pd.date_range("20100101", periods=len(df)),
152-
}
157+
result = x != y
158+
expected = DataFrame(
159+
{col: x[col] != y[col] for col in x.columns},
160+
index=x.index,
161+
columns=x.columns,
153162
)
154-
check(df, df2)
163+
tm.assert_frame_equal(result, expected)
164+
165+
msgs = [
166+
r"Invalid comparison between dtype=datetime64\[ns\] and ndarray",
167+
"invalid type promotion",
168+
(
169+
# npdev 1.20.0
170+
r"The DTypes <class 'numpy.dtype\[.*\]'> and "
171+
r"<class 'numpy.dtype\[.*\]'> do not have a common DType."
172+
),
173+
]
174+
msg = "|".join(msgs)
175+
with pytest.raises(TypeError, match=msg):
176+
x >= y
177+
with pytest.raises(TypeError, match=msg):
178+
x > y
179+
with pytest.raises(TypeError, match=msg):
180+
x < y
181+
with pytest.raises(TypeError, match=msg):
182+
x <= y
155183

156184
def test_timestamp_compare(self):
157185
# make sure we can compare Timestamps on the right AND left hand side

0 commit comments

Comments
 (0)