|
2 | 2 | from doctest import DocTestSuite
|
3 | 3 | from fractions import Fraction
|
4 | 4 | from functools import reduce
|
5 |
| -from itertools import combinations, count, permutations |
| 5 | +from itertools import combinations, count, groupby, permutations |
6 | 6 | from operator import mul
|
7 | 7 | from math import factorial
|
8 | 8 | from sys import version_info
|
9 | 9 | from unittest import TestCase, skipIf
|
| 10 | +from unittest.mock import patch |
10 | 11 |
|
11 | 12 | import more_itertools as mi
|
12 | 13 |
|
@@ -158,6 +159,22 @@ def test_key(self):
|
158 | 159 | self.assertTrue(mi.all_equal('4٤໔4৪', key=int))
|
159 | 160 | self.assertFalse(mi.all_equal('Abc', key=str.casefold))
|
160 | 161 |
|
| 162 | + @patch('more_itertools.recipes.groupby', autospec=True) |
| 163 | + def test_groupby_calls(self, mock_groupby): |
| 164 | + next_count = 0 |
| 165 | + |
| 166 | + class _groupby(groupby): |
| 167 | + def __next__(true_self): |
| 168 | + nonlocal next_count |
| 169 | + next_count += 1 |
| 170 | + return super().__next__() |
| 171 | + |
| 172 | + mock_groupby.side_effect = _groupby |
| 173 | + iterable = iter('aaaaa') |
| 174 | + self.assertTrue(mi.all_equal(iterable)) |
| 175 | + self.assertEqual(list(iterable), []) |
| 176 | + self.assertEqual(next_count, 2) |
| 177 | + |
161 | 178 |
|
162 | 179 | class QuantifyTests(TestCase):
|
163 | 180 | """Tests for ``quantify()``"""
|
|
0 commit comments