Skip to content

Commit ae14833

Browse files
authored
Refactor registry checksum TOFU logic (#6190)
Motivation: Checksum TOFU logic is scattered in `RegistryClient`, making it difficult to reason about and maintain. Modifications: Refactor checksum TOFU logic into `PackageVersionChecksumTOFU`.
1 parent 083572e commit ae14833

File tree

8 files changed

+1376
-723
lines changed

8 files changed

+1376
-723
lines changed

Sources/PackageFingerprint/PackageFingerprintStorage.swift

Lines changed: 65 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// This source file is part of the Swift open source project
44
//
5-
// Copyright (c) 2021-2022 Apple Inc. and the Swift project authors
5+
// Copyright (c) 2021-2023 Apple Inc. and the Swift project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See http://swift.org/LICENSE.txt for license information
@@ -17,59 +17,83 @@ import PackageModel
1717
import struct TSCUtility.Version
1818

1919
public protocol PackageFingerprintStorage {
20-
func get(package: PackageIdentity,
21-
version: Version,
22-
observabilityScope: ObservabilityScope,
23-
callbackQueue: DispatchQueue,
24-
callback: @escaping (Result<[Fingerprint.Kind: Fingerprint], Error>) -> Void)
20+
func get(
21+
package: PackageIdentity,
22+
version: Version,
23+
observabilityScope: ObservabilityScope,
24+
callbackQueue: DispatchQueue,
25+
callback: @escaping (Result<[Fingerprint.Kind: Fingerprint], Error>) -> Void
26+
)
2527

26-
func put(package: PackageIdentity,
27-
version: Version,
28-
fingerprint: Fingerprint,
29-
observabilityScope: ObservabilityScope,
30-
callbackQueue: DispatchQueue,
31-
callback: @escaping (Result<Void, Error>) -> Void)
28+
func put(
29+
package: PackageIdentity,
30+
version: Version,
31+
fingerprint: Fingerprint,
32+
observabilityScope: ObservabilityScope,
33+
callbackQueue: DispatchQueue,
34+
callback: @escaping (Result<Void, Error>) -> Void
35+
)
3236

33-
func get(package: PackageReference,
34-
version: Version,
35-
observabilityScope: ObservabilityScope,
36-
callbackQueue: DispatchQueue,
37-
callback: @escaping (Result<[Fingerprint.Kind: Fingerprint], Error>) -> Void)
37+
func get(
38+
package: PackageReference,
39+
version: Version,
40+
observabilityScope: ObservabilityScope,
41+
callbackQueue: DispatchQueue,
42+
callback: @escaping (Result<[Fingerprint.Kind: Fingerprint], Error>) -> Void
43+
)
3844

39-
func put(package: PackageReference,
40-
version: Version,
41-
fingerprint: Fingerprint,
42-
observabilityScope: ObservabilityScope,
43-
callbackQueue: DispatchQueue,
44-
callback: @escaping (Result<Void, Error>) -> Void)
45+
func put(
46+
package: PackageReference,
47+
version: Version,
48+
fingerprint: Fingerprint,
49+
observabilityScope: ObservabilityScope,
50+
callbackQueue: DispatchQueue,
51+
callback: @escaping (Result<Void, Error>) -> Void
52+
)
4553
}
4654

47-
public extension PackageFingerprintStorage {
48-
func get(package: PackageIdentity,
49-
version: Version,
50-
kind: Fingerprint.Kind,
51-
observabilityScope: ObservabilityScope,
52-
callbackQueue: DispatchQueue,
53-
callback: @escaping (Result<Fingerprint, Error>) -> Void) {
54-
self.get(package: package, version: version, observabilityScope: observabilityScope, callbackQueue: callbackQueue) { result in
55+
extension PackageFingerprintStorage {
56+
public func get(
57+
package: PackageIdentity,
58+
version: Version,
59+
kind: Fingerprint.Kind,
60+
observabilityScope: ObservabilityScope,
61+
callbackQueue: DispatchQueue,
62+
callback: @escaping (Result<Fingerprint, Error>) -> Void
63+
) {
64+
self.get(
65+
package: package,
66+
version: version,
67+
observabilityScope: observabilityScope,
68+
callbackQueue: callbackQueue
69+
) { result in
5570
self.get(kind: kind, result, callback: callback)
5671
}
5772
}
5873

59-
func get(package: PackageReference,
60-
version: Version,
61-
kind: Fingerprint.Kind,
62-
observabilityScope: ObservabilityScope,
63-
callbackQueue: DispatchQueue,
64-
callback: @escaping (Result<Fingerprint, Error>) -> Void) {
65-
self.get(package: package, version: version, observabilityScope: observabilityScope, callbackQueue: callbackQueue) { result in
74+
public func get(
75+
package: PackageReference,
76+
version: Version,
77+
kind: Fingerprint.Kind,
78+
observabilityScope: ObservabilityScope,
79+
callbackQueue: DispatchQueue,
80+
callback: @escaping (Result<Fingerprint, Error>) -> Void
81+
) {
82+
self.get(
83+
package: package,
84+
version: version,
85+
observabilityScope: observabilityScope,
86+
callbackQueue: callbackQueue
87+
) { result in
6688
self.get(kind: kind, result, callback: callback)
6789
}
6890
}
6991

70-
private func get(kind: Fingerprint.Kind,
71-
_ fingerprintsResult: Result<[Fingerprint.Kind: Fingerprint], Error>,
72-
callback: @escaping (Result<Fingerprint, Error>) -> Void) {
92+
private func get(
93+
kind: Fingerprint.Kind,
94+
_ fingerprintsResult: Result<[Fingerprint.Kind: Fingerprint], Error>,
95+
callback: @escaping (Result<Fingerprint, Error>) -> Void
96+
) {
7397
callback(fingerprintsResult.tryMap { fingerprints in
7498
guard let fingerprint = fingerprints[kind] else {
7599
throw PackageFingerprintStorageError.notFound

Sources/PackageRegistry/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This source file is part of the Swift open source project
22
#
3-
# Copyright (c) 2021 Apple Inc. and the Swift project authors
3+
# Copyright (c) 2021-2023 Apple Inc. and the Swift project authors
44
# Licensed under Apache License v2.0 with Runtime Library Exception
55
#
66
# See http://swift.org/LICENSE.txt for license information
@@ -10,7 +10,8 @@ add_library(PackageRegistry STATIC
1010
Registry.swift
1111
RegistryConfiguration.swift
1212
RegistryClient.swift
13-
RegistryDownloadsManager.swift)
13+
RegistryDownloadsManager.swift
14+
ChecksumTOFU.swift)
1415
target_link_libraries(PackageRegistry PUBLIC
1516
Basics
1617
PackageFingerprint
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift open source project
4+
//
5+
// Copyright (c) 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See http://swift.org/LICENSE.txt for license information
9+
// See http://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
import Dispatch
14+
15+
import Basics
16+
import PackageFingerprint
17+
import PackageModel
18+
19+
import struct TSCUtility.Version
20+
21+
struct PackageVersionChecksumTOFU {
22+
private let fingerprintStorage: PackageFingerprintStorage?
23+
private let fingerprintCheckingMode: FingerprintCheckingMode
24+
25+
private let registryClient: RegistryClient
26+
27+
init(
28+
fingerprintStorage: PackageFingerprintStorage?,
29+
fingerprintCheckingMode: FingerprintCheckingMode,
30+
registryClient: RegistryClient
31+
) {
32+
self.fingerprintStorage = fingerprintStorage
33+
self.fingerprintCheckingMode = fingerprintCheckingMode
34+
self.registryClient = registryClient
35+
}
36+
37+
func check(
38+
registry: Registry,
39+
package: PackageIdentity.RegistryIdentity,
40+
version: Version,
41+
checksum: String,
42+
timeout: DispatchTimeInterval?,
43+
observabilityScope: ObservabilityScope,
44+
callbackQueue: DispatchQueue,
45+
completion: @escaping (Result<Void, Error>) -> Void
46+
) {
47+
self.getExpectedChecksum(
48+
registry: registry,
49+
package: package,
50+
version: version,
51+
timeout: timeout,
52+
observabilityScope: observabilityScope,
53+
callbackQueue: callbackQueue
54+
) { result in
55+
completion(
56+
result.tryMap { expectedChecksum in
57+
if checksum != expectedChecksum {
58+
switch self.fingerprintCheckingMode {
59+
case .strict:
60+
throw RegistryError.invalidChecksum(expected: expectedChecksum, actual: checksum)
61+
case .warn:
62+
observabilityScope
63+
.emit(
64+
warning: "The checksum \(checksum) does not match previously recorded value \(expectedChecksum)"
65+
)
66+
}
67+
}
68+
}
69+
)
70+
}
71+
}
72+
73+
private func getExpectedChecksum(
74+
registry: Registry,
75+
package: PackageIdentity.RegistryIdentity,
76+
version: Version,
77+
timeout: DispatchTimeInterval?,
78+
observabilityScope: ObservabilityScope,
79+
callbackQueue: DispatchQueue,
80+
completion: @escaping (Result<String, Error>) -> Void
81+
) {
82+
// We either use a previously recorded checksum, or fetch it from the registry.
83+
self.readFromStorage(
84+
package: package,
85+
version: version,
86+
observabilityScope: observabilityScope,
87+
callbackQueue: callbackQueue
88+
) { result in
89+
switch result {
90+
case .success(.some(let savedChecksum)):
91+
completion(.success(savedChecksum))
92+
default:
93+
// Try fetching checksum from registry if:
94+
// - No storage available
95+
// - Checksum not found in storage
96+
// - Reading from storage resulted in error
97+
self.registryClient.getRawPackageVersionMetadata(
98+
registry: registry,
99+
package: package,
100+
version: version,
101+
timeout: timeout,
102+
observabilityScope: observabilityScope,
103+
callbackQueue: callbackQueue
104+
) { result in
105+
switch result {
106+
case .success(let metadata):
107+
guard let sourceArchive = metadata.resources
108+
.first(where: { $0.name == "source-archive" })
109+
else {
110+
return completion(.failure(RegistryError.missingSourceArchive))
111+
}
112+
113+
guard let checksum = sourceArchive.checksum else {
114+
return completion(.failure(RegistryError.invalidSourceArchive))
115+
}
116+
117+
self.writeToStorage(
118+
registry: registry,
119+
package: package,
120+
version: version,
121+
checksum: checksum,
122+
observabilityScope: observabilityScope,
123+
callbackQueue: callbackQueue
124+
) { writeResult in
125+
completion(writeResult.tryMap { _ in checksum })
126+
}
127+
case .failure(RegistryError.failedRetrievingReleaseInfo(_, _, _, let error)):
128+
completion(.failure(RegistryError.failedRetrievingReleaseChecksum(
129+
registry: registry,
130+
package: package.underlying,
131+
version: version,
132+
error: error
133+
)))
134+
case .failure(let error):
135+
completion(.failure(RegistryError.failedRetrievingReleaseChecksum(
136+
registry: registry,
137+
package: package.underlying,
138+
version: version,
139+
error: error
140+
)))
141+
}
142+
}
143+
}
144+
}
145+
}
146+
147+
private func readFromStorage(
148+
package: PackageIdentity.RegistryIdentity,
149+
version: Version,
150+
observabilityScope: ObservabilityScope,
151+
callbackQueue: DispatchQueue,
152+
completion: @escaping (Result<String?, Error>) -> Void
153+
) {
154+
guard let fingerprintStorage = self.fingerprintStorage else {
155+
return completion(.success(nil))
156+
}
157+
158+
fingerprintStorage.get(
159+
package: package.underlying,
160+
version: version,
161+
kind: .registry,
162+
observabilityScope: observabilityScope,
163+
callbackQueue: callbackQueue
164+
) { result in
165+
switch result {
166+
case .success(let fingerprint):
167+
completion(.success(fingerprint.value))
168+
case .failure(PackageFingerprintStorageError.notFound):
169+
completion(.success(nil))
170+
case .failure(let error):
171+
observabilityScope
172+
.emit(error: "Failed to get registry fingerprint for \(package) \(version) from storage: \(error)")
173+
completion(.failure(error))
174+
}
175+
}
176+
}
177+
178+
private func writeToStorage(
179+
registry: Registry,
180+
package: PackageIdentity.RegistryIdentity,
181+
version: Version,
182+
checksum: String,
183+
observabilityScope: ObservabilityScope,
184+
callbackQueue: DispatchQueue,
185+
completion: @escaping (Result<Void, Error>) -> Void
186+
) {
187+
guard let fingerprintStorage = self.fingerprintStorage else {
188+
return completion(.success(()))
189+
}
190+
191+
fingerprintStorage.put(
192+
package: package.underlying,
193+
version: version,
194+
fingerprint: .init(origin: .registry(registry.url), value: checksum),
195+
observabilityScope: observabilityScope,
196+
callbackQueue: callbackQueue
197+
) { result in
198+
switch result {
199+
case .success:
200+
completion(.success(()))
201+
case .failure(PackageFingerprintStorageError.conflict(_, let existing)):
202+
switch self.fingerprintCheckingMode {
203+
case .strict:
204+
completion(.failure(RegistryError.checksumChanged(latest: checksum, previous: existing.value)))
205+
case .warn:
206+
observabilityScope
207+
.emit(
208+
warning: "The checksum \(checksum) from \(registry.url.absoluteString) does not match previously recorded value \(existing.value) from \(String(describing: existing.origin.url?.absoluteString))"
209+
)
210+
completion(.success(()))
211+
}
212+
case .failure(let error):
213+
completion(.failure(error))
214+
}
215+
}
216+
}
217+
}

0 commit comments

Comments
 (0)