Skip to content

[stdlib] Optimize high-level Set operations #40012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions stdlib/public/core/Bitset.swift
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,51 @@ extension _UnsafeBitset.Word: Sequence, IteratorProtocol {
return bit
}
}

extension _UnsafeBitset {
@_alwaysEmitIntoClient
@inline(__always)
internal static func _withTemporaryUninitializedBitset<R>(
wordCount: Int,
body: (_UnsafeBitset) throws -> R
) rethrows -> R {
try withUnsafeTemporaryAllocation(
of: _UnsafeBitset.Word.self, capacity: wordCount
) { buffer in
let bitset = _UnsafeBitset(
words: buffer.baseAddress!, wordCount: buffer.count)
return try body(bitset)
}
}

@_alwaysEmitIntoClient
@inline(__always)
internal static func withTemporaryBitset<R>(
capacity: Int,
body: (_UnsafeBitset) throws -> R
) rethrows -> R {
let wordCount = Swift.max(1, Self.wordCount(forCapacity: capacity))
return try _withTemporaryUninitializedBitset(
wordCount: wordCount
) { bitset in
bitset.clear()
return try body(bitset)
}
}
}

extension _UnsafeBitset {
@_alwaysEmitIntoClient
@inline(__always)
internal static func withTemporaryCopy<R>(
of original: _UnsafeBitset,
body: (_UnsafeBitset) throws -> R
) rethrows -> R {
try _withTemporaryUninitializedBitset(
wordCount: original.wordCount
) { bitset in
bitset.words.initialize(from: original.words, count: original.wordCount)
return try body(bitset)
}
}
}
13 changes: 12 additions & 1 deletion stdlib/public/core/HashTable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ internal struct _HashTable {
@inlinable
internal var bucketCount: Int {
@inline(__always) get {
return bucketMask &+ 1
return _assumeNonNegative(bucketMask &+ 1)
}
}

Expand All @@ -52,6 +52,17 @@ internal struct _HashTable {
return _UnsafeBitset.wordCount(forCapacity: bucketCount)
}
}

/// Return a bitset representation of the occupied buckets in this table.
///
/// Note that if we have only a single partial word in the hash table's
/// bitset, then its out-of-bounds bits are guaranteed to be all set. These
/// filler bits are there to speed up finding holes -- they don't correspond
/// to occupied buckets in the table.
@_alwaysEmitIntoClient
internal var bitset: _UnsafeBitset {
_UnsafeBitset(words: words, wordCount: wordCount)
}
}

extension _HashTable {
Expand Down
200 changes: 200 additions & 0 deletions stdlib/public/core/NativeSet.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ extension _NativeSet { // Primitive fields
}
}

@_alwaysEmitIntoClient
@inline(__always)
internal var bucketCount: Int {
_assumeNonNegative(_storage._bucketCount)
}

@inlinable
internal var hashTable: _HashTable {
@inline(__always) get {
Expand Down Expand Up @@ -580,3 +586,197 @@ extension _NativeSet.Iterator: IteratorProtocol {
return base.uncheckedElement(at: index)
}
}

extension _NativeSet {
@_alwaysEmitIntoClient
internal func isSubset<S: Sequence>(of possibleSuperset: S) -> Bool
where S.Element == Element {
_UnsafeBitset.withTemporaryBitset(capacity: self.bucketCount) { seen in
// Mark elements in self that we've seen in `possibleSuperset`.
var seenCount = 0
for element in possibleSuperset {
let (bucket, found) = find(element)
guard found else { continue }
let inserted = seen.uncheckedInsert(bucket.offset)
if inserted {
seenCount += 1
if seenCount == self.count {
return true
}
}
}
return false
}
}

@_alwaysEmitIntoClient
internal func isStrictSubset<S: Sequence>(of possibleSuperset: S) -> Bool
where S.Element == Element {
_UnsafeBitset.withTemporaryBitset(capacity: self.bucketCount) { seen in
// Mark elements in self that we've seen in `possibleSuperset`.
var seenCount = 0
var isStrict = false
for element in possibleSuperset {
let (bucket, found) = find(element)
guard found else {
if !isStrict {
isStrict = true
if seenCount == self.count { return true }
}
continue
}
let inserted = seen.uncheckedInsert(bucket.offset)
if inserted {
seenCount += 1
if seenCount == self.count, isStrict {
return true
}
}
}
return false
}
}

@_alwaysEmitIntoClient
internal func isStrictSuperset<S: Sequence>(of possibleSubset: S) -> Bool
where S.Element == Element {
_UnsafeBitset.withTemporaryBitset(capacity: self.bucketCount) { seen in
// Mark elements in self that we've seen in `possibleStrictSubset`.
var seenCount = 0
for element in possibleSubset {
let (bucket, found) = find(element)
guard found else { return false }
let inserted = seen.uncheckedInsert(bucket.offset)
if inserted {
seenCount += 1
if seenCount == self.count {
return false
}
}
}
return true
}
}

@_alwaysEmitIntoClient
internal __consuming func extractSubset(
using bitset: _UnsafeBitset,
count: Int
) -> _NativeSet {
var count = count
if count == 0 { return _NativeSet() }
if count == self.count { return self }
let result = _NativeSet(capacity: count)
for offset in bitset {
result._unsafeInsertNew(self.uncheckedElement(at: Bucket(offset: offset)))
// The hash table can have set bits after the end of the bitmap.
// Ignore them.
count -= 1
if count == 0 { break }
}
return result
}

@_alwaysEmitIntoClient
internal __consuming func subtracting<S: Sequence>(_ other: S) -> _NativeSet
where S.Element == Element {
guard count > 0 else { return _NativeSet() }

// Find one item that we need to remove before creating a result set.
var it = other.makeIterator()
var bucket: Bucket? = nil
while let next = it.next() {
let (b, found) = find(next)
if found {
bucket = b
break
}
}
guard let bucket = bucket else { return self }

// Rather than directly creating a new set, calculate the difference in a
// bitset first. This ensures we hash each element (in both sets) only once,
// and that we'll have an exact count for the result set, preventing
// rehashings during insertions.
return _UnsafeBitset.withTemporaryCopy(of: hashTable.bitset) { difference in
var remainingCount = self.count

let removed = difference.uncheckedRemove(bucket.offset)
_internalInvariant(removed)
remainingCount -= 1

while let element = it.next() {
let (bucket, found) = find(element)
if found {
if difference.uncheckedRemove(bucket.offset) {
remainingCount -= 1
if remainingCount == 0 { return _NativeSet() }
}
}
}
_internalInvariant(difference.count > 0)
return extractSubset(using: difference, count: remainingCount)
}
}

@_alwaysEmitIntoClient
internal __consuming func filter(
_ isIncluded: (Element) throws -> Bool
) rethrows -> _NativeSet<Element> {
try _UnsafeBitset.withTemporaryBitset(capacity: bucketCount) { bitset in
var count = 0
for bucket in hashTable {
if try isIncluded(uncheckedElement(at: bucket)) {
bitset.uncheckedInsert(bucket.offset)
count += 1
}
}
return extractSubset(using: bitset, count: count)
}
}

@_alwaysEmitIntoClient
internal __consuming func intersection(
_ other: _NativeSet<Element>
) -> _NativeSet<Element> {
// Prefer to iterate over the smaller set. However, we must be careful to
// only include elements from `self`, not `other`.
guard self.count <= other.count else {
return genericIntersection(other)
}
// Rather than directly creating a new set, mark common elements in a bitset
// first. This minimizes hashing, and ensures that we'll have an exact count
// for the result set, preventing rehashings during insertions.
return _UnsafeBitset.withTemporaryBitset(capacity: bucketCount) { bitset in
var count = 0
for bucket in hashTable {
if other.find(uncheckedElement(at: bucket)).found {
bitset.uncheckedInsert(bucket.offset)
count += 1
}
}
return extractSubset(using: bitset, count: count)
}
}

@_alwaysEmitIntoClient
internal __consuming func genericIntersection<S: Sequence>(
_ other: S
) -> _NativeSet<Element>
where S.Element == Element {
// Rather than directly creating a new set, mark common elements in a bitset
// first. This minimizes hashing, and ensures that we'll have an exact count
// for the result set, preventing rehashings during insertions.
_UnsafeBitset.withTemporaryBitset(capacity: bucketCount) { bitset in
var count = 0
for element in other {
let (bucket, found) = find(element)
if found {
bitset.uncheckedInsert(bucket.offset)
count += 1
}
}
return extractSubset(using: bitset, count: count)
}
}
}
Loading