Skip to content

Commit 8768d06

Browse files
authored
Merge pull request #19294 from nkcsgexi/cherry-pick-byte-tree-deserialization
[swift-5.0-branch] cherry pick Alex's byte tree deserialization patches
2 parents 6bd4e3b + 61935fd commit 8768d06

13 files changed

+204
-120
lines changed

tools/SwiftSyntax/ByteTreeDeserialization.swift

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ import Foundation
1414

1515
// MARK: - ByteTree decoder protocols
1616

17+
struct ByteTreeUserInfoKey: Hashable {
18+
let rawValue: String
19+
20+
init(rawValue: String) {
21+
self.rawValue = rawValue
22+
}
23+
}
24+
1725
/// A type that can be deserialized from ByteTree into a scalar value that
1826
/// doesn't have any child nodes
1927
protocol ByteTreeScalarDecodable {
@@ -23,7 +31,8 @@ protocol ByteTreeScalarDecodable {
2331
/// - pointer: The pointer pointing to the start of the serialized data
2432
/// - size: The length of the serialized data in bytes
2533
/// - Returns: The deserialized value
26-
static func read(from pointer: UnsafeRawPointer, size: Int) -> Self
34+
static func read(from pointer: UnsafeRawPointer, size: Int,
35+
userInfo: [ByteTreeUserInfoKey: Any]) -> Self
2736
}
2837

2938
/// A type that can be deserialized from ByteTree into an object with child
@@ -39,7 +48,8 @@ protocol ByteTreeObjectDecodable {
3948
/// - numFields: The number of fields that are present in the serialized
4049
/// object
4150
/// - Returns: The deserialized object
42-
static func read(from reader: ByteTreeObjectReader, numFields: Int) -> Self
51+
static func read(from reader: ByteTreeObjectReader, numFields: Int,
52+
userInfo: [ByteTreeUserInfoKey: Any]) -> Self
4353
}
4454

4555
// MARK: - Reader objects
@@ -119,14 +129,30 @@ class ByteTreeObjectReader {
119129

120130
/// Reader for reading the ByteTree format into Swift objects
121131
class ByteTreeReader {
132+
enum DeserializationError: Error, CustomStringConvertible {
133+
case versionValidationFailed(ByteTreeReader.ProtocolVersion)
134+
135+
public var description: String {
136+
switch self {
137+
case .versionValidationFailed(let version):
138+
return "The serialized ByteTree version \(version) cannot be parsed " +
139+
"by this version of swiftSyntax"
140+
}
141+
}
142+
}
143+
122144
/// The type as which the protocol version is encoded in ByteTree
123145
typealias ProtocolVersion = UInt32
124146

125147
/// A pointer pointing to the next byte of serialized data to be read
126148
private var pointer: UnsafeRawPointer
127149

128-
private init(pointer: UnsafeRawPointer) {
150+
private var userInfo: [ByteTreeUserInfoKey: Any]
151+
152+
private init(pointer: UnsafeRawPointer,
153+
userInfo: [ByteTreeUserInfoKey: Any]) {
129154
self.pointer = pointer
155+
self.userInfo = userInfo
130156
}
131157

132158
// MARK: Public entrance function
@@ -137,22 +163,45 @@ class ByteTreeReader {
137163
/// - Parameters:
138164
/// - rootObjectType: The type of the root object in the deserialized tree
139165
/// - pointer: The memory location at which the serialized data resides
140-
/// - protocolVerisonValidation: A callback to determine if the data can be
166+
/// - protocolVersionValidation: A callback to determine if the data can be
141167
/// read, based on the format's protocol version. If the callback
142-
/// returns `false`, `nil` will be returned and reading aborded.
168+
/// returns `false` an error will be thrown
143169
/// - Returns: The deserialized tree or `nil` if protocol version validation
144170
/// failed
145171
static func read<T: ByteTreeObjectDecodable>(
146172
_ rootObjectType: T.Type, from pointer: UnsafeRawPointer,
147-
protocolVerisonValidation: (ProtocolVersion) -> Bool
148-
) -> T? {
149-
let reader = ByteTreeReader(pointer: pointer)
150-
if !reader.readAndValidateProtocolVersion(protocolVerisonValidation) {
151-
return nil
152-
}
173+
userInfo: [ByteTreeUserInfoKey: Any],
174+
protocolVersionValidation: (ProtocolVersion) -> Bool
175+
) throws -> T {
176+
let reader = ByteTreeReader(pointer: pointer, userInfo: userInfo)
177+
try reader.readAndValidateProtocolVersion(protocolVersionValidation)
153178
return reader.read(rootObjectType)
154179
}
155180

181+
/// Deserialize an object tree from the ByteTree data at the given memory
182+
/// location.
183+
///
184+
/// - Parameters:
185+
/// - rootObjectType: The type of the root object in the deserialized tree
186+
/// - data: The data to deserialize
187+
/// - protocolVersionValidation: A callback to determine if the data can be
188+
/// read, based on the format's protocol version. If the callback
189+
/// returns `false` an error will be thrown
190+
/// - Returns: The deserialized tree
191+
static func read<T: ByteTreeObjectDecodable>(
192+
_ rootObjectType: T.Type, from data: Data,
193+
userInfo: [ByteTreeUserInfoKey: Any],
194+
protocolVersionValidation versionValidate: (ProtocolVersion) -> Bool
195+
) throws -> T {
196+
return try data.withUnsafeBytes { (pointer: UnsafePointer<UInt8>) in
197+
let rawPointer = UnsafeRawPointer(pointer)
198+
return try ByteTreeReader.read(rootObjectType, from: rawPointer,
199+
userInfo: userInfo,
200+
protocolVersionValidation: versionValidate)
201+
}
202+
}
203+
204+
156205
// MARK: Internal read functions
157206

158207
/// Cast the current pointer location to the given type and advance `pointer`
@@ -181,12 +230,13 @@ class ByteTreeReader {
181230
/// protocol version can be read
182231
private func readAndValidateProtocolVersion(
183232
_ validationCallback: (ProtocolVersion) -> Bool
184-
) -> Bool {
233+
) throws {
185234
let protocolVersion = ProtocolVersion(littleEndian:
186235
readRaw(ProtocolVersion.self))
187236
let result = validationCallback(protocolVersion)
188-
pointer = pointer.advanced(by: MemoryLayout<ProtocolVersion>.size)
189-
return result
237+
if !result {
238+
throw DeserializationError.versionValidationFailed(protocolVersion)
239+
}
190240
}
191241

192242
/// Read the next field in the tree as an object of the specified type.
@@ -199,7 +249,7 @@ class ByteTreeReader {
199249
let numFields = readFieldLength()
200250
let objectReader = ByteTreeObjectReader(reader: self,
201251
numFields: numFields)
202-
return T.read(from: objectReader, numFields: numFields)
252+
return T.read(from: objectReader, numFields: numFields, userInfo: userInfo)
203253
}
204254

205255
/// Read the next field in the tree as a scalar of the specified type.
@@ -213,7 +263,7 @@ class ByteTreeReader {
213263
defer {
214264
pointer = pointer.advanced(by: fieldSize)
215265
}
216-
return T.read(from: pointer, size: fieldSize)
266+
return T.read(from: pointer, size: fieldSize, userInfo: userInfo)
217267
}
218268

219269
/// Discard the next scalar field, advancing the pointer to the next field
@@ -230,7 +280,9 @@ class ByteTreeReader {
230280
// Implemenation for reading an integer from memory to be shared between
231281
// multiple types
232282
extension ByteTreeScalarDecodable where Self : FixedWidthInteger {
233-
static func read(from pointer: UnsafeRawPointer, size: Int) -> Self {
283+
static func read(from pointer: UnsafeRawPointer, size: Int,
284+
userInfo: [ByteTreeUserInfoKey: Any]
285+
) -> Self {
234286
assert(size == MemoryLayout<Self>.size)
235287
return pointer.bindMemory(to: Self.self, capacity: 1).pointee
236288
}
@@ -241,7 +293,8 @@ extension UInt16: ByteTreeScalarDecodable {}
241293
extension UInt32: ByteTreeScalarDecodable {}
242294

243295
extension String: ByteTreeScalarDecodable {
244-
static func read(from pointer: UnsafeRawPointer, size: Int) -> String {
296+
static func read(from pointer: UnsafeRawPointer, size: Int,
297+
userInfo: [ByteTreeUserInfoKey: Any]) -> String {
245298
let data = Data(bytes: pointer, count: size)
246299
return String(data: data, encoding: .utf8)!
247300
}
@@ -250,21 +303,24 @@ extension String: ByteTreeScalarDecodable {
250303
extension Optional: ByteTreeObjectDecodable
251304
where
252305
Wrapped: ByteTreeObjectDecodable {
253-
static func read(from reader: ByteTreeObjectReader, numFields: Int) ->
254-
Optional<Wrapped> {
306+
static func read(from reader: ByteTreeObjectReader, numFields: Int,
307+
userInfo: [ByteTreeUserInfoKey: Any]
308+
) -> Optional<Wrapped> {
255309
if numFields == 0 {
256310
return nil
257311
} else {
258-
return Wrapped.read(from: reader, numFields: numFields)
312+
return Wrapped.read(from: reader, numFields: numFields,
313+
userInfo: userInfo)
259314
}
260315
}
261316
}
262317

263318
extension Array: ByteTreeObjectDecodable
264319
where
265320
Element: ByteTreeObjectDecodable {
266-
static func read(from reader: ByteTreeObjectReader, numFields: Int) ->
267-
Array<Element> {
321+
static func read(from reader: ByteTreeObjectReader, numFields: Int,
322+
userInfo: [ByteTreeUserInfoKey: Any]
323+
) -> Array<Element> {
268324
return (0..<numFields).map {
269325
return reader.readField(Element.self, index: $0)
270326
}

tools/SwiftSyntax/RawSyntax.swift

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,27 @@ extension CodingUserInfoKey {
2424
CodingUserInfoKey(rawValue: "SwiftSyntax.RawSyntax.OmittedNodeLookup")!
2525
}
2626

27+
extension ByteTreeUserInfoKey {
28+
/// Callback that will be called whenever a `RawSyntax` node is decoded
29+
/// Value must have signature `(RawSyntax) -> Void`
30+
static let rawSyntaxDecodedCallback =
31+
ByteTreeUserInfoKey(rawValue: "SwiftSyntax.RawSyntax.DecodedCallback")
32+
/// Function that shall be used to look up nodes that were omitted in the
33+
/// syntax tree transfer.
34+
/// Value must have signature `(SyntaxNodeId) -> RawSyntax`
35+
static let omittedNodeLookupFunction =
36+
ByteTreeUserInfoKey(rawValue: "SwiftSyntax.RawSyntax.OmittedNodeLookup")
37+
}
38+
39+
/// Box a value type into a reference type
40+
class Box<T> {
41+
let value: T
42+
43+
init(_ value: T) {
44+
self.value = value
45+
}
46+
}
47+
2748
/// A ID that uniquely identifies a syntax node and stays stable across multiple
2849
/// incremental parses
2950
public struct SyntaxNodeId: Hashable, Codable {
@@ -54,7 +75,7 @@ public struct SyntaxNodeId: Hashable, Codable {
5475
}
5576

5677
/// The data that is specific to a tree or token node
57-
fileprivate indirect enum RawSyntaxData {
78+
fileprivate enum RawSyntaxData {
5879
/// A tree node with a kind and an array of children
5980
case node(kind: SyntaxKind, layout: [RawSyntax?])
6081
/// A token with a token kind, leading trivia, and trailing trivia
@@ -72,11 +93,11 @@ struct RawSyntax: Codable {
7293
/// incremental parses
7394
let id: SyntaxNodeId
7495

75-
var _contentLength = AtomicCache<SourceLength>()
96+
var _contentLength = AtomicCache<Box<SourceLength>>()
7697

7798
/// The length of this node excluding its leading and trailing trivia
7899
var contentLength: SourceLength {
79-
return _contentLength.value() {
100+
return _contentLength.value({
80101
switch data {
81102
case .node(kind: _, layout: let layout):
82103
let firstElementIndex = layout.firstIndex(where: { $0 != nil })
@@ -97,11 +118,11 @@ struct RawSyntax: Codable {
97118
contentLength += element.trailingTriviaLength
98119
}
99120
}
100-
return contentLength
121+
return Box(contentLength)
101122
case .token(kind: let kind, leadingTrivia: _, trailingTrivia: _):
102-
return SourceLength(of: kind.text)
123+
return Box(SourceLength(of: kind.text))
103124
}
104-
}
125+
}).value
105126
}
106127

107128
init(kind: SyntaxKind, layout: [RawSyntax?], presence: SourcePresence,
@@ -372,43 +393,64 @@ extension RawSyntax: ByteTreeObjectDecodable {
372393
enum SyntaxType: UInt8, ByteTreeScalarDecodable {
373394
case token = 0
374395
case layout = 1
396+
case omitted = 2
375397

376-
static func read(from pointer: UnsafeRawPointer, size: Int) ->
377-
SyntaxType {
378-
let rawValue = UInt8.read(from: pointer, size: size)
398+
static func read(from pointer: UnsafeRawPointer, size: Int,
399+
userInfo: [ByteTreeUserInfoKey: Any]
400+
) -> SyntaxType {
401+
let rawValue = UInt8.read(from: pointer, size: size, userInfo: userInfo)
379402
guard let type = SyntaxType(rawValue: rawValue) else {
380403
fatalError("Unknown RawSyntax node type \(rawValue)")
381404
}
382405
return type
383406
}
384407
}
385408

386-
static func read(from reader: ByteTreeObjectReader, numFields: Int) ->
387-
RawSyntax {
409+
static func read(from reader: ByteTreeObjectReader, numFields: Int,
410+
userInfo: [ByteTreeUserInfoKey: Any]
411+
) -> RawSyntax {
412+
let syntaxNode: RawSyntax
388413
let type = reader.readField(SyntaxType.self, index: 0)
414+
let id = reader.readField(SyntaxNodeId.self, index: 1)
389415
switch type {
390416
case .token:
391-
let presence = reader.readField(SourcePresence.self, index: 1)
392-
let id = reader.readField(SyntaxNodeId.self, index: 2)
417+
let presence = reader.readField(SourcePresence.self, index: 2)
393418
let kind = reader.readField(TokenKind.self, index: 3)
394419
let leadingTrivia = reader.readField(Trivia.self, index: 4)
395420
let trailingTrivia = reader.readField(Trivia.self, index: 5)
396-
return RawSyntax(kind: kind, leadingTrivia: leadingTrivia,
397-
trailingTrivia: trailingTrivia,
398-
presence: presence, id: id)
421+
syntaxNode = RawSyntax(kind: kind, leadingTrivia: leadingTrivia,
422+
trailingTrivia: trailingTrivia,
423+
presence: presence, id: id)
399424
case .layout:
400-
let presence = reader.readField(SourcePresence.self, index: 1)
401-
let id = reader.readField(SyntaxNodeId.self, index: 2)
425+
let presence = reader.readField(SourcePresence.self, index: 2)
402426
let kind = reader.readField(SyntaxKind.self, index: 3)
403427
let layout = reader.readField([RawSyntax?].self, index: 4)
404-
return RawSyntax(kind: kind, layout: layout, presence: presence, id: id)
428+
syntaxNode = RawSyntax(kind: kind, layout: layout, presence: presence,
429+
id: id)
430+
case .omitted:
431+
guard let lookupFunc = userInfo[.omittedNodeLookupFunction] as?
432+
(SyntaxNodeId) -> RawSyntax? else {
433+
fatalError("omittedNodeLookupFunction is required when decoding an " +
434+
"incrementally transferred syntax tree")
435+
}
436+
guard let lookupNode = lookupFunc(id) else {
437+
fatalError("Node lookup for id \(id) failed")
438+
}
439+
syntaxNode = lookupNode
440+
}
441+
if let callback = userInfo[.rawSyntaxDecodedCallback] as?
442+
(RawSyntax) -> Void {
443+
callback(syntaxNode)
405444
}
445+
return syntaxNode
406446
}
407447
}
408448

409449
extension SyntaxNodeId: ByteTreeScalarDecodable {
410-
static func read(from pointer: UnsafeRawPointer, size: Int) -> SyntaxNodeId {
411-
let rawValue = UInt8.read(from: pointer, size: size)
450+
static func read(from pointer: UnsafeRawPointer, size: Int,
451+
userInfo: [ByteTreeUserInfoKey: Any]
452+
) -> SyntaxNodeId {
453+
let rawValue = UInt32.read(from: pointer, size: size, userInfo: userInfo)
412454
return SyntaxNodeId(rawValue: UInt(rawValue))
413455
}
414456
}

tools/SwiftSyntax/SourceLength.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
/// The length a syntax node spans in the source code. From any AbsolutePosition
1414
/// you reach a node's end location by either adding its UTF-8 length or by
1515
/// inserting `lines` newlines and then moving `columns` columns to the right.
16-
public final class SourceLength {
16+
public struct SourceLength {
1717
public let newlines: Int
1818
public let columnsAtLastLine: Int
1919
public let utf8Length: Int
@@ -90,4 +90,4 @@ extension AbsolutePosition {
9090
public static func +=(lhs: inout AbsolutePosition, rhs: SourceLength) {
9191
lhs = lhs + rhs
9292
}
93-
}
93+
}

tools/SwiftSyntax/SourcePresence.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ public enum SourcePresence: String, Codable {
2525
}
2626

2727
extension SourcePresence: ByteTreeScalarDecodable {
28-
static func read(from pointer: UnsafeRawPointer, size: Int) -> SourcePresence {
28+
static func read(from pointer: UnsafeRawPointer, size: Int,
29+
userInfo: [ByteTreeUserInfoKey: Any]
30+
) -> SourcePresence {
2931
let rawValue = pointer.bindMemory(to: UInt8.self, capacity: 1).pointee
3032
switch rawValue {
3133
case 0: return .missing

0 commit comments

Comments
 (0)