Skip to content

fix thread safety issue in package collections #3136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions Sources/Basics/ConcurrencyHelpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,64 @@ public final class ThreadSafeKeyValueStore<Key, Value> where Key: Hashable {
}
}

/// Thread-safe array like structure
public final class ThreadSafeArrayStore<Value> {
private var underlying: [Value]
private let lock = Lock()

public init(_ seed: [Value] = []) {
self.underlying = seed
}

public subscript(index: Int) -> Value? {
self.lock.withLock {
self.underlying[index]
}
}

public func get() -> [Value] {
self.lock.withLock {
self.underlying
}
}

public func clear() {
self.lock.withLock {
self.underlying = []
}
}

public func append(_ item: Value) {
self.lock.withLock {
self.underlying.append(item)
}
}

public var count: Int {
self.lock.withLock {
self.underlying.count
}
}

public var isEmpty: Bool {
self.lock.withLock {
self.underlying.isEmpty
}
}

public func map<NewValue>(_ transform: (Value) -> NewValue) -> [NewValue] {
self.lock.withLock {
self.underlying.map(transform)
}
}

public func compactMap<NewValue>(_ transform: (Value) throws -> NewValue?) rethrows -> [NewValue] {
try self.lock.withLock {
try self.underlying.compactMap(transform)
}
}
}

/// Thread-safe value boxing structure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Thread-safe value boxing structure
/// Thread-safe value boxing structure

public final class ThreadSafeBox<Value> {
private var underlying: Value?
Expand Down
28 changes: 14 additions & 14 deletions Sources/PackageCollections/JSONModel/JSONCollection+v1.swift
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,14 @@ extension JSONPackageCollectionModel.V1 {
extension JSONPackageCollectionModel.V1 {
public struct Validator {
public let configuration: Configuration

public init(configuration: Configuration = .init()) {
self.configuration = configuration
}

public func validate(collection: Collection) -> [ValidationMessage]? {
var messages = [ValidationMessage]()

let packages = collection.packages
// Stop validating if collection doesn't pass basic checks
if packages.isEmpty {
Expand All @@ -265,24 +265,24 @@ extension JSONPackageCollectionModel.V1 {
} else {
packages.forEach { self.validate(package: $0, messages: &messages) }
}

guard messages.isEmpty else {
return messages
}

return nil
}

// TODO: validate package url?
private func validate(package: Collection.Package, messages: inout [ValidationMessage]) {
let packageID = PackageIdentity(url: package.url.absoluteString).description

// Check for duplicate versions
let nonUniqueVersions = Dictionary(grouping: package.versions, by: { $0.version }).filter { $1.count > 1 }.keys
if !nonUniqueVersions.isEmpty {
messages.append(.error("Duplicate version(s) found in package \(packageID): \(nonUniqueVersions).", property: "package.versions"))
}

var nonSemanticVersions = [String]()
let semanticVersions: [TSCUtility.Version] = package.versions.compactMap {
let semver = TSCUtility.Version(string: $0.version)
Expand All @@ -291,15 +291,15 @@ extension JSONPackageCollectionModel.V1 {
}
return semver
}

guard nonSemanticVersions.isEmpty else {
messages.append(.error("Non semantic version(s) found in package \(packageID): \(nonSemanticVersions).", property: "package.versions"))
// The next part of validation requires sorting the semvers. Cannot continue if non-semver.
return
}

let sortedVersions = semanticVersions.sorted(by: >)

var currentMajor: Int?
var majorCount = 0
var minorCount = 0
Expand All @@ -322,7 +322,7 @@ extension JSONPackageCollectionModel.V1 {

minorCount += 1
}

package.versions.forEach { version in
if version.products.isEmpty {
messages.append(.error("Package \(packageID) version \(version.version) does not contain any products.", property: "version.products"))
Expand All @@ -332,13 +332,13 @@ extension JSONPackageCollectionModel.V1 {
messages.append(.error("Product \(product.name) of package \(packageID) version \(version.version) does not contain any targets.", property: "product.targets"))
}
}

if version.targets.isEmpty {
messages.append(.error("Package \(packageID) version \(version.version) does not contain any targets.", property: "version.targets"))
}
}
}

public struct Configuration {
public var maximumPackageCount: Int
public var maximumMajorVersionCount: Int
Expand Down
14 changes: 7 additions & 7 deletions Sources/PackageCollections/PackageCollections+Validation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,26 @@ public struct ValidationMessage: Equatable, CustomStringConvertible {
public let message: String
public let level: Level
public let property: String?

private init(_ message: String, level: Level, property: String? = nil) {
self.message = message
self.level = level
self.property = property
}

static func error(_ message: String, property: String? = nil) -> ValidationMessage {
.init(message, level: .error, property: property)
}

static func warning(_ message: String, property: String? = nil) -> ValidationMessage {
.init(message, level: .warning, property: property)
}

public enum Level: String, Equatable {
case warning
case error
}

public var description: String {
"[\(self.level)] \(self.property.map { "\($0): " } ?? "")\(self.message)"
}
Expand All @@ -68,9 +68,9 @@ public struct ValidationMessage: Equatable, CustomStringConvertible {
extension Array where Element == ValidationMessage {
func errors(include levels: Set<ValidationMessage.Level> = [.error]) -> [ValidationError]? {
let errors = self.filter { levels.contains($0.level) }

guard !errors.isEmpty else { return nil }

return errors.map {
if let property = $0.property {
return ValidationError.property(name: property, message: $0.message)
Expand Down
8 changes: 4 additions & 4 deletions Sources/PackageCollections/PackageCollections.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
See http://swift.org/CONTRIBUTORS.txt for Swift project authors
*/

import Basics
import PackageModel
import TSCBasic

Expand Down Expand Up @@ -100,12 +101,11 @@ public struct PackageCollections: PackageCollectionsProtocol {
if sources.isEmpty {
return callback(.success([]))
}
let lock = Lock()
var refreshResults = [Result<Model.Collection, Error>]()
let refreshResults = ThreadSafeArrayStore<Result<Model.Collection, Error>>()
sources.forEach { source in
self.refreshCollectionFromSource(source: source) { refreshResult in
lock.withLock { refreshResults.append(refreshResult) }
if refreshResults.count == (lock.withLock { sources.count }) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was the issue that prompted the PR. I saw a few PRs failing on PackageCollectionsTests.testHappyRefresh inconsistently so I was hunting for a race condition (interesting TSAN did not report this). the issue here is that the lock is on sources instead of refreshResults. so in addition to fixing this, I decided to create this new utility for arrays so we dont have to mess with locks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed this coding mistake in #3108 though didn't anticipate it causing race condition. Anyway, 👍

refreshResults.append(refreshResult)
if refreshResults.count == sources.count {
let errors = refreshResults.compactMap { $0.failure }
callback(errors.isEmpty ? .success(sources) : .failure(MultipleErrors(errors)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ struct GitHubPackageMetadataProvider: PackageMetadataProvider {
let readmeURL = baseURL.appendingPathComponent("readme")

let sync = DispatchGroup()
var results = [URL: Result<HTTPClientResponse, Error>]()
let resultsLock = Lock()
let results = ThreadSafeKeyValueStore<URL, Result<HTTPClientResponse, Error>>()

// get the main data
sync.enter()
Expand All @@ -58,34 +57,22 @@ struct GitHubPackageMetadataProvider: PackageMetadataProvider {
let metadataOptions = self.makeRequestOptions(validResponseCodes: [200, 401, 403, 404])
httpClient.get(metadataURL, headers: metadataHeaders, options: metadataOptions) { result in
defer { sync.leave() }
resultsLock.withLock {
results[metadataURL] = result
}
results[metadataURL] = result
if case .success(let response) = result {
let apiLimit = response.headers.get("X-RateLimit-Limit").first.flatMap(Int.init) ?? -1
let apiRemaining = response.headers.get("X-RateLimit-Remaining").first.flatMap(Int.init) ?? -1
switch (response.statusCode, metadataHeaders.contains("Authorization"), apiRemaining) {
case (_, _, 0):
self.diagnosticsEngine?.emit(warning: "Exceeded API limits on \(metadataURL.host ?? metadataURL.absoluteString) (\(apiRemaining)/\(apiLimit)), consider configuring an API token for this service.")
resultsLock.withLock {
results[metadataURL] = .failure(Errors.apiLimitsExceeded(metadataURL, apiLimit))
}
results[metadataURL] = .failure(Errors.apiLimitsExceeded(metadataURL, apiLimit))
case (401, true, _):
resultsLock.withLock {
results[metadataURL] = .failure(Errors.invalidAuthToken(metadataURL))
}
results[metadataURL] = .failure(Errors.invalidAuthToken(metadataURL))
case (401, false, _):
resultsLock.withLock {
results[metadataURL] = .failure(Errors.permissionDenied(metadataURL))
}
results[metadataURL] = .failure(Errors.permissionDenied(metadataURL))
case (403, _, _):
resultsLock.withLock {
results[metadataURL] = .failure(Errors.permissionDenied(metadataURL))
}
results[metadataURL] = .failure(Errors.permissionDenied(metadataURL))
case (404, _, _):
resultsLock.withLock {
results[metadataURL] = .failure(NotFoundError("\(baseURL)"))
}
results[metadataURL] = .failure(NotFoundError("\(baseURL)"))
case (200, _, _):
if apiRemaining < self.configuration.apiLimitWarningThreshold {
self.diagnosticsEngine?.emit(warning: "Approaching API limits on \(metadataURL.host ?? metadataURL.absoluteString) (\(apiRemaining)/\(apiLimit)), consider configuring an API token for this service.")
Expand All @@ -98,15 +85,11 @@ struct GitHubPackageMetadataProvider: PackageMetadataProvider {
let options = self.makeRequestOptions(validResponseCodes: [200])
self.httpClient.get(url, headers: headers, options: options) { result in
defer { sync.leave() }
resultsLock.withLock {
results[url] = result
}
results[url] = result
}
}
default:
resultsLock.withLock {
results[metadataURL] = .failure(Errors.invalidResponse(metadataURL, "Invalid status code: \(response.statusCode)"))
}
results[metadataURL] = .failure(Errors.invalidResponse(metadataURL, "Invalid status code: \(response.statusCode)"))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ final class SQLitePackageCollectionsStorage: PackageCollectionsStorage, Closable
private var state = State.idle
private let stateLock = Lock()

private var cache = [Model.CollectionIdentifier: Model.Collection]()
private let cacheLock = Lock()
private let cache = ThreadSafeKeyValueStore<Model.CollectionIdentifier, Model.Collection>()

init(location: SQLite.Location? = nil, diagnosticsEngine: DiagnosticsEngine? = nil) {
self.location = location ?? .path(localFileSystem.swiftPMCacheDirectory.appending(components: "package-collection.db"))
Expand Down Expand Up @@ -86,9 +85,7 @@ final class SQLitePackageCollectionsStorage: PackageCollectionsStorage, Closable
try statement.step()
}
// write to cache
self.cacheLock.withLock {
self.cache[collection.identifier] = collection
}
self.cache[collection.identifier] = collection
callback(.success(collection))
} catch {
callback(.failure(error))
Expand All @@ -110,9 +107,7 @@ final class SQLitePackageCollectionsStorage: PackageCollectionsStorage, Closable
try statement.step()
}
// write to cache
self.cacheLock.withLock {
self.cache[identifier] = nil
}
self.cache[identifier] = nil
callback(.success(()))
} catch {
callback(.failure(error))
Expand All @@ -123,7 +118,7 @@ final class SQLitePackageCollectionsStorage: PackageCollectionsStorage, Closable
func get(identifier: Model.CollectionIdentifier,
callback: @escaping (Result<Model.Collection, Error>) -> Void) {
// try read to cache
if let collection = (self.cacheLock.withLock { self.cache[identifier] }) {
if let collection = self.cache[identifier] {
return callback(.success(collection))
}

Expand Down Expand Up @@ -152,11 +147,7 @@ final class SQLitePackageCollectionsStorage: PackageCollectionsStorage, Closable
func list(identifiers: [Model.CollectionIdentifier]? = nil,
callback: @escaping (Result<[Model.Collection], Error>) -> Void) {
// try read to cache
let cached = self.cacheLock.withLock {
identifiers?.compactMap { identifier in
self.cache[identifier]
}
}
let cached = identifiers?.compactMap { self.cache[$0] }
if let cached = cached, cached.count > 0, cached.count == identifiers?.count {
return callback(.success(cached))
}
Expand Down Expand Up @@ -190,20 +181,17 @@ final class SQLitePackageCollectionsStorage: PackageCollectionsStorage, Closable
// decoding is a performance bottleneck (10+s for 1000 collections)
// workaround is to decode in parallel if list is large enough to justify it
let sync = DispatchGroup()
var collections: [Model.Collection]
let collections: ThreadSafeArrayStore<Model.Collection>
if blobs.count < Self.batchSize {
collections = blobs.compactMap { data -> Model.Collection? in
collections = .init(blobs.compactMap { data -> Model.Collection? in
try? self.decoder.decode(Model.Collection.self, from: data)
}
})
} else {
let lock = Lock()
collections = [Model.Collection]()
collections = .init()
blobs.forEach { data in
self.queue.async(group: sync) {
if let collection = try? self.decoder.decode(Model.Collection.self, from: data) {
lock.withLock {
collections.append(collection)
}
collections.append(collection)
}
}
}
Expand All @@ -213,7 +201,7 @@ final class SQLitePackageCollectionsStorage: PackageCollectionsStorage, Closable
if collections.count != blobs.count {
self.diagnosticsEngine?.emit(warning: "Some stored collections could not be deserialized. Please refresh the collections to resolve this issue.")
}
callback(.success(collections))
callback(.success(collections.get()))
}

} catch {
Expand Down Expand Up @@ -370,9 +358,7 @@ final class SQLitePackageCollectionsStorage: PackageCollectionsStorage, Closable

// for testing
internal func resetCache() {
self.cacheLock.withLock {
self.cache = [:]
}
self.cache.clear()
}

// MARK: - Private
Expand Down
Loading