Skip to content

Commit f50ddf9

Browse files
authored
Merge pull request #18888 from ahoppen/swiftsyntax-performance-improvments
[swiftSyntax] Performance improvements for deserialising ByteTrees
2 parents 4b625a3 + d41af61 commit f50ddf9

9 files changed

+96
-82
lines changed

tools/SwiftSyntax/ByteTreeDeserialization.swift

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ protocol ByteTreeScalarDecodable {
3232
/// - size: The length of the serialized data in bytes
3333
/// - Returns: The deserialized value
3434
static func read(from pointer: UnsafeRawPointer, size: Int,
35-
userInfo: [ByteTreeUserInfoKey: Any]) -> Self
35+
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>) -> Self
3636
}
3737

3838
/// A type that can be deserialized from ByteTree into an object with child
@@ -48,32 +48,34 @@ protocol ByteTreeObjectDecodable {
4848
/// - numFields: The number of fields that are present in the serialized
4949
/// object
5050
/// - Returns: The deserialized object
51-
static func read(from reader: ByteTreeObjectReader, numFields: Int,
52-
userInfo: [ByteTreeUserInfoKey: Any]) -> Self
51+
static func read(from reader: UnsafeMutablePointer<ByteTreeObjectReader>,
52+
numFields: Int,
53+
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>) -> Self
5354
}
5455

5556
// MARK: - Reader objects
5657

5758
/// Helper object for reading objects out a ByteTree. Keeps track that fields
5859
/// are not read out of order and discards all trailing fields that were present
5960
/// in the binary format but were not handled when reading the object.
60-
class ByteTreeObjectReader {
61+
struct ByteTreeObjectReader {
6162
/// The reader that holds a reference to the data from which the object is
6263
/// read
63-
private let reader: ByteTreeReader
64+
private let reader: UnsafeMutablePointer<ByteTreeReader>
6465

6566
/// The number of fields this object is expected to have
6667
private let numFields: Int
6768

6869
/// The index of the field that is expected to be read next.
6970
private var nextIndex: Int = 0
7071

71-
fileprivate init(reader: ByteTreeReader, numFields: Int) {
72+
fileprivate init(reader: UnsafeMutablePointer<ByteTreeReader>,
73+
numFields: Int) {
7274
self.reader = reader
7375
self.numFields = numFields
7476
}
7577

76-
private func advanceAndValidateIndex(_ index: Int) {
78+
private mutating func advanceAndValidateIndex(_ index: Int) {
7779
assert(index == nextIndex, "Reading fields out of order")
7880
assert(index < numFields)
7981
nextIndex += 1
@@ -88,11 +90,11 @@ class ByteTreeObjectReader {
8890
/// - objectType: The type as which this field should be read
8991
/// - index: The index of this field
9092
/// - Returns: The decoded field
91-
func readField<FieldType: ByteTreeScalarDecodable>(
93+
mutating func readField<FieldType: ByteTreeScalarDecodable>(
9294
_ objectType: FieldType.Type, index: Int
9395
) -> FieldType {
9496
advanceAndValidateIndex(index)
95-
return reader.read(objectType)
97+
return reader.pointee.read(objectType)
9698
}
9799

98100
/// Read the field at the given index as the specified type. All indicies must
@@ -103,23 +105,23 @@ class ByteTreeObjectReader {
103105
/// - objectType: The type as which this field should be read
104106
/// - index: The index of this field
105107
/// - Returns: The decoded field
106-
func readField<FieldType: ByteTreeObjectDecodable>(
108+
mutating func readField<FieldType: ByteTreeObjectDecodable>(
107109
_ objectType: FieldType.Type, index: Int
108110
) -> FieldType {
109111
advanceAndValidateIndex(index)
110-
return reader.read(objectType)
112+
return reader.pointee.read(objectType)
111113
}
112114

113115
/// Read and immediately discard the field at the specified index. This
114116
/// advances the reader by one field so that the next field can be read.
115117
///
116118
/// - Parameter index: The index of the field that shall be discarded
117-
func discardField(index: Int) {
119+
mutating func discardField(index: Int) {
118120
advanceAndValidateIndex(index)
119-
reader.discardField()
121+
reader.pointee.discardField()
120122
}
121123

122-
deinit {
124+
fileprivate mutating func finalize() {
123125
// Discard all fields that have not been read
124126
while nextIndex < numFields {
125127
discardField(index: nextIndex)
@@ -128,7 +130,7 @@ class ByteTreeObjectReader {
128130
}
129131

130132
/// Reader for reading the ByteTree format into Swift objects
131-
class ByteTreeReader {
133+
struct ByteTreeReader {
132134
enum DeserializationError: Error, CustomStringConvertible {
133135
case versionValidationFailed(ByteTreeReader.ProtocolVersion)
134136

@@ -147,10 +149,10 @@ class ByteTreeReader {
147149
/// A pointer pointing to the next byte of serialized data to be read
148150
private var pointer: UnsafeRawPointer
149151

150-
private var userInfo: [ByteTreeUserInfoKey: Any]
152+
private var userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
151153

152154
private init(pointer: UnsafeRawPointer,
153-
userInfo: [ByteTreeUserInfoKey: Any]) {
155+
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>) {
154156
self.pointer = pointer
155157
self.userInfo = userInfo
156158
}
@@ -170,10 +172,10 @@ class ByteTreeReader {
170172
/// failed
171173
static func read<T: ByteTreeObjectDecodable>(
172174
_ rootObjectType: T.Type, from pointer: UnsafeRawPointer,
173-
userInfo: [ByteTreeUserInfoKey: Any],
175+
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>,
174176
protocolVersionValidation: (ProtocolVersion) -> Bool
175177
) throws -> T {
176-
let reader = ByteTreeReader(pointer: pointer, userInfo: userInfo)
178+
var reader = ByteTreeReader(pointer: pointer, userInfo: userInfo)
177179
try reader.readAndValidateProtocolVersion(protocolVersionValidation)
178180
return reader.read(rootObjectType)
179181
}
@@ -190,7 +192,7 @@ class ByteTreeReader {
190192
/// - Returns: The deserialized tree
191193
static func read<T: ByteTreeObjectDecodable>(
192194
_ rootObjectType: T.Type, from data: Data,
193-
userInfo: [ByteTreeUserInfoKey: Any],
195+
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>,
194196
protocolVersionValidation versionValidate: (ProtocolVersion) -> Bool
195197
) throws -> T {
196198
return try data.withUnsafeBytes { (pointer: UnsafePointer<UInt8>) in
@@ -209,13 +211,13 @@ class ByteTreeReader {
209211
///
210212
/// - Parameter type: The type as which the current data should be read
211213
/// - Returns: The read value
212-
private func readRaw<T>(_ type: T.Type) -> T {
214+
private mutating func readRaw<T>(_ type: T.Type) -> T {
213215
let result = pointer.bindMemory(to: T.self, capacity: 1).pointee
214216
pointer = pointer.advanced(by: MemoryLayout<T>.size)
215217
return result
216218
}
217219

218-
private func readFieldLength() -> (isObject: Bool, length: Int) {
220+
private mutating func readFieldLength() -> (isObject: Bool, length: Int) {
219221
let raw = UInt32(littleEndian: readRaw(UInt32.self))
220222
let isObject = (raw & (UInt32(1) << 31)) != 0
221223
let length = Int(raw & ~(UInt32(1) << 31))
@@ -226,7 +228,7 @@ class ByteTreeReader {
226228
/// Read the number of fields in an object.
227229
///
228230
/// - Returns: The number of fields in the following object
229-
private func readObjectLength() -> Int {
231+
private mutating func readObjectLength() -> Int {
230232
let (isObject, length) = readFieldLength()
231233
assert(isObject)
232234
return length
@@ -235,7 +237,7 @@ class ByteTreeReader {
235237
/// Read the size of a scalar in bytes
236238
///
237239
/// - Returns: The size of the following scalar in bytes
238-
private func readScalarLength() -> Int {
240+
private mutating func readScalarLength() -> Int {
239241
let (isObject, length) = readFieldLength()
240242
assert(!isObject)
241243
return length
@@ -246,7 +248,7 @@ class ByteTreeReader {
246248
///
247249
/// - Parameter validationCallback: A callback that determines if the given
248250
/// protocol version can be read
249-
private func readAndValidateProtocolVersion(
251+
private mutating func readAndValidateProtocolVersion(
250252
_ validationCallback: (ProtocolVersion) -> Bool
251253
) throws {
252254
let protocolVersion = ProtocolVersion(littleEndian:
@@ -261,20 +263,24 @@ class ByteTreeReader {
261263
///
262264
/// - Parameter objectType: The type as which the next field shall be read
263265
/// - Returns: The deserialized object
264-
fileprivate func read<T: ByteTreeObjectDecodable>(
266+
fileprivate mutating func read<T: ByteTreeObjectDecodable>(
265267
_ objectType: T.Type
266268
) -> T {
267269
let numFields = readObjectLength()
268-
let objectReader = ByteTreeObjectReader(reader: self,
270+
var objectReader = ByteTreeObjectReader(reader: &self,
269271
numFields: numFields)
270-
return T.read(from: objectReader, numFields: numFields, userInfo: userInfo)
272+
defer {
273+
objectReader.finalize()
274+
}
275+
return T.read(from: &objectReader, numFields: numFields,
276+
userInfo: userInfo)
271277
}
272278

273279
/// Read the next field in the tree as a scalar of the specified type.
274280
///
275281
/// - Parameter scalarType: The type as which the field shall be read
276282
/// - Returns: The deserialized scalar
277-
fileprivate func read<T: ByteTreeScalarDecodable>(
283+
fileprivate mutating func read<T: ByteTreeScalarDecodable>(
278284
_ scalarType: T.Type
279285
) -> T {
280286
let fieldSize = readScalarLength()
@@ -285,7 +291,7 @@ class ByteTreeReader {
285291
}
286292

287293
/// Discard the next scalar field, advancing the pointer to the next field
288-
fileprivate func discardField() {
294+
fileprivate mutating func discardField() {
289295
let (isObject, length) = readFieldLength()
290296
if isObject {
291297
// Discard object by discarding all its objects
@@ -305,7 +311,7 @@ class ByteTreeReader {
305311
// multiple types
306312
extension ByteTreeScalarDecodable where Self : FixedWidthInteger {
307313
static func read(from pointer: UnsafeRawPointer, size: Int,
308-
userInfo: [ByteTreeUserInfoKey: Any]
314+
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
309315
) -> Self {
310316
assert(size == MemoryLayout<Self>.size)
311317
return pointer.bindMemory(to: Self.self, capacity: 1).pointee
@@ -318,7 +324,8 @@ extension UInt32: ByteTreeScalarDecodable {}
318324

319325
extension String: ByteTreeScalarDecodable {
320326
static func read(from pointer: UnsafeRawPointer, size: Int,
321-
userInfo: [ByteTreeUserInfoKey: Any]) -> String {
327+
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
328+
) -> String {
322329
let data = Data(bytes: pointer, count: size)
323330
return String(data: data, encoding: .utf8)!
324331
}
@@ -327,8 +334,9 @@ extension String: ByteTreeScalarDecodable {
327334
extension Optional: ByteTreeObjectDecodable
328335
where
329336
Wrapped: ByteTreeObjectDecodable {
330-
static func read(from reader: ByteTreeObjectReader, numFields: Int,
331-
userInfo: [ByteTreeUserInfoKey: Any]
337+
static func read(from reader: UnsafeMutablePointer<ByteTreeObjectReader>,
338+
numFields: Int,
339+
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
332340
) -> Optional<Wrapped> {
333341
if numFields == 0 {
334342
return nil
@@ -342,11 +350,12 @@ extension Optional: ByteTreeObjectDecodable
342350
extension Array: ByteTreeObjectDecodable
343351
where
344352
Element: ByteTreeObjectDecodable {
345-
static func read(from reader: ByteTreeObjectReader, numFields: Int,
346-
userInfo: [ByteTreeUserInfoKey: Any]
353+
static func read(from reader: UnsafeMutablePointer<ByteTreeObjectReader>,
354+
numFields: Int,
355+
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
347356
) -> Array<Element> {
348357
return (0..<numFields).map {
349-
return reader.readField(Element.self, index: $0)
358+
return reader.pointee.readField(Element.self, index: $0)
350359
}
351360
}
352361
}

tools/SwiftSyntax/RawSyntax.swift

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ extension ByteTreeUserInfoKey {
3636
ByteTreeUserInfoKey(rawValue: "SwiftSyntax.RawSyntax.OmittedNodeLookup")
3737
}
3838

39+
/// Box a value type into a reference type
40+
class Box<T> {
41+
let value: T
42+
43+
init(_ value: T) {
44+
self.value = value
45+
}
46+
}
47+
3948
/// A ID that uniquely identifies a syntax node and stays stable across multiple
4049
/// incremental parses
4150
public struct SyntaxNodeId: Hashable, Codable {
@@ -66,7 +75,7 @@ public struct SyntaxNodeId: Hashable, Codable {
6675
}
6776

6877
/// The data that is specific to a tree or token node
69-
fileprivate indirect enum RawSyntaxData {
78+
fileprivate enum RawSyntaxData {
7079
/// A tree node with a kind and an array of children
7180
case node(kind: SyntaxKind, layout: [RawSyntax?])
7281
/// A token with a token kind, leading trivia, and trailing trivia
@@ -84,11 +93,11 @@ struct RawSyntax: Codable {
8493
/// incremental parses
8594
let id: SyntaxNodeId
8695

87-
var _contentLength = AtomicCache<SourceLength>()
96+
var _contentLength = AtomicCache<Box<SourceLength>>()
8897

8998
/// The length of this node excluding its leading and trailing trivia
9099
var contentLength: SourceLength {
91-
return _contentLength.value() {
100+
return _contentLength.value({
92101
switch data {
93102
case .node(kind: _, layout: let layout):
94103
let firstElementIndex = layout.firstIndex(where: { $0 != nil })
@@ -109,11 +118,11 @@ struct RawSyntax: Codable {
109118
contentLength += element.trailingTriviaLength
110119
}
111120
}
112-
return contentLength
121+
return Box(contentLength)
113122
case .token(kind: let kind, leadingTrivia: _, trailingTrivia: _):
114-
return SourceLength(of: kind.text)
123+
return Box(SourceLength(of: kind.text))
115124
}
116-
}
125+
}).value
117126
}
118127

119128
init(kind: SyntaxKind, layout: [RawSyntax?], presence: SourcePresence,
@@ -387,7 +396,7 @@ extension RawSyntax: ByteTreeObjectDecodable {
387396
case omitted = 2
388397

389398
static func read(from pointer: UnsafeRawPointer, size: Int,
390-
userInfo: [ByteTreeUserInfoKey: Any]
399+
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
391400
) -> SyntaxType {
392401
let rawValue = UInt8.read(from: pointer, size: size, userInfo: userInfo)
393402
guard let type = SyntaxType(rawValue: rawValue) else {
@@ -397,29 +406,30 @@ extension RawSyntax: ByteTreeObjectDecodable {
397406
}
398407
}
399408

400-
static func read(from reader: ByteTreeObjectReader, numFields: Int,
401-
userInfo: [ByteTreeUserInfoKey: Any]
409+
static func read(from reader: UnsafeMutablePointer<ByteTreeObjectReader>,
410+
numFields: Int,
411+
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
402412
) -> RawSyntax {
403413
let syntaxNode: RawSyntax
404-
let type = reader.readField(SyntaxType.self, index: 0)
405-
let id = reader.readField(SyntaxNodeId.self, index: 1)
414+
let type = reader.pointee.readField(SyntaxType.self, index: 0)
415+
let id = reader.pointee.readField(SyntaxNodeId.self, index: 1)
406416
switch type {
407417
case .token:
408-
let presence = reader.readField(SourcePresence.self, index: 2)
409-
let kind = reader.readField(TokenKind.self, index: 3)
410-
let leadingTrivia = reader.readField(Trivia.self, index: 4)
411-
let trailingTrivia = reader.readField(Trivia.self, index: 5)
418+
let presence = reader.pointee.readField(SourcePresence.self, index: 2)
419+
let kind = reader.pointee.readField(TokenKind.self, index: 3)
420+
let leadingTrivia = reader.pointee.readField(Trivia.self, index: 4)
421+
let trailingTrivia = reader.pointee.readField(Trivia.self, index: 5)
412422
syntaxNode = RawSyntax(kind: kind, leadingTrivia: leadingTrivia,
413423
trailingTrivia: trailingTrivia,
414424
presence: presence, id: id)
415425
case .layout:
416-
let presence = reader.readField(SourcePresence.self, index: 2)
417-
let kind = reader.readField(SyntaxKind.self, index: 3)
418-
let layout = reader.readField([RawSyntax?].self, index: 4)
426+
let presence = reader.pointee.readField(SourcePresence.self, index: 2)
427+
let kind = reader.pointee.readField(SyntaxKind.self, index: 3)
428+
let layout = reader.pointee.readField([RawSyntax?].self, index: 4)
419429
syntaxNode = RawSyntax(kind: kind, layout: layout, presence: presence,
420430
id: id)
421431
case .omitted:
422-
guard let lookupFunc = userInfo[.omittedNodeLookupFunction] as?
432+
guard let lookupFunc = userInfo.pointee[.omittedNodeLookupFunction] as?
423433
(SyntaxNodeId) -> RawSyntax? else {
424434
fatalError("omittedNodeLookupFunction is required when decoding an " +
425435
"incrementally transferred syntax tree")
@@ -429,7 +439,7 @@ extension RawSyntax: ByteTreeObjectDecodable {
429439
}
430440
syntaxNode = lookupNode
431441
}
432-
if let callback = userInfo[.rawSyntaxDecodedCallback] as?
442+
if let callback = userInfo.pointee[.rawSyntaxDecodedCallback] as?
433443
(RawSyntax) -> Void {
434444
callback(syntaxNode)
435445
}
@@ -439,7 +449,7 @@ extension RawSyntax: ByteTreeObjectDecodable {
439449

440450
extension SyntaxNodeId: ByteTreeScalarDecodable {
441451
static func read(from pointer: UnsafeRawPointer, size: Int,
442-
userInfo: [ByteTreeUserInfoKey: Any]
452+
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
443453
) -> SyntaxNodeId {
444454
let rawValue = UInt32.read(from: pointer, size: size, userInfo: userInfo)
445455
return SyntaxNodeId(rawValue: UInt(rawValue))

0 commit comments

Comments
 (0)