Skip to content

Commit 558f2b8

Browse files
authored
Merge pull request #3517 from natecook1000/nc-SE-0120
[stdlib] Implement partition API change (SE-0120)
2 parents e4e9cf5 + 4ee37b7 commit 558f2b8

File tree

11 files changed

+364
-275
lines changed

11 files changed

+364
-275
lines changed

docs/proposals/InoutCOWOptimization.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ could be written as follows:
5555
let (start, end) = (startIndex, endIndex)
5656
if start != end && start.succ() != end {
5757
let pivot = self[start]
58-
let mid = partition({compare($0, pivot)})
58+
let mid = partition(by: {!compare($0, pivot)})
5959
**self[start...mid].quickSort(compare)**
6060
**self[mid...end].quickSort(compare)**
6161
}

stdlib/private/StdlibCollectionUnittest/CheckMutableCollectionType.swift.gyb

Lines changed: 124 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import StdlibUnittest
1414

15-
// These tests are shared between partition() and sort().
15+
// These tests are shared between partition(by:) and sort().
1616
public struct PartitionExhaustiveTest {
1717
public let sequence: [Int]
1818
public let loc: SourceLoc
@@ -67,6 +67,18 @@ internal func _mapInPlace<C : MutableCollection>(
6767
}
6868
}
6969

70+
internal func makeBufferAccessLoggingMutableCollection<
71+
C : MutableCollection & BidirectionalCollection
72+
>(wrapping c: C) -> BufferAccessLoggingMutableBidirectionalCollection<C> {
73+
return BufferAccessLoggingMutableBidirectionalCollection(wrapping: c)
74+
}
75+
76+
internal func makeBufferAccessLoggingMutableCollection<
77+
C : MutableCollection & RandomAccessCollection
78+
>(wrapping c: C) -> BufferAccessLoggingMutableBidirectionalCollection<C> {
79+
return BufferAccessLoggingMutableBidirectionalCollection(wrapping: c)
80+
}
81+
7082
extension TestSuite {
7183
public func addMutableCollectionTests<
7284
C : MutableCollection,
@@ -574,6 +586,86 @@ self.test("\(testNamePrefix).sorted/${'Predicate' if predicate else 'WhereElemen
574586

575587
% end
576588

589+
//===----------------------------------------------------------------------===//
590+
// partition(by:)
591+
//===----------------------------------------------------------------------===//
592+
593+
func checkPartition(
594+
sequence: [Int],
595+
pivotValue: Int,
596+
lessImpl: ((Int, Int) -> Bool),
597+
verifyOrder: Bool
598+
) {
599+
let elements: [OpaqueValue<Int>] =
600+
zip(sequence, 0..<sequence.count).map {
601+
OpaqueValue($0, identity: $1)
602+
}
603+
604+
var c = makeWrappedCollection(elements)
605+
let closureLifetimeTracker = LifetimeTracked(0)
606+
let pivot = c.partition(by: { val in
607+
_blackHole(closureLifetimeTracker)
608+
return !lessImpl(extractValue(val).value, pivotValue)
609+
})
610+
611+
// Check that we didn't lose any elements.
612+
let identities = c.map { extractValue($0).identity }
613+
expectEqualsUnordered(0..<sequence.count, identities)
614+
615+
if verifyOrder {
616+
// All the elements in the first partition are less than the pivot
617+
// value.
618+
for i in c[c.startIndex..<pivot].indices {
619+
expectLT(extractValue(c[i]).value, pivotValue)
620+
}
621+
// All the elements in the second partition are greater or equal to
622+
// the pivot value.
623+
for i in c[pivot..<c.endIndex].indices {
624+
expectGE(extractValue(c[i]).value, pivotValue)
625+
}
626+
}
627+
}
628+
629+
self.test("\(testNamePrefix).partition") {
630+
for test in partitionExhaustiveTests {
631+
forAllPermutations(test.sequence) { (sequence) in
632+
checkPartition(
633+
sequence: sequence,
634+
pivotValue: sequence.first ?? 0,
635+
lessImpl: { $0 < $1 },
636+
verifyOrder: true)
637+
638+
// Pivot value where all elements will pass the partitioning predicate
639+
checkPartition(
640+
sequence: sequence,
641+
pivotValue: Int.min,
642+
lessImpl: { $0 < $1 },
643+
verifyOrder: true)
644+
645+
// Pivot value where no element will pass the partitioning predicate
646+
checkPartition(
647+
sequence: sequence,
648+
pivotValue: Int.max,
649+
lessImpl: { $0 < $1 },
650+
verifyOrder: true)
651+
}
652+
}
653+
}
654+
655+
self.test("\(testNamePrefix).partition/InvalidOrderings") {
656+
withInvalidOrderings { (comparisonPredicate) in
657+
for i in 0..<7 {
658+
forAllPermutations(i) { (sequence) in
659+
checkPartition(
660+
sequence: sequence,
661+
pivotValue: sequence.first ?? 0,
662+
lessImpl: comparisonPredicate,
663+
verifyOrder: false)
664+
}
665+
}
666+
}
667+
}
668+
577669
//===----------------------------------------------------------------------===//
578670

579671
} // addMutableCollectionTests
@@ -695,6 +787,37 @@ self.test("\(testNamePrefix).reverse()") {
695787
}
696788
}
697789

790+
//===----------------------------------------------------------------------===//
791+
// partition(by:)
792+
//===----------------------------------------------------------------------===//
793+
794+
self.test("\(testNamePrefix).partition/DispatchesThrough_withUnsafeMutableBufferPointerIfSupported") {
795+
let sequence = [ 5, 4, 3, 2, 1 ]
796+
let elements: [OpaqueValue<Int>] =
797+
zip(sequence, 0..<sequence.count).map {
798+
OpaqueValue($0, identity: $1)
799+
}
800+
let c = makeWrappedCollection(elements)
801+
var lc = makeBufferAccessLoggingMutableCollection(wrapping: c)
802+
803+
let closureLifetimeTracker = LifetimeTracked(0)
804+
let first = c.first
805+
let pivot = lc.partition(by: { val in
806+
_blackHole(closureLifetimeTracker)
807+
return !(extractValue(val).value < extractValue(first!).value)
808+
})
809+
810+
expectEqual(
811+
1, lc.log._withUnsafeMutableBufferPointerIfSupported[lc.dynamicType])
812+
expectEqual(
813+
withUnsafeMutableBufferPointerIsSupported ? 1 : 0,
814+
lc.log._withUnsafeMutableBufferPointerIfSupportedNonNilReturns[lc.dynamicType])
815+
816+
expectEqual(4, lc.distance(from: lc.startIndex, to: pivot))
817+
expectEqualsUnordered([1, 2, 3, 4], lc.prefix(upTo: pivot).map { extractValue($0).value })
818+
expectEqualsUnordered([5], lc.suffix(from: pivot).map { extractValue($0).value })
819+
}
820+
698821
//===----------------------------------------------------------------------===//
699822

700823
} // addMutableBidirectionalCollectionTests
@@ -875,140 +998,6 @@ self.test("\(testNamePrefix).sort/${'Predicate' if predicate else 'WhereElementI
875998

876999
% end
8771000

878-
//===----------------------------------------------------------------------===//
879-
// partition()
880-
//===----------------------------------------------------------------------===//
881-
882-
% for predicate in [False, True]:
883-
884-
func checkPartition_${'Predicate' if predicate else 'WhereElementIsComparable'}(
885-
sequence: [Int],
886-
equalImpl: ((Int, Int) -> Bool),
887-
lessImpl: ((Int, Int) -> Bool),
888-
verifyOrder: Bool
889-
) {
890-
% if predicate:
891-
let extract = extractValue
892-
let elements: [OpaqueValue<Int>] =
893-
zip(sequence, 0..<sequence.count).map {
894-
OpaqueValue($0, identity: $1)
895-
}
896-
897-
var c = makeWrappedCollection(elements)
898-
% else:
899-
MinimalComparableValue.equalImpl.value = equalImpl
900-
MinimalComparableValue.lessImpl.value = lessImpl
901-
902-
let extract = extractValueFromComparable
903-
let elements: [MinimalComparableValue] =
904-
zip(sequence, 0..<sequence.count).map {
905-
MinimalComparableValue($0, identity: $1)
906-
}
907-
908-
var c = makeWrappedCollectionWithComparableElement(elements)
909-
% end
910-
911-
% if predicate:
912-
let closureLifetimeTracker = LifetimeTracked(0)
913-
let pivot = c.partition() {
914-
(lhs, rhs) in
915-
_blackHole(closureLifetimeTracker)
916-
return extract(lhs).value < extract(rhs).value
917-
}
918-
% else:
919-
let pivot = c.partition()
920-
% end
921-
922-
// Check that we didn't lose any elements.
923-
let identities = c.map { extract($0).identity }
924-
expectEqualsUnordered(0..<sequence.count, identities)
925-
926-
if verifyOrder {
927-
// All the elements in the first partition are less than the pivot
928-
// value.
929-
for i in c[c.startIndex..<pivot].indices {
930-
expectLT(extract(c[i]).value, extract(c[pivot]).value)
931-
}
932-
// All the elements in the second partition are greater or equal to
933-
// the pivot value.
934-
for i in c[pivot..<c.endIndex].indices {
935-
expectLE(extract(c[pivot]).value, extract(c[i]).value)
936-
}
937-
}
938-
}
939-
940-
self.test("\(testNamePrefix).partition/${'Predicate' if predicate else 'WhereElementIsComparable'}") {
941-
for test in partitionExhaustiveTests {
942-
forAllPermutations(test.sequence) { (sequence) in
943-
checkPartition_${'Predicate' if predicate else 'WhereElementIsComparable'}(
944-
sequence: sequence,
945-
equalImpl: { $0 == $1 },
946-
lessImpl: { $0 < $1 },
947-
verifyOrder: true)
948-
}
949-
}
950-
}
951-
952-
self.test("\(testNamePrefix).partition/${'Predicate' if predicate else 'WhereElementIsComparable'}/InvalidOrderings") {
953-
withInvalidOrderings { (comparisonPredicate) in
954-
for i in 0..<7 {
955-
forAllPermutations(i) { (sequence) in
956-
checkPartition_${'Predicate' if predicate else 'WhereElementIsComparable'}(
957-
sequence: sequence,
958-
equalImpl: {
959-
!comparisonPredicate($0, $1) &&
960-
!comparisonPredicate($1, $0)
961-
},
962-
lessImpl: comparisonPredicate,
963-
verifyOrder: false)
964-
}
965-
}
966-
}
967-
}
968-
969-
self.test("\(testNamePrefix).partition/DispatchesThrough_withUnsafeMutableBufferPointerIfSupported/${'Predicate' if predicate else 'WhereElementIsComparable'}") {
970-
let sequence = [ 5, 4, 3, 2, 1 ]
971-
% if predicate:
972-
let extract = extractValue
973-
let elements: [OpaqueValue<Int>] =
974-
zip(sequence, 0..<sequence.count).map {
975-
OpaqueValue($0, identity: $1)
976-
}
977-
let c = makeWrappedCollection(elements)
978-
% else:
979-
let extract = extractValueFromComparable
980-
let elements: [MinimalComparableValue] =
981-
zip(sequence, 0..<sequence.count).map {
982-
MinimalComparableValue($0, identity: $1)
983-
}
984-
let c = makeWrappedCollectionWithComparableElement(elements)
985-
% end
986-
987-
var lc = LoggingMutableRandomAccessCollection(wrapping: c)
988-
989-
% if predicate:
990-
let closureLifetimeTracker = LifetimeTracked(0)
991-
let pivot = lc.partition() {
992-
(lhs, rhs) in
993-
_blackHole(closureLifetimeTracker)
994-
return extract(lhs).value < extract(rhs).value
995-
}
996-
% else:
997-
let pivot = lc.partition()
998-
% end
999-
1000-
expectEqual(
1001-
1, lc.log._withUnsafeMutableBufferPointerIfSupported[lc.dynamicType])
1002-
expectEqual(
1003-
withUnsafeMutableBufferPointerIsSupported ? 1 : 0,
1004-
lc.log._withUnsafeMutableBufferPointerIfSupportedNonNilReturns[lc.dynamicType])
1005-
1006-
expectEqual(4, lc.distance(from: lc.startIndex, to: pivot))
1007-
expectEqualSequence([ 1, 4, 3, 2, 5 ], lc.map { extract($0).value })
1008-
}
1009-
1010-
% end
1011-
10121001
//===----------------------------------------------------------------------===//
10131002

10141003
} // addMutableRandomAccessCollectionTests

0 commit comments

Comments
 (0)