Skip to content

Commit 73d3f03

Browse files
polmauriamueller
authored andcommitted
[MRG + 1] FIX raise an error message when n_groups > number of groups (scikit-learn#7681) (scikit-learn#7683)
* FIX raise an error message when n_groups > actual number of groups (scikit-learn#7681) This change addresses issue scikit-learn#7681: - Raise ValueError when n_groups > actual number of unique groups in LeaveOneGroupOut and LeavePGroupsOut. - Add unit test. * Make requested changes - Check error message with `assert_raise_message` - Pass parameters to `assert_raise_message` instead of defining functions * Update condition and exception message
1 parent ff5c36e commit 73d3f03

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

sklearn/model_selection/_split.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,10 @@ def _iter_test_masks(self, X, y, groups):
773773
# We make a copy of groups to avoid side-effects during iteration
774774
groups = np.array(groups, copy=True)
775775
unique_groups = np.unique(groups)
776+
if len(unique_groups) <= 1:
777+
raise ValueError(
778+
"The groups parameter contains fewer than 2 unique groups "
779+
"(%s). LeaveOneGroupOut expects at least 2." % unique_groups)
776780
for i in unique_groups:
777781
yield groups == i
778782

@@ -862,6 +866,12 @@ def _iter_test_masks(self, X, y, groups):
862866
raise ValueError("The groups parameter should not be None")
863867
groups = np.array(groups, copy=True)
864868
unique_groups = np.unique(groups)
869+
if self.n_groups >= len(unique_groups):
870+
raise ValueError(
871+
"The groups parameter contains fewer than (or equal to) "
872+
"n_groups (%d) numbers of unique groups (%s). LeavePGroupsOut "
873+
"expects that at least n_groups + 1 (%d) unique groups be "
874+
"present" % (self.n_groups, unique_groups, self.n_groups + 1))
865875
combi = combinations(range(len(unique_groups)), self.n_groups)
866876
for indices in combi:
867877
test_index = np.zeros(_num_samples(X), dtype=np.bool)

sklearn/model_selection/tests/test_split.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,31 @@ def test_leave_group_out_changing_groups():
724724
assert_equal(3, LeaveOneGroupOut().get_n_splits(X, y, groups))
725725

726726

727+
def test_leave_one_p_group_out_error_on_fewer_number_of_groups():
728+
X = y = groups = np.ones(0)
729+
msg = ("The groups parameter contains fewer than 2 unique groups ([]). "
730+
"LeaveOneGroupOut expects at least 2.")
731+
assert_raise_message(ValueError, msg, next,
732+
LeaveOneGroupOut().split(X, y, groups))
733+
X = y = groups = np.ones(1)
734+
msg = ("The groups parameter contains fewer than 2 unique groups ([ 1.]). "
735+
"LeaveOneGroupOut expects at least 2.")
736+
assert_raise_message(ValueError, msg, next,
737+
LeaveOneGroupOut().split(X, y, groups))
738+
X = y = groups = np.ones(1)
739+
msg = ("The groups parameter contains fewer than (or equal to) n_groups "
740+
"(3) numbers of unique groups ([ 1.]). LeavePGroupsOut expects "
741+
"that at least n_groups + 1 (4) unique groups be present")
742+
assert_raise_message(ValueError, msg, next,
743+
LeavePGroupsOut(n_groups=3).split(X, y, groups))
744+
X = y = groups = np.arange(3)
745+
msg = ("The groups parameter contains fewer than (or equal to) n_groups "
746+
"(3) numbers of unique groups ([0 1 2]). LeavePGroupsOut expects "
747+
"that at least n_groups + 1 (4) unique groups be present")
748+
assert_raise_message(ValueError, msg, next,
749+
LeavePGroupsOut(n_groups=3).split(X, y, groups))
750+
751+
727752
def test_train_test_split_errors():
728753
assert_raises(ValueError, train_test_split)
729754
assert_raises(ValueError, train_test_split, range(3), train_size=1.1)

0 commit comments

Comments
 (0)