Skip to content

Make observabilityScope and callbackQueue API args instead of instance members #6601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ struct JSONPackageCollectionProvider: PackageCollectionProvider {
self.httpClient = customHTTPClient ?? Self.makeDefaultHTTPClient()
self.signatureValidator = customSignatureValidator ?? PackageCollectionSigning(
trustedRootCertsDir: configuration.trustedRootCertsDir ?? (try? fileSystem.swiftPMConfigurationDirectory.appending("trust-root-certs").asURL) ?? AbsolutePath.root.asURL,
additionalTrustedRootCerts: sourceCertPolicy.allRootCerts.map { Array($0) },
observabilityScope: observabilityScope,
callbackQueue: .sharedConcurrent
additionalTrustedRootCerts: sourceCertPolicy.allRootCerts.map { Array($0) }
)
self.sourceCertPolicy = sourceCertPolicy
self.decoder = JSONDecoder.makeWithDefaults()
Expand Down Expand Up @@ -161,7 +159,12 @@ struct JSONPackageCollectionProvider: PackageCollectionProvider {
// Check the signature
let signatureResults = ThreadSafeArrayStore<Result<Void, Error>>()
certPolicyKeys.forEach { certPolicyKey in
self.signatureValidator.validate(signedCollection: signedCollection, certPolicyKey: certPolicyKey) { result in
self.signatureValidator.validate(
signedCollection: signedCollection,
certPolicyKey: certPolicyKey,
observabilityScope: self.observabilityScope,
callbackQueue: .sharedConcurrent
) { result in
let count = signatureResults.append(result)
if count == certPolicyKeys.count {
if signatureResults.compactMap({ $0.success }).first != nil {
Expand Down
102 changes: 58 additions & 44 deletions Sources/PackageCollectionsSigning/CertificatePolicy.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@ protocol CertificatePolicy {
/// - validationTime: Overrides the timestamp used for checking certificate expiry (e.g., for testing).
/// By default the current time is used.
/// - callback: The callback to invoke when the result is available.
func validate(certChain: [Certificate], validationTime: Date, callback: @escaping (Result<Void, Error>) -> Void)
func validate(
certChain: [Certificate],
validationTime: Date,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Void, Error>) -> Void
)
}

extension CertificatePolicy {
Expand All @@ -70,8 +76,19 @@ extension CertificatePolicy {
/// element of the array, with its issuer the next element and so on, and the root CA
/// certificate is last.
/// - callback: The callback to invoke when the result is available.
func validate(certChain: [Certificate], callback: @escaping (Result<Void, Error>) -> Void) {
self.validate(certChain: certChain, validationTime: Date(), callback: callback)
func validate(
certChain: [Certificate],
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Void, Error>) -> Void
) {
self.validate(
certChain: certChain,
validationTime: Date(),
observabilityScope: observabilityScope,
callbackQueue: callbackQueue,
callback: callback
)
}

func verify(
Expand Down Expand Up @@ -136,9 +153,7 @@ struct DefaultCertificatePolicy: CertificatePolicy {
let expectedSubjectUserID: String?
let expectedSubjectOrganizationalUnit: String?

private let callbackQueue: DispatchQueue
private let httpClient: HTTPClient
private let observabilityScope: ObservabilityScope

/// Initializes a `DefaultCertificatePolicy`.
///
Expand All @@ -155,28 +170,30 @@ struct DefaultCertificatePolicy: CertificatePolicy {
trustedRootCertsDir: URL?,
additionalTrustedRootCerts: [Certificate]?,
expectedSubjectUserID: String? = nil,
expectedSubjectOrganizationalUnit: String? = nil,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue
expectedSubjectOrganizationalUnit: String? = nil
) {
var trustedRoots = [Certificate]()
if let trustedRootCertsDir {
trustedRoots
.append(contentsOf: Self.loadCerts(at: trustedRootCertsDir, observabilityScope: observabilityScope))
.append(contentsOf: Self.loadCerts(at: trustedRootCertsDir))
}
if let additionalTrustedRootCerts {
trustedRoots.append(contentsOf: additionalTrustedRootCerts)
}
self.trustedRoots = trustedRoots
self.expectedSubjectUserID = expectedSubjectUserID
self.expectedSubjectOrganizationalUnit = expectedSubjectOrganizationalUnit
self.callbackQueue = callbackQueue
self.httpClient = HTTPClient.makeDefault()
self.observabilityScope = observabilityScope
}

func validate(certChain: [Certificate], validationTime: Date, callback: @escaping (Result<Void, Error>) -> Void) {
let wrappedCallback: (Result<Void, Error>) -> Void = { result in self.callbackQueue.async { callback(result) } }
func validate(
certChain: [Certificate],
validationTime: Date,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Void, Error>) -> Void
) {
let wrappedCallback: (Result<Void, Error>) -> Void = { result in callbackQueue.async { callback(result) } }

guard !certChain.isEmpty else {
return wrappedCallback(.failure(CertificatePolicyError.emptyCertChain))
Expand Down Expand Up @@ -207,8 +224,8 @@ struct DefaultCertificatePolicy: CertificatePolicy {
certChain: certChain,
trustedRoots: self.trustedRoots,
policies: policies,
observabilityScope: self.observabilityScope,
callbackQueue: self.callbackQueue,
observabilityScope: observabilityScope,
callbackQueue: callbackQueue,
callback: callback
)
}
Expand All @@ -223,9 +240,7 @@ struct ADPSwiftPackageCollectionCertificatePolicy: CertificatePolicy {
let expectedSubjectUserID: String?
let expectedSubjectOrganizationalUnit: String?

private let callbackQueue: DispatchQueue
private let httpClient: HTTPClient
private let observabilityScope: ObservabilityScope

/// Initializes a `ADPSwiftPackageCollectionCertificatePolicy`.
///
Expand All @@ -242,28 +257,30 @@ struct ADPSwiftPackageCollectionCertificatePolicy: CertificatePolicy {
trustedRootCertsDir: URL?,
additionalTrustedRootCerts: [Certificate]?,
expectedSubjectUserID: String? = nil,
expectedSubjectOrganizationalUnit: String? = nil,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue
expectedSubjectOrganizationalUnit: String? = nil
) {
var trustedRoots = [Certificate]()
if let trustedRootCertsDir {
trustedRoots
.append(contentsOf: Self.loadCerts(at: trustedRootCertsDir, observabilityScope: observabilityScope))
.append(contentsOf: Self.loadCerts(at: trustedRootCertsDir))
}
if let additionalTrustedRootCerts {
trustedRoots.append(contentsOf: additionalTrustedRootCerts)
}
self.trustedRoots = trustedRoots
self.expectedSubjectUserID = expectedSubjectUserID
self.expectedSubjectOrganizationalUnit = expectedSubjectOrganizationalUnit
self.callbackQueue = callbackQueue
self.httpClient = HTTPClient.makeDefault()
self.observabilityScope = observabilityScope
}

func validate(certChain: [Certificate], validationTime: Date, callback: @escaping (Result<Void, Error>) -> Void) {
let wrappedCallback: (Result<Void, Error>) -> Void = { result in self.callbackQueue.async { callback(result) } }
func validate(
certChain: [Certificate],
validationTime: Date,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Void, Error>) -> Void
) {
let wrappedCallback: (Result<Void, Error>) -> Void = { result in callbackQueue.async { callback(result) } }

guard !certChain.isEmpty else {
return wrappedCallback(.failure(CertificatePolicyError.emptyCertChain))
Expand Down Expand Up @@ -296,8 +313,8 @@ struct ADPSwiftPackageCollectionCertificatePolicy: CertificatePolicy {
certChain: certChain,
trustedRoots: self.trustedRoots,
policies: policies,
observabilityScope: self.observabilityScope,
callbackQueue: self.callbackQueue,
observabilityScope: observabilityScope,
callbackQueue: callbackQueue,
callback: callback
)
}
Expand All @@ -312,9 +329,7 @@ struct ADPAppleDistributionCertificatePolicy: CertificatePolicy {
let expectedSubjectUserID: String?
let expectedSubjectOrganizationalUnit: String?

private let callbackQueue: DispatchQueue
private let httpClient: HTTPClient
private let observabilityScope: ObservabilityScope

/// Initializes a `ADPAppleDistributionCertificatePolicy`.
///
Expand All @@ -331,28 +346,30 @@ struct ADPAppleDistributionCertificatePolicy: CertificatePolicy {
trustedRootCertsDir: URL?,
additionalTrustedRootCerts: [Certificate]?,
expectedSubjectUserID: String? = nil,
expectedSubjectOrganizationalUnit: String? = nil,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue
expectedSubjectOrganizationalUnit: String? = nil
) {
var trustedRoots = [Certificate]()
if let trustedRootCertsDir {
trustedRoots
.append(contentsOf: Self.loadCerts(at: trustedRootCertsDir, observabilityScope: observabilityScope))
.append(contentsOf: Self.loadCerts(at: trustedRootCertsDir))
}
if let additionalTrustedRootCerts {
trustedRoots.append(contentsOf: additionalTrustedRootCerts)
}
self.trustedRoots = trustedRoots
self.expectedSubjectUserID = expectedSubjectUserID
self.expectedSubjectOrganizationalUnit = expectedSubjectOrganizationalUnit
self.callbackQueue = callbackQueue
self.httpClient = HTTPClient.makeDefault()
self.observabilityScope = observabilityScope
}

func validate(certChain: [Certificate], validationTime: Date, callback: @escaping (Result<Void, Error>) -> Void) {
let wrappedCallback: (Result<Void, Error>) -> Void = { result in self.callbackQueue.async { callback(result) } }
func validate(
certChain: [Certificate],
validationTime: Date,
observabilityScope: ObservabilityScope,
callbackQueue: DispatchQueue,
callback: @escaping (Result<Void, Error>) -> Void
) {
let wrappedCallback: (Result<Void, Error>) -> Void = { result in callbackQueue.async { callback(result) } }

guard !certChain.isEmpty else {
return wrappedCallback(.failure(CertificatePolicyError.emptyCertChain))
Expand Down Expand Up @@ -385,8 +402,8 @@ struct ADPAppleDistributionCertificatePolicy: CertificatePolicy {
certChain: certChain,
trustedRoots: self.trustedRoots,
policies: policies,
observabilityScope: self.observabilityScope,
callbackQueue: self.callbackQueue,
observabilityScope: observabilityScope,
callbackQueue: callbackQueue,
callback: callback
)
}
Expand Down Expand Up @@ -620,18 +637,15 @@ enum CertificateStores {
// MARK: - Utils

extension CertificatePolicy {
fileprivate static func loadCerts(at directory: URL, observabilityScope: ObservabilityScope) -> [Certificate] {
fileprivate static func loadCerts(at directory: URL) -> [Certificate] {
var certs = [Certificate]()
if let enumerator = FileManager.default.enumerator(at: directory, includingPropertiesForKeys: nil) {
for case let fileURL as URL in enumerator {
do {
let certData = try Data(contentsOf: fileURL)
certs.append(try Certificate(derEncoded: Array(certData)))
} catch {
observabilityScope.emit(
warning: "The certificate \(fileURL) is invalid",
underlyingError: error
)
// do nothing
}
}
}
Expand Down
Loading