Skip to content

Commit 6d33683

Browse files
authored
Merge pull request #40012 from lorentey/set-on-fire2
[stdlib] Optimize high-level Set operations
2 parents bd67413 + 172b1b8 commit 6d33683

File tree

6 files changed

+923
-50
lines changed

6 files changed

+923
-50
lines changed

stdlib/public/core/Bitset.swift

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,51 @@ extension _UnsafeBitset.Word: Sequence, IteratorProtocol {
337337
return bit
338338
}
339339
}
340+
341+
extension _UnsafeBitset {
342+
@_alwaysEmitIntoClient
343+
@inline(__always)
344+
internal static func _withTemporaryUninitializedBitset<R>(
345+
wordCount: Int,
346+
body: (_UnsafeBitset) throws -> R
347+
) rethrows -> R {
348+
try withUnsafeTemporaryAllocation(
349+
of: _UnsafeBitset.Word.self, capacity: wordCount
350+
) { buffer in
351+
let bitset = _UnsafeBitset(
352+
words: buffer.baseAddress!, wordCount: buffer.count)
353+
return try body(bitset)
354+
}
355+
}
356+
357+
@_alwaysEmitIntoClient
358+
@inline(__always)
359+
internal static func withTemporaryBitset<R>(
360+
capacity: Int,
361+
body: (_UnsafeBitset) throws -> R
362+
) rethrows -> R {
363+
let wordCount = Swift.max(1, Self.wordCount(forCapacity: capacity))
364+
return try _withTemporaryUninitializedBitset(
365+
wordCount: wordCount
366+
) { bitset in
367+
bitset.clear()
368+
return try body(bitset)
369+
}
370+
}
371+
}
372+
373+
extension _UnsafeBitset {
374+
@_alwaysEmitIntoClient
375+
@inline(__always)
376+
internal static func withTemporaryCopy<R>(
377+
of original: _UnsafeBitset,
378+
body: (_UnsafeBitset) throws -> R
379+
) rethrows -> R {
380+
try _withTemporaryUninitializedBitset(
381+
wordCount: original.wordCount
382+
) { bitset in
383+
bitset.words.initialize(from: original.words, count: original.wordCount)
384+
return try body(bitset)
385+
}
386+
}
387+
}

stdlib/public/core/HashTable.swift

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ internal struct _HashTable {
4242
@inlinable
4343
internal var bucketCount: Int {
4444
@inline(__always) get {
45-
return bucketMask &+ 1
45+
return _assumeNonNegative(bucketMask &+ 1)
4646
}
4747
}
4848

@@ -52,6 +52,17 @@ internal struct _HashTable {
5252
return _UnsafeBitset.wordCount(forCapacity: bucketCount)
5353
}
5454
}
55+
56+
/// Return a bitset representation of the occupied buckets in this table.
57+
///
58+
/// Note that if we have only a single partial word in the hash table's
59+
/// bitset, then its out-of-bounds bits are guaranteed to be all set. These
60+
/// filler bits are there to speed up finding holes -- they don't correspond
61+
/// to occupied buckets in the table.
62+
@_alwaysEmitIntoClient
63+
internal var bitset: _UnsafeBitset {
64+
_UnsafeBitset(words: words, wordCount: wordCount)
65+
}
5566
}
5667

5768
extension _HashTable {

stdlib/public/core/NativeSet.swift

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ extension _NativeSet { // Primitive fields
7777
}
7878
}
7979

80+
@_alwaysEmitIntoClient
81+
@inline(__always)
82+
internal var bucketCount: Int {
83+
_assumeNonNegative(_storage._bucketCount)
84+
}
85+
8086
@inlinable
8187
internal var hashTable: _HashTable {
8288
@inline(__always) get {
@@ -580,3 +586,197 @@ extension _NativeSet.Iterator: IteratorProtocol {
580586
return base.uncheckedElement(at: index)
581587
}
582588
}
589+
590+
extension _NativeSet {
591+
@_alwaysEmitIntoClient
592+
internal func isSubset<S: Sequence>(of possibleSuperset: S) -> Bool
593+
where S.Element == Element {
594+
_UnsafeBitset.withTemporaryBitset(capacity: self.bucketCount) { seen in
595+
// Mark elements in self that we've seen in `possibleSuperset`.
596+
var seenCount = 0
597+
for element in possibleSuperset {
598+
let (bucket, found) = find(element)
599+
guard found else { continue }
600+
let inserted = seen.uncheckedInsert(bucket.offset)
601+
if inserted {
602+
seenCount += 1
603+
if seenCount == self.count {
604+
return true
605+
}
606+
}
607+
}
608+
return false
609+
}
610+
}
611+
612+
@_alwaysEmitIntoClient
613+
internal func isStrictSubset<S: Sequence>(of possibleSuperset: S) -> Bool
614+
where S.Element == Element {
615+
_UnsafeBitset.withTemporaryBitset(capacity: self.bucketCount) { seen in
616+
// Mark elements in self that we've seen in `possibleSuperset`.
617+
var seenCount = 0
618+
var isStrict = false
619+
for element in possibleSuperset {
620+
let (bucket, found) = find(element)
621+
guard found else {
622+
if !isStrict {
623+
isStrict = true
624+
if seenCount == self.count { return true }
625+
}
626+
continue
627+
}
628+
let inserted = seen.uncheckedInsert(bucket.offset)
629+
if inserted {
630+
seenCount += 1
631+
if seenCount == self.count, isStrict {
632+
return true
633+
}
634+
}
635+
}
636+
return false
637+
}
638+
}
639+
640+
@_alwaysEmitIntoClient
641+
internal func isStrictSuperset<S: Sequence>(of possibleSubset: S) -> Bool
642+
where S.Element == Element {
643+
_UnsafeBitset.withTemporaryBitset(capacity: self.bucketCount) { seen in
644+
// Mark elements in self that we've seen in `possibleStrictSubset`.
645+
var seenCount = 0
646+
for element in possibleSubset {
647+
let (bucket, found) = find(element)
648+
guard found else { return false }
649+
let inserted = seen.uncheckedInsert(bucket.offset)
650+
if inserted {
651+
seenCount += 1
652+
if seenCount == self.count {
653+
return false
654+
}
655+
}
656+
}
657+
return true
658+
}
659+
}
660+
661+
@_alwaysEmitIntoClient
662+
internal __consuming func extractSubset(
663+
using bitset: _UnsafeBitset,
664+
count: Int
665+
) -> _NativeSet {
666+
var count = count
667+
if count == 0 { return _NativeSet() }
668+
if count == self.count { return self }
669+
let result = _NativeSet(capacity: count)
670+
for offset in bitset {
671+
result._unsafeInsertNew(self.uncheckedElement(at: Bucket(offset: offset)))
672+
// The hash table can have set bits after the end of the bitmap.
673+
// Ignore them.
674+
count -= 1
675+
if count == 0 { break }
676+
}
677+
return result
678+
}
679+
680+
@_alwaysEmitIntoClient
681+
internal __consuming func subtracting<S: Sequence>(_ other: S) -> _NativeSet
682+
where S.Element == Element {
683+
guard count > 0 else { return _NativeSet() }
684+
685+
// Find one item that we need to remove before creating a result set.
686+
var it = other.makeIterator()
687+
var bucket: Bucket? = nil
688+
while let next = it.next() {
689+
let (b, found) = find(next)
690+
if found {
691+
bucket = b
692+
break
693+
}
694+
}
695+
guard let bucket = bucket else { return self }
696+
697+
// Rather than directly creating a new set, calculate the difference in a
698+
// bitset first. This ensures we hash each element (in both sets) only once,
699+
// and that we'll have an exact count for the result set, preventing
700+
// rehashings during insertions.
701+
return _UnsafeBitset.withTemporaryCopy(of: hashTable.bitset) { difference in
702+
var remainingCount = self.count
703+
704+
let removed = difference.uncheckedRemove(bucket.offset)
705+
_internalInvariant(removed)
706+
remainingCount -= 1
707+
708+
while let element = it.next() {
709+
let (bucket, found) = find(element)
710+
if found {
711+
if difference.uncheckedRemove(bucket.offset) {
712+
remainingCount -= 1
713+
if remainingCount == 0 { return _NativeSet() }
714+
}
715+
}
716+
}
717+
_internalInvariant(difference.count > 0)
718+
return extractSubset(using: difference, count: remainingCount)
719+
}
720+
}
721+
722+
@_alwaysEmitIntoClient
723+
internal __consuming func filter(
724+
_ isIncluded: (Element) throws -> Bool
725+
) rethrows -> _NativeSet<Element> {
726+
try _UnsafeBitset.withTemporaryBitset(capacity: bucketCount) { bitset in
727+
var count = 0
728+
for bucket in hashTable {
729+
if try isIncluded(uncheckedElement(at: bucket)) {
730+
bitset.uncheckedInsert(bucket.offset)
731+
count += 1
732+
}
733+
}
734+
return extractSubset(using: bitset, count: count)
735+
}
736+
}
737+
738+
@_alwaysEmitIntoClient
739+
internal __consuming func intersection(
740+
_ other: _NativeSet<Element>
741+
) -> _NativeSet<Element> {
742+
// Prefer to iterate over the smaller set. However, we must be careful to
743+
// only include elements from `self`, not `other`.
744+
guard self.count <= other.count else {
745+
return genericIntersection(other)
746+
}
747+
// Rather than directly creating a new set, mark common elements in a bitset
748+
// first. This minimizes hashing, and ensures that we'll have an exact count
749+
// for the result set, preventing rehashings during insertions.
750+
return _UnsafeBitset.withTemporaryBitset(capacity: bucketCount) { bitset in
751+
var count = 0
752+
for bucket in hashTable {
753+
if other.find(uncheckedElement(at: bucket)).found {
754+
bitset.uncheckedInsert(bucket.offset)
755+
count += 1
756+
}
757+
}
758+
return extractSubset(using: bitset, count: count)
759+
}
760+
}
761+
762+
@_alwaysEmitIntoClient
763+
internal __consuming func genericIntersection<S: Sequence>(
764+
_ other: S
765+
) -> _NativeSet<Element>
766+
where S.Element == Element {
767+
// Rather than directly creating a new set, mark common elements in a bitset
768+
// first. This minimizes hashing, and ensures that we'll have an exact count
769+
// for the result set, preventing rehashings during insertions.
770+
_UnsafeBitset.withTemporaryBitset(capacity: bucketCount) { bitset in
771+
var count = 0
772+
for element in other {
773+
let (bucket, found) = find(element)
774+
if found {
775+
bitset.uncheckedInsert(bucket.offset)
776+
count += 1
777+
}
778+
}
779+
return extractSubset(using: bitset, count: count)
780+
}
781+
}
782+
}

0 commit comments

Comments
 (0)