Skip to content

Commit c9c865c

Browse files
committed
[stdlib] Update partition(by:) tests
Adds a dispatch test for partition(by:), since it's now a protocol requirement. Also adds a new logging collection wrapper that only logs when _withUnsafeMutableBufferPointerIfSupported is called -- any calls to this method from dispatched methods are uncounted by the standard logging wrappers.
1 parent 2a6a671 commit c9c865c

File tree

3 files changed

+156
-40
lines changed

3 files changed

+156
-40
lines changed

stdlib/private/StdlibCollectionUnittest/CheckMutableCollectionType.swift.gyb

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import StdlibUnittest
1414

15-
// These tests are shared between partition() and sort().
15+
// These tests are shared between partition(by:) and sort().
1616
public struct PartitionExhaustiveTest {
1717
public let sequence: [Int]
1818
public let loc: SourceLoc
@@ -67,6 +67,18 @@ internal func _mapInPlace<C : MutableCollection>(
6767
}
6868
}
6969

70+
internal func makeBufferAccessLoggingMutableCollection<
71+
C : MutableCollection & BidirectionalCollection
72+
>(wrapping c: C) -> BufferAccessLoggingMutableBidirectionalCollection<C> {
73+
return BufferAccessLoggingMutableBidirectionalCollection(wrapping: c)
74+
}
75+
76+
internal func makeBufferAccessLoggingMutableCollection<
77+
C : MutableCollection & RandomAccessCollection
78+
>(wrapping c: C) -> BufferAccessLoggingMutableBidirectionalCollection<C> {
79+
return BufferAccessLoggingMutableBidirectionalCollection(wrapping: c)
80+
}
81+
7082
extension TestSuite {
7183
public func addMutableCollectionTests<
7284
C : MutableCollection,
@@ -575,7 +587,7 @@ self.test("\(testNamePrefix).sorted/${'Predicate' if predicate else 'WhereElemen
575587
% end
576588

577589
//===----------------------------------------------------------------------===//
578-
// partition()
590+
// partition(by:)
579591
//===----------------------------------------------------------------------===//
580592

581593
func checkPartition(
@@ -584,7 +596,6 @@ func checkPartition(
584596
lessImpl: ((Int, Int) -> Bool),
585597
verifyOrder: Bool
586598
) {
587-
let extract = extractValue
588599
let elements: [OpaqueValue<Int>] =
589600
zip(sequence, 0..<sequence.count).map {
590601
OpaqueValue($0, identity: $1)
@@ -594,23 +605,23 @@ func checkPartition(
594605
let closureLifetimeTracker = LifetimeTracked(0)
595606
let pivot = c.partition(by: { val in
596607
_blackHole(closureLifetimeTracker)
597-
return !lessImpl(extract(val).value, pivotValue)
608+
return !lessImpl(extractValue(val).value, pivotValue)
598609
})
599610

600611
// Check that we didn't lose any elements.
601-
let identities = c.map { extract($0).identity }
612+
let identities = c.map { extractValue($0).identity }
602613
expectEqualsUnordered(0..<sequence.count, identities)
603614

604615
if verifyOrder {
605616
// All the elements in the first partition are less than the pivot
606617
// value.
607618
for i in c[c.startIndex..<pivot].indices {
608-
expectLT(extract(c[i]).value, pivotValue)
619+
expectLT(extractValue(c[i]).value, pivotValue)
609620
}
610621
// All the elements in the second partition are greater or equal to
611622
// the pivot value.
612623
for i in c[pivot..<c.endIndex].indices {
613-
expectGE(extract(c[i]).value, pivotValue)
624+
expectGE(extractValue(c[i]).value, pivotValue)
614625
}
615626
}
616627
}
@@ -776,6 +787,37 @@ self.test("\(testNamePrefix).reverse()") {
776787
}
777788
}
778789

790+
//===----------------------------------------------------------------------===//
791+
// partition(by:)
792+
//===----------------------------------------------------------------------===//
793+
794+
self.test("\(testNamePrefix).partition/DispatchesThrough_withUnsafeMutableBufferPointerIfSupported") {
795+
let sequence = [ 5, 4, 3, 2, 1 ]
796+
let elements: [OpaqueValue<Int>] =
797+
zip(sequence, 0..<sequence.count).map {
798+
OpaqueValue($0, identity: $1)
799+
}
800+
let c = makeWrappedCollection(elements)
801+
var lc = makeBufferAccessLoggingMutableCollection(wrapping: c)
802+
803+
let closureLifetimeTracker = LifetimeTracked(0)
804+
let first = c.first
805+
let pivot = lc.partition(by: { val in
806+
_blackHole(closureLifetimeTracker)
807+
return !(extractValue(val).value < extractValue(first!).value)
808+
})
809+
810+
expectEqual(
811+
1, lc.log._withUnsafeMutableBufferPointerIfSupported[lc.dynamicType])
812+
expectEqual(
813+
withUnsafeMutableBufferPointerIsSupported ? 1 : 0,
814+
lc.log._withUnsafeMutableBufferPointerIfSupportedNonNilReturns[lc.dynamicType])
815+
816+
expectEqual(4, lc.distance(from: lc.startIndex, to: pivot))
817+
expectEqualsUnordered([1, 2, 3, 4], lc.prefix(upTo: pivot).map { extractValue($0).value })
818+
expectEqualsUnordered([5], lc.suffix(from: pivot).map { extractValue($0).value })
819+
}
820+
779821
//===----------------------------------------------------------------------===//
780822

781823
} // addMutableBidirectionalCollectionTests
@@ -956,39 +998,6 @@ self.test("\(testNamePrefix).sort/${'Predicate' if predicate else 'WhereElementI
956998

957999
% end
9581000

959-
//===----------------------------------------------------------------------===//
960-
// partition()
961-
//===----------------------------------------------------------------------===//
962-
963-
// FIXME(tests): Move to addMutableCollectionTests?
964-
self.test("\(testNamePrefix).partition/DispatchesThrough_withUnsafeMutableBufferPointerIfSupported") {
965-
let sequence = [ 5, 4, 3, 2, 1 ]
966-
let extract = extractValue
967-
let elements: [OpaqueValue<Int>] =
968-
zip(sequence, 0..<sequence.count).map {
969-
OpaqueValue($0, identity: $1)
970-
}
971-
let c = makeWrappedCollection(elements)
972-
973-
var lc = LoggingMutableRandomAccessCollection(wrapping: c)
974-
975-
let closureLifetimeTracker = LifetimeTracked(0)
976-
let first = c.first
977-
let pivot = lc.partition(by: { val in
978-
_blackHole(closureLifetimeTracker)
979-
return !(extract(val).value < extract(first!).value)
980-
})
981-
982-
expectEqual(
983-
1, lc.log._withUnsafeMutableBufferPointerIfSupported[lc.dynamicType])
984-
expectEqual(
985-
withUnsafeMutableBufferPointerIsSupported ? 1 : 0,
986-
lc.log._withUnsafeMutableBufferPointerIfSupportedNonNilReturns[lc.dynamicType])
987-
988-
expectEqual(4, lc.distance(from: lc.startIndex, to: pivot))
989-
expectEqualSequence([ 1, 4, 3, 2, 5 ], lc.map { extract($0).value })
990-
}
991-
9921001
//===----------------------------------------------------------------------===//
9931002

9941003
} // addMutableRandomAccessCollectionTests

stdlib/private/StdlibCollectionUnittest/LoggingWrappers.swift.gyb

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ public class MutableCollectionLog : CollectionLog {
140140
}
141141
public static var subscriptIndexSet = TypeIndexed(0)
142142
public static var subscriptRangeSet = TypeIndexed(0)
143+
public static var partitionBy = TypeIndexed(0)
143144
public static var _withUnsafeMutableBufferPointerIfSupported = TypeIndexed(0)
144145
public static var _withUnsafeMutableBufferPointerIfSupportedNonNilReturns =
145146
TypeIndexed(0)
@@ -440,6 +441,13 @@ public struct ${Self}<
440441
% end
441442

442443
% if Kind == 'MutableCollection':
444+
public mutating func partition(
445+
by belongsInSecondPartition: @noescape (Iterator.Element) throws -> Bool
446+
) rethrows -> Index {
447+
Log.partitionBy[selfType] += 1
448+
return try base.partition(by: belongsInSecondPartition)
449+
}
450+
443451
public mutating func _withUnsafeMutableBufferPointerIfSupported<R>(
444452
_ body: @noescape (UnsafeMutablePointer<Iterator.Element>, Int) throws -> R
445453
) rethrows -> R? {
@@ -565,6 +573,95 @@ public struct ${Self}<
565573
% end
566574
% end
567575

576+
//===----------------------------------------------------------------------===//
577+
// Collections that count calls to `_withUnsafeMutableBufferPointerIfSupported`
578+
//===----------------------------------------------------------------------===//
579+
580+
% for Traversal in TRAVERSALS:
581+
% Self = 'BufferAccessLoggingMutable' + collectionForTraversal(Traversal)
582+
/// Interposes between `_withUnsafeMutableBufferPointerIfSupported` method calls
583+
/// to increment a counter. Calls to this method from within dispatched methods
584+
/// are uncounted by the standard logging collection wrapper.
585+
public struct ${Self}<
586+
Base : MutableCollection & ${collectionForTraversal(Traversal)}
587+
> : MutableCollection, ${collectionForTraversal(Traversal)}, LoggingType {
588+
589+
public var base: Base
590+
591+
public typealias Log = MutableCollectionLog
592+
593+
public typealias SubSequence = Base.SubSequence
594+
595+
public typealias Iterator = Base.Iterator
596+
597+
public init(wrapping base: Base) {
598+
self.base = base
599+
}
600+
601+
public func makeIterator() -> Iterator {
602+
return base.makeIterator()
603+
}
604+
605+
public typealias Index = Base.Index
606+
607+
public var startIndex: Index {
608+
return base.startIndex
609+
}
610+
611+
public var endIndex: Index {
612+
return base.endIndex
613+
}
614+
615+
public subscript(position: Index) -> Base.Iterator.Element {
616+
get {
617+
return base[position]
618+
}
619+
set {
620+
base[position] = newValue
621+
}
622+
}
623+
624+
public subscript(bounds: Range<Index>) -> SubSequence {
625+
get {
626+
return base[bounds]
627+
}
628+
set {
629+
base[bounds] = newValue
630+
}
631+
}
632+
633+
public func index(after i: Index) -> Index {
634+
return base.index(after: i)
635+
}
636+
637+
% if Traversal in ['Bidirectional', 'RandomAccess']:
638+
public func index(before i: Index) -> Index {
639+
return base.index(before: i)
640+
}
641+
% end
642+
643+
public func index(_ i: Index, offsetBy n: Base.IndexDistance) -> Index {
644+
return base.index(i, offsetBy: n)
645+
}
646+
647+
public func distance(from start: Index, to end: Index) -> Base.IndexDistance {
648+
return base.distance(from: start, to: end)
649+
}
650+
651+
public mutating func _withUnsafeMutableBufferPointerIfSupported<R>(
652+
_ body: @noescape (UnsafeMutablePointer<Iterator.Element>, Int) throws -> R
653+
) rethrows -> R? {
654+
print("Log._withUnsafeMutableBufferPointerIfSupported[selfType] += 1")
655+
Log._withUnsafeMutableBufferPointerIfSupported[selfType] += 1
656+
let result = try base._withUnsafeMutableBufferPointerIfSupported(body)
657+
if result != nil {
658+
Log._withUnsafeMutableBufferPointerIfSupportedNonNilReturns[selfType] += 1
659+
}
660+
return result
661+
}
662+
}
663+
% end
664+
568665
//===----------------------------------------------------------------------===//
569666
// Custom assertions
570667
//===----------------------------------------------------------------------===//

validation-test/stdlib/CollectionType.swift.gyb

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,16 @@ CollectionTypeTests.test("Collection/suffix(from:)/dispatch") {
778778
// MutableCollection
779779
//===----------------------------------------------------------------------===//
780780

781+
//===----------------------------------------------------------------------===//
782+
// partition(by:)
783+
//===----------------------------------------------------------------------===//
784+
785+
CollectionTypeTests.test("partition(by:)/dispatch") {
786+
var tester = MutableCollectionLog.dispatchTester([OpaqueValue(1)])
787+
_ = tester.partition(by: { _ in true })
788+
expectCustomizable(tester, tester.log.partitionBy)
789+
}
790+
781791
//===----------------------------------------------------------------------===//
782792
// _withUnsafeMutableBufferPointerIfSupported()
783793
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)