Skip to content

Commit 8a7fbbe

Browse files
authored
TST: parametrize generic/internals tests (#31900)
1 parent 48cb5a9 commit 8a7fbbe

File tree

4 files changed

+203
-268
lines changed

4 files changed

+203
-268
lines changed

pandas/tests/generic/test_frame.py

Lines changed: 40 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -32,35 +32,35 @@ def test_rename_mi(self):
3232
)
3333
df.rename(str.lower)
3434

35-
def test_set_axis_name(self):
35+
@pytest.mark.parametrize("func", ["_set_axis_name", "rename_axis"])
36+
def test_set_axis_name(self, func):
3637
df = pd.DataFrame([[1, 2], [3, 4]])
37-
funcs = ["_set_axis_name", "rename_axis"]
38-
for func in funcs:
39-
result = methodcaller(func, "foo")(df)
40-
assert df.index.name is None
41-
assert result.index.name == "foo"
4238

43-
result = methodcaller(func, "cols", axis=1)(df)
44-
assert df.columns.name is None
45-
assert result.columns.name == "cols"
39+
result = methodcaller(func, "foo")(df)
40+
assert df.index.name is None
41+
assert result.index.name == "foo"
4642

47-
def test_set_axis_name_mi(self):
43+
result = methodcaller(func, "cols", axis=1)(df)
44+
assert df.columns.name is None
45+
assert result.columns.name == "cols"
46+
47+
@pytest.mark.parametrize("func", ["_set_axis_name", "rename_axis"])
48+
def test_set_axis_name_mi(self, func):
4849
df = DataFrame(
4950
np.empty((3, 3)),
5051
index=MultiIndex.from_tuples([("A", x) for x in list("aBc")]),
5152
columns=MultiIndex.from_tuples([("C", x) for x in list("xyz")]),
5253
)
5354

5455
level_names = ["L1", "L2"]
55-
funcs = ["_set_axis_name", "rename_axis"]
56-
for func in funcs:
57-
result = methodcaller(func, level_names)(df)
58-
assert result.index.names == level_names
59-
assert result.columns.names == [None, None]
6056

61-
result = methodcaller(func, level_names, axis=1)(df)
62-
assert result.columns.names == ["L1", "L2"]
63-
assert result.index.names == [None, None]
57+
result = methodcaller(func, level_names)(df)
58+
assert result.index.names == level_names
59+
assert result.columns.names == [None, None]
60+
61+
result = methodcaller(func, level_names, axis=1)(df)
62+
assert result.columns.names == ["L1", "L2"]
63+
assert result.index.names == [None, None]
6464

6565
def test_nonzero_single_element(self):
6666

@@ -185,36 +185,35 @@ def test_deepcopy_empty(self):
185185

186186
# formerly in Generic but only test DataFrame
187187
class TestDataFrame2:
188-
def test_validate_bool_args(self):
188+
@pytest.mark.parametrize("value", [1, "True", [1, 2, 3], 5.0])
189+
def test_validate_bool_args(self, value):
189190
df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
190-
invalid_values = [1, "True", [1, 2, 3], 5.0]
191191

192-
for value in invalid_values:
193-
with pytest.raises(ValueError):
194-
super(DataFrame, df).rename_axis(
195-
mapper={"a": "x", "b": "y"}, axis=1, inplace=value
196-
)
192+
with pytest.raises(ValueError):
193+
super(DataFrame, df).rename_axis(
194+
mapper={"a": "x", "b": "y"}, axis=1, inplace=value
195+
)
197196

198-
with pytest.raises(ValueError):
199-
super(DataFrame, df).drop("a", axis=1, inplace=value)
197+
with pytest.raises(ValueError):
198+
super(DataFrame, df).drop("a", axis=1, inplace=value)
200199

201-
with pytest.raises(ValueError):
202-
super(DataFrame, df)._consolidate(inplace=value)
200+
with pytest.raises(ValueError):
201+
super(DataFrame, df)._consolidate(inplace=value)
203202

204-
with pytest.raises(ValueError):
205-
super(DataFrame, df).fillna(value=0, inplace=value)
203+
with pytest.raises(ValueError):
204+
super(DataFrame, df).fillna(value=0, inplace=value)
206205

207-
with pytest.raises(ValueError):
208-
super(DataFrame, df).replace(to_replace=1, value=7, inplace=value)
206+
with pytest.raises(ValueError):
207+
super(DataFrame, df).replace(to_replace=1, value=7, inplace=value)
209208

210-
with pytest.raises(ValueError):
211-
super(DataFrame, df).interpolate(inplace=value)
209+
with pytest.raises(ValueError):
210+
super(DataFrame, df).interpolate(inplace=value)
212211

213-
with pytest.raises(ValueError):
214-
super(DataFrame, df)._where(cond=df.a > 2, inplace=value)
212+
with pytest.raises(ValueError):
213+
super(DataFrame, df)._where(cond=df.a > 2, inplace=value)
215214

216-
with pytest.raises(ValueError):
217-
super(DataFrame, df).mask(cond=df.a > 2, inplace=value)
215+
with pytest.raises(ValueError):
216+
super(DataFrame, df).mask(cond=df.a > 2, inplace=value)
218217

219218
def test_unexpected_keyword(self):
220219
# GH8597
@@ -243,23 +242,10 @@ class TestToXArray:
243242
and LooseVersion(xarray.__version__) < LooseVersion("0.10.0"),
244243
reason="xarray >= 0.10.0 required",
245244
)
246-
@pytest.mark.parametrize(
247-
"index",
248-
[
249-
"FloatIndex",
250-
"IntIndex",
251-
"StringIndex",
252-
"UnicodeIndex",
253-
"DateIndex",
254-
"PeriodIndex",
255-
"CategoricalIndex",
256-
"TimedeltaIndex",
257-
],
258-
)
245+
@pytest.mark.parametrize("index", tm.all_index_generator(3))
259246
def test_to_xarray_index_types(self, index):
260247
from xarray import Dataset
261248

262-
index = getattr(tm, f"make{index}")
263249
df = DataFrame(
264250
{
265251
"a": list("abc"),
@@ -273,7 +259,7 @@ def test_to_xarray_index_types(self, index):
273259
}
274260
)
275261

276-
df.index = index(3)
262+
df.index = index
277263
df.index.name = "foo"
278264
df.columns.name = "bar"
279265
result = df.to_xarray()

pandas/tests/generic/test_generic.py

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -257,39 +257,31 @@ def test_metadata_propagation(self):
257257
self.check_metadata(v1 & v2)
258258
self.check_metadata(v1 | v2)
259259

260-
def test_head_tail(self):
260+
@pytest.mark.parametrize("index", tm.all_index_generator(10))
261+
def test_head_tail(self, index):
261262
# GH5370
262263

263264
o = self._construct(shape=10)
264265

265-
# check all index types
266-
for index in [
267-
tm.makeFloatIndex,
268-
tm.makeIntIndex,
269-
tm.makeStringIndex,
270-
tm.makeUnicodeIndex,
271-
tm.makeDateIndex,
272-
tm.makePeriodIndex,
273-
]:
274-
axis = o._get_axis_name(0)
275-
setattr(o, axis, index(len(getattr(o, axis))))
266+
axis = o._get_axis_name(0)
267+
setattr(o, axis, index)
276268

277-
o.head()
269+
o.head()
278270

279-
self._compare(o.head(), o.iloc[:5])
280-
self._compare(o.tail(), o.iloc[-5:])
271+
self._compare(o.head(), o.iloc[:5])
272+
self._compare(o.tail(), o.iloc[-5:])
281273

282-
# 0-len
283-
self._compare(o.head(0), o.iloc[0:0])
284-
self._compare(o.tail(0), o.iloc[0:0])
274+
# 0-len
275+
self._compare(o.head(0), o.iloc[0:0])
276+
self._compare(o.tail(0), o.iloc[0:0])
285277

286-
# bounded
287-
self._compare(o.head(len(o) + 1), o)
288-
self._compare(o.tail(len(o) + 1), o)
278+
# bounded
279+
self._compare(o.head(len(o) + 1), o)
280+
self._compare(o.tail(len(o) + 1), o)
289281

290-
# neg index
291-
self._compare(o.head(-3), o.head(7))
292-
self._compare(o.tail(-3), o.tail(7))
282+
# neg index
283+
self._compare(o.head(-3), o.head(7))
284+
self._compare(o.tail(-3), o.tail(7))
293285

294286
def test_sample(self):
295287
# Fixes issue: 2419
@@ -468,16 +460,16 @@ def test_stat_unexpected_keyword(self):
468460
with pytest.raises(TypeError, match=errmsg):
469461
obj.any(epic=starwars) # logical_function
470462

471-
def test_api_compat(self):
463+
@pytest.mark.parametrize("func", ["sum", "cumsum", "any", "var"])
464+
def test_api_compat(self, func):
472465

473466
# GH 12021
474467
# compat for __name__, __qualname__
475468

476469
obj = self._construct(5)
477-
for func in ["sum", "cumsum", "any", "var"]:
478-
f = getattr(obj, func)
479-
assert f.__name__ == func
480-
assert f.__qualname__.endswith(func)
470+
f = getattr(obj, func)
471+
assert f.__name__ == func
472+
assert f.__qualname__.endswith(func)
481473

482474
def test_stat_non_defaults_args(self):
483475
obj = self._construct(5)
@@ -510,19 +502,17 @@ def test_truncate_out_of_bounds(self):
510502
self._compare(big.truncate(before=0, after=3e6), big)
511503
self._compare(big.truncate(before=-1, after=2e6), big)
512504

513-
def test_copy_and_deepcopy(self):
505+
@pytest.mark.parametrize(
506+
"func",
507+
[copy, deepcopy, lambda x: x.copy(deep=False), lambda x: x.copy(deep=True)],
508+
)
509+
@pytest.mark.parametrize("shape", [0, 1, 2])
510+
def test_copy_and_deepcopy(self, shape, func):
514511
# GH 15444
515-
for shape in [0, 1, 2]:
516-
obj = self._construct(shape)
517-
for func in [
518-
copy,
519-
deepcopy,
520-
lambda x: x.copy(deep=False),
521-
lambda x: x.copy(deep=True),
522-
]:
523-
obj_copy = func(obj)
524-
assert obj_copy is not obj
525-
self._compare(obj_copy, obj)
512+
obj = self._construct(shape)
513+
obj_copy = func(obj)
514+
assert obj_copy is not obj
515+
self._compare(obj_copy, obj)
526516

527517
@pytest.mark.parametrize(
528518
"periods,fill_method,limit,exp",

pandas/tests/generic/test_series.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,29 +38,29 @@ def test_rename_mi(self):
3838
)
3939
s.rename(str.lower)
4040

41-
def test_set_axis_name(self):
41+
@pytest.mark.parametrize("func", ["rename_axis", "_set_axis_name"])
42+
def test_set_axis_name(self, func):
4243
s = Series([1, 2, 3], index=["a", "b", "c"])
43-
funcs = ["rename_axis", "_set_axis_name"]
4444
name = "foo"
45-
for func in funcs:
46-
result = methodcaller(func, name)(s)
47-
assert s.index.name is None
48-
assert result.index.name == name
4945

50-
def test_set_axis_name_mi(self):
46+
result = methodcaller(func, name)(s)
47+
assert s.index.name is None
48+
assert result.index.name == name
49+
50+
@pytest.mark.parametrize("func", ["rename_axis", "_set_axis_name"])
51+
def test_set_axis_name_mi(self, func):
5152
s = Series(
5253
[11, 21, 31],
5354
index=MultiIndex.from_tuples(
5455
[("A", x) for x in ["a", "B", "c"]], names=["l1", "l2"]
5556
),
5657
)
57-
funcs = ["rename_axis", "_set_axis_name"]
58-
for func in funcs:
59-
result = methodcaller(func, ["L1", "L2"])(s)
60-
assert s.index.name is None
61-
assert s.index.names == ["l1", "l2"]
62-
assert result.index.name is None
63-
assert result.index.names, ["L1", "L2"]
58+
59+
result = methodcaller(func, ["L1", "L2"])(s)
60+
assert s.index.name is None
61+
assert s.index.names == ["l1", "l2"]
62+
assert result.index.name is None
63+
assert result.index.names, ["L1", "L2"]
6464

6565
def test_set_axis_name_raises(self):
6666
s = pd.Series([1])
@@ -230,24 +230,11 @@ class TestToXArray:
230230
and LooseVersion(xarray.__version__) < LooseVersion("0.10.0"),
231231
reason="xarray >= 0.10.0 required",
232232
)
233-
@pytest.mark.parametrize(
234-
"index",
235-
[
236-
"FloatIndex",
237-
"IntIndex",
238-
"StringIndex",
239-
"UnicodeIndex",
240-
"DateIndex",
241-
"PeriodIndex",
242-
"TimedeltaIndex",
243-
"CategoricalIndex",
244-
],
245-
)
233+
@pytest.mark.parametrize("index", tm.all_index_generator(6))
246234
def test_to_xarray_index_types(self, index):
247235
from xarray import DataArray
248236

249-
index = getattr(tm, f"make{index}")
250-
s = Series(range(6), index=index(6))
237+
s = Series(range(6), index=index)
251238
s.index.name = "foo"
252239
result = s.to_xarray()
253240
repr(result)

0 commit comments

Comments
 (0)