Skip to content

Commit 8172a85

Browse files
authored
Merge pull request scala#6760 from valydia/scala/collection-strawman#235
`partitionWith` support in `IterableOps`
2 parents 531b12c + d131c1d commit 8172a85

File tree

5 files changed

+119
-5
lines changed

5 files changed

+119
-5
lines changed

src/library/scala/collection/Iterable.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,31 @@ trait IterableOps[+A, +CC[_], +C] extends Any with IterableOnce[A] with Iterable
648648
def collect[B](pf: PartialFunction[A, B]): CC[B] =
649649
fromIterable(new View.Collect(this, pf))
650650

651+
/** A pair of ${coll}s, first, consisting of the values that are produced by `f` applied to this $coll elements contained in [[scala.util.Left]] and, second,
652+
* all the values contained in [[scala.util.Right]].
653+
*
654+
* Example:
655+
* {{{
656+
* val xs = $Coll(1, "one", 2, "two", 3, "three") partitionWith {
657+
* case i: Int => Left(i)
658+
* case s: String => Right(s)
659+
* }
660+
* // xs == ($Coll(1, 2, 3),
661+
* // $Coll(one, two, three))
662+
* }}}
663+
*
664+
* @tparam A1 element type of the first resulting collection
665+
* @tparam A2 element type of the second resulting collection
666+
* @param f split function that map the element of the $coll into an [[scala.util.Either]][A1, A2]
667+
*
668+
* @return a pair of ${coll}s, first, consisting of the values that are produced by `f` applied to this $coll
669+
* elements contained in [[scala.util.Left]] and, second, all the values contained in [[scala.util.Right]].
670+
*/
671+
def partitionWith[A1, A2](f: A => Either[A1, A2]): (CC[A1], CC[A2]) = {
672+
val mp = new View.PartitionWith(this, f)
673+
(fromIterable(mp.left), fromIterable(mp.right))
674+
}
675+
651676
/** Returns a new $coll containing the elements from the left hand operand followed by the elements from the
652677
* right hand operand. The element type of the $coll is the most specific superclass encompassing
653678
* the element types of the two operands.

src/library/scala/collection/StrictOptimizedIterableOps.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,4 +224,17 @@ trait StrictOptimizedIterableOps[+A, +CC[_], +C]
224224
b.result()
225225
}
226226

227+
// Optimized, push-based version of `partitionWith`
228+
override def partitionWith[A1, A2](f: A => Either[A1, A2]): (CC[A1], CC[A2]) = {
229+
val l = iterableFactory.newBuilder[A1]
230+
val r = iterableFactory.newBuilder[A2]
231+
foreach { x =>
232+
f(x) match {
233+
case Left(x1) => l += x1
234+
case Right(x2) => r += x2
235+
}
236+
}
237+
(l.result(), r.result())
238+
}
239+
227240
}

src/library/scala/collection/View.scala

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ object View extends IterableFactory[View] {
160160
override def isEmpty: Boolean = underlying.isEmpty
161161
}
162162

163-
/** A view that partitions an underlying collection into two views */
163+
/** A class that partitions an underlying collection into two views */
164164
@SerialVersionUID(3L)
165165
class Partition[A](val underlying: SomeIterableOps[A], val p: A => Boolean) extends Serializable {
166166

@@ -183,6 +183,71 @@ object View extends IterableFactory[View] {
183183
override def isEmpty: Boolean = iterator.isEmpty
184184
}
185185

186+
/** A class that splits an underlying collection into two views */
187+
@SerialVersionUID(3L)
188+
class PartitionWith[A, A1, A2](val underlying: SomeIterableOps[A], val f: A => Either[A1, A2]) extends Serializable {
189+
190+
/** The view consisting of all elements of the underlying collection
191+
* that map to `Left`.
192+
*/
193+
val left: View[A1] = new LeftPartitionedWith(this, f)
194+
195+
196+
/** The view consisting of all elements of the underlying collection
197+
* that map to `Right`.
198+
*/
199+
val right: View[A2] = new RightPartitionedWith(this, f)
200+
201+
}
202+
203+
@SerialVersionUID(3L)
204+
class LeftPartitionedWith[A, A1, A2](partitionWith: PartitionWith[A, A1, A2], f: A => Either[A1, A2]) extends AbstractView[A1] {
205+
def iterator = new AbstractIterator[A1] {
206+
private val self = partitionWith.underlying.iterator
207+
private var hd: A1 = _
208+
private var hdDefined: Boolean = false
209+
def hasNext = hdDefined || {
210+
def findNext(): Boolean =
211+
if (self.hasNext) {
212+
f(self.next()) match {
213+
case Left(a1) => hd = a1; hdDefined = true; true
214+
case Right(_) => findNext()
215+
}
216+
} else false
217+
findNext()
218+
}
219+
def next() =
220+
if (hasNext) {
221+
hdDefined = false
222+
hd
223+
} else Iterator.empty.next()
224+
}
225+
}
226+
227+
@SerialVersionUID(3L)
228+
class RightPartitionedWith[A, A1, A2](partitionWith: PartitionWith[A, A1, A2], f: A => Either[A1, A2]) extends AbstractView[A2] {
229+
def iterator = new AbstractIterator[A2] {
230+
private val self = partitionWith.underlying.iterator
231+
private var hd: A2 = _
232+
private var hdDefined: Boolean = false
233+
def hasNext = hdDefined || {
234+
def findNext(): Boolean =
235+
if (self.hasNext) {
236+
f(self.next()) match {
237+
case Left(_) => findNext()
238+
case Right(a2) => hd = a2; hdDefined = true; true
239+
}
240+
} else false
241+
findNext()
242+
}
243+
def next() =
244+
if (hasNext) {
245+
hdDefined = false
246+
hd
247+
} else Iterator.empty.next()
248+
}
249+
}
250+
186251
/** A view that drops leading elements of the underlying collection. */
187252
@SerialVersionUID(3L)
188253
class Drop[A](underlying: SomeIterableOps[A], n: Int) extends AbstractView[A] {

test/junit/scala/collection/BuildFromTest.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class BuildFromTest {
105105
builder.result()
106106
}
107107

108-
def mapSplit[A, B, C, ToL, ToR](coll: Iterable[A])(f: A => Either[B, C])
108+
def partitionWith[A, B, C, ToL, ToR](coll: Iterable[A])(f: A => Either[B, C])
109109
(implicit bfLeft: BuildFrom[coll.type, B, ToL], bfRight: BuildFrom[coll.type, C, ToR]): (ToL, ToR) = {
110110
val left = bfLeft.newBuilder(coll)
111111
val right = bfRight.newBuilder(coll)
@@ -135,14 +135,14 @@ class BuildFromTest {
135135
}
136136

137137
@Test
138-
def mapSplitTest: Unit = {
138+
def partitionWithTest: Unit = {
139139
val xs1 = immutable.List(1, 2, 3)
140-
val (xs2, xs3) = mapSplit(xs1)(x => if (x % 2 == 0) Left(x) else Right(x.toString))
140+
val (xs2, xs3) = partitionWith(xs1)(x => if (x % 2 == 0) Left(x) else Right(x.toString))
141141
val xs4: immutable.List[Int] = xs2
142142
val xs5: immutable.List[String] = xs3
143143

144144
val xs6 = immutable.TreeMap((1, "1"), (2, "2"))
145-
val (xs7, xs8) = mapSplit(xs6) { case (k, v) => Left[(String, Int), (Int, Boolean)]((v, k)) }
145+
val (xs7, xs8) = partitionWith(xs6) { case (k, v) => Left[(String, Int), (Int, Boolean)]((v, k)) }
146146
val xs9: immutable.TreeMap[String, Int] = xs7
147147
val xs10: immutable.TreeMap[Int, Boolean] = xs8
148148
}

test/junit/scala/collection/IterableTest.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ class IterableTest {
212212
Assert.assertEquals(baselist.reverse, checklist)
213213
}
214214

215+
@Test
215216
def unzip(): Unit = {
216217
val zipped = Seq((1, 'a'), (2, 'b'), (3, 'c'))
217218
val (s1, s2) = zipped.unzip
@@ -261,4 +262,14 @@ class IterableTest {
261262
Assert.assertEquals("Fu()", foo.toString)
262263
}
263264

265+
@Test
266+
def partitionWith: Unit = {
267+
val (left, right) = Seq(1, "1", 2, "2", 3, "3", 4, "4", 5, "5").partitionWith {
268+
case i: Int => Left(i)
269+
case s: String => Right(s)
270+
}
271+
Assert.assertEquals(left, Seq(1, 2, 3, 4 ,5))
272+
Assert.assertEquals(right, Seq("1", "2", "3", "4" ,"5"))
273+
}
274+
264275
}

0 commit comments

Comments
 (0)