Skip to content

[swiftSyntax] Performance improvements for deserialising ByteTrees #18888

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 46 additions & 37 deletions tools/SwiftSyntax/ByteTreeDeserialization.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ protocol ByteTreeScalarDecodable {
/// - size: The length of the serialized data in bytes
/// - Returns: The deserialized value
static func read(from pointer: UnsafeRawPointer, size: Int,
userInfo: [ByteTreeUserInfoKey: Any]) -> Self
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>) -> Self
}

/// A type that can be deserialized from ByteTree into an object with child
Expand All @@ -48,32 +48,34 @@ protocol ByteTreeObjectDecodable {
/// - numFields: The number of fields that are present in the serialized
/// object
/// - Returns: The deserialized object
static func read(from reader: ByteTreeObjectReader, numFields: Int,
userInfo: [ByteTreeUserInfoKey: Any]) -> Self
static func read(from reader: UnsafeMutablePointer<ByteTreeObjectReader>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So why not inout ByteTreeObjectReader?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because inout ByteTreeObjectReader still has a retain/release count associated with it.

numFields: Int,
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>) -> Self
}

// MARK: - Reader objects

/// Helper object for reading objects out a ByteTree. Keeps track that fields
/// are not read out of order and discards all trailing fields that were present
/// in the binary format but were not handled when reading the object.
class ByteTreeObjectReader {
struct ByteTreeObjectReader {
/// The reader that holds a reference to the data from which the object is
/// read
private let reader: ByteTreeReader
private let reader: UnsafeMutablePointer<ByteTreeReader>

/// The number of fields this object is expected to have
private let numFields: Int

/// The index of the field that is expected to be read next.
private var nextIndex: Int = 0

fileprivate init(reader: ByteTreeReader, numFields: Int) {
fileprivate init(reader: UnsafeMutablePointer<ByteTreeReader>,
numFields: Int) {
self.reader = reader
self.numFields = numFields
}

private func advanceAndValidateIndex(_ index: Int) {
private mutating func advanceAndValidateIndex(_ index: Int) {
assert(index == nextIndex, "Reading fields out of order")
assert(index < numFields)
nextIndex += 1
Expand All @@ -88,11 +90,11 @@ class ByteTreeObjectReader {
/// - objectType: The type as which this field should be read
/// - index: The index of this field
/// - Returns: The decoded field
func readField<FieldType: ByteTreeScalarDecodable>(
mutating func readField<FieldType: ByteTreeScalarDecodable>(
_ objectType: FieldType.Type, index: Int
) -> FieldType {
advanceAndValidateIndex(index)
return reader.read(objectType)
return reader.pointee.read(objectType)
}

/// Read the field at the given index as the specified type. All indicies must
Expand All @@ -103,23 +105,23 @@ class ByteTreeObjectReader {
/// - objectType: The type as which this field should be read
/// - index: The index of this field
/// - Returns: The decoded field
func readField<FieldType: ByteTreeObjectDecodable>(
mutating func readField<FieldType: ByteTreeObjectDecodable>(
_ objectType: FieldType.Type, index: Int
) -> FieldType {
advanceAndValidateIndex(index)
return reader.read(objectType)
return reader.pointee.read(objectType)
}

/// Read and immediately discard the field at the specified index. This
/// advances the reader by one field so that the next field can be read.
///
/// - Parameter index: The index of the field that shall be discarded
func discardField(index: Int) {
mutating func discardField(index: Int) {
advanceAndValidateIndex(index)
reader.discardField()
reader.pointee.discardField()
}

deinit {
fileprivate mutating func finalize() {
// Discard all fields that have not been read
while nextIndex < numFields {
discardField(index: nextIndex)
Expand All @@ -128,7 +130,7 @@ class ByteTreeObjectReader {
}

/// Reader for reading the ByteTree format into Swift objects
class ByteTreeReader {
struct ByteTreeReader {
enum DeserializationError: Error, CustomStringConvertible {
case versionValidationFailed(ByteTreeReader.ProtocolVersion)

Expand All @@ -147,10 +149,10 @@ class ByteTreeReader {
/// A pointer pointing to the next byte of serialized data to be read
private var pointer: UnsafeRawPointer

private var userInfo: [ByteTreeUserInfoKey: Any]
private var userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>

private init(pointer: UnsafeRawPointer,
userInfo: [ByteTreeUserInfoKey: Any]) {
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>) {
self.pointer = pointer
self.userInfo = userInfo
}
Expand All @@ -170,10 +172,10 @@ class ByteTreeReader {
/// failed
static func read<T: ByteTreeObjectDecodable>(
_ rootObjectType: T.Type, from pointer: UnsafeRawPointer,
userInfo: [ByteTreeUserInfoKey: Any],
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>,
protocolVersionValidation: (ProtocolVersion) -> Bool
) throws -> T {
let reader = ByteTreeReader(pointer: pointer, userInfo: userInfo)
var reader = ByteTreeReader(pointer: pointer, userInfo: userInfo)
try reader.readAndValidateProtocolVersion(protocolVersionValidation)
return reader.read(rootObjectType)
}
Expand All @@ -190,7 +192,7 @@ class ByteTreeReader {
/// - Returns: The deserialized tree
static func read<T: ByteTreeObjectDecodable>(
_ rootObjectType: T.Type, from data: Data,
userInfo: [ByteTreeUserInfoKey: Any],
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>,
protocolVersionValidation versionValidate: (ProtocolVersion) -> Bool
) throws -> T {
return try data.withUnsafeBytes { (pointer: UnsafePointer<UInt8>) in
Expand All @@ -209,13 +211,13 @@ class ByteTreeReader {
///
/// - Parameter type: The type as which the current data should be read
/// - Returns: The read value
private func readRaw<T>(_ type: T.Type) -> T {
private mutating func readRaw<T>(_ type: T.Type) -> T {
let result = pointer.bindMemory(to: T.self, capacity: 1).pointee
pointer = pointer.advanced(by: MemoryLayout<T>.size)
return result
}

private func readFieldLength() -> (isObject: Bool, length: Int) {
private mutating func readFieldLength() -> (isObject: Bool, length: Int) {
let raw = UInt32(littleEndian: readRaw(UInt32.self))
let isObject = (raw & (UInt32(1) << 31)) != 0
let length = Int(raw & ~(UInt32(1) << 31))
Expand All @@ -226,7 +228,7 @@ class ByteTreeReader {
/// Read the number of fields in an object.
///
/// - Returns: The number of fields in the following object
private func readObjectLength() -> Int {
private mutating func readObjectLength() -> Int {
let (isObject, length) = readFieldLength()
assert(isObject)
return length
Expand All @@ -235,7 +237,7 @@ class ByteTreeReader {
/// Read the size of a scalar in bytes
///
/// - Returns: The size of the following scalar in bytes
private func readScalarLength() -> Int {
private mutating func readScalarLength() -> Int {
let (isObject, length) = readFieldLength()
assert(!isObject)
return length
Expand All @@ -246,7 +248,7 @@ class ByteTreeReader {
///
/// - Parameter validationCallback: A callback that determines if the given
/// protocol version can be read
private func readAndValidateProtocolVersion(
private mutating func readAndValidateProtocolVersion(
_ validationCallback: (ProtocolVersion) -> Bool
) throws {
let protocolVersion = ProtocolVersion(littleEndian:
Expand All @@ -261,20 +263,24 @@ class ByteTreeReader {
///
/// - Parameter objectType: The type as which the next field shall be read
/// - Returns: The deserialized object
fileprivate func read<T: ByteTreeObjectDecodable>(
fileprivate mutating func read<T: ByteTreeObjectDecodable>(
_ objectType: T.Type
) -> T {
let numFields = readObjectLength()
let objectReader = ByteTreeObjectReader(reader: self,
var objectReader = ByteTreeObjectReader(reader: &self,
numFields: numFields)
return T.read(from: objectReader, numFields: numFields, userInfo: userInfo)
defer {
objectReader.finalize()
}
return T.read(from: &objectReader, numFields: numFields,
userInfo: userInfo)
}

/// Read the next field in the tree as a scalar of the specified type.
///
/// - Parameter scalarType: The type as which the field shall be read
/// - Returns: The deserialized scalar
fileprivate func read<T: ByteTreeScalarDecodable>(
fileprivate mutating func read<T: ByteTreeScalarDecodable>(
_ scalarType: T.Type
) -> T {
let fieldSize = readScalarLength()
Expand All @@ -285,7 +291,7 @@ class ByteTreeReader {
}

/// Discard the next scalar field, advancing the pointer to the next field
fileprivate func discardField() {
fileprivate mutating func discardField() {
let (isObject, length) = readFieldLength()
if isObject {
// Discard object by discarding all its objects
Expand All @@ -305,7 +311,7 @@ class ByteTreeReader {
// multiple types
extension ByteTreeScalarDecodable where Self : FixedWidthInteger {
static func read(from pointer: UnsafeRawPointer, size: Int,
userInfo: [ByteTreeUserInfoKey: Any]
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
) -> Self {
assert(size == MemoryLayout<Self>.size)
return pointer.bindMemory(to: Self.self, capacity: 1).pointee
Expand All @@ -318,7 +324,8 @@ extension UInt32: ByteTreeScalarDecodable {}

extension String: ByteTreeScalarDecodable {
static func read(from pointer: UnsafeRawPointer, size: Int,
userInfo: [ByteTreeUserInfoKey: Any]) -> String {
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
) -> String {
let data = Data(bytes: pointer, count: size)
return String(data: data, encoding: .utf8)!
}
Expand All @@ -327,8 +334,9 @@ extension String: ByteTreeScalarDecodable {
extension Optional: ByteTreeObjectDecodable
where
Wrapped: ByteTreeObjectDecodable {
static func read(from reader: ByteTreeObjectReader, numFields: Int,
userInfo: [ByteTreeUserInfoKey: Any]
static func read(from reader: UnsafeMutablePointer<ByteTreeObjectReader>,
numFields: Int,
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
) -> Optional<Wrapped> {
if numFields == 0 {
return nil
Expand All @@ -342,11 +350,12 @@ extension Optional: ByteTreeObjectDecodable
extension Array: ByteTreeObjectDecodable
where
Element: ByteTreeObjectDecodable {
static func read(from reader: ByteTreeObjectReader, numFields: Int,
userInfo: [ByteTreeUserInfoKey: Any]
static func read(from reader: UnsafeMutablePointer<ByteTreeObjectReader>,
numFields: Int,
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
) -> Array<Element> {
return (0..<numFields).map {
return reader.readField(Element.self, index: $0)
return reader.pointee.readField(Element.self, index: $0)
}
}
}
52 changes: 31 additions & 21 deletions tools/SwiftSyntax/RawSyntax.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ extension ByteTreeUserInfoKey {
ByteTreeUserInfoKey(rawValue: "SwiftSyntax.RawSyntax.OmittedNodeLookup")
}

/// Box a value type into a reference type
class Box<T> {
let value: T

init(_ value: T) {
self.value = value
}
}

/// A ID that uniquely identifies a syntax node and stays stable across multiple
/// incremental parses
public struct SyntaxNodeId: Hashable, Codable {
Expand Down Expand Up @@ -66,7 +75,7 @@ public struct SyntaxNodeId: Hashable, Codable {
}

/// The data that is specific to a tree or token node
fileprivate indirect enum RawSyntaxData {
fileprivate enum RawSyntaxData {
/// A tree node with a kind and an array of children
case node(kind: SyntaxKind, layout: [RawSyntax?])
/// A token with a token kind, leading trivia, and trailing trivia
Expand All @@ -84,11 +93,11 @@ struct RawSyntax: Codable {
/// incremental parses
let id: SyntaxNodeId

var _contentLength = AtomicCache<SourceLength>()
var _contentLength = AtomicCache<Box<SourceLength>>()

/// The length of this node excluding its leading and trailing trivia
var contentLength: SourceLength {
return _contentLength.value() {
return _contentLength.value({
switch data {
case .node(kind: _, layout: let layout):
let firstElementIndex = layout.firstIndex(where: { $0 != nil })
Expand All @@ -109,11 +118,11 @@ struct RawSyntax: Codable {
contentLength += element.trailingTriviaLength
}
}
return contentLength
return Box(contentLength)
case .token(kind: let kind, leadingTrivia: _, trailingTrivia: _):
return SourceLength(of: kind.text)
return Box(SourceLength(of: kind.text))
}
}
}).value
}

init(kind: SyntaxKind, layout: [RawSyntax?], presence: SourcePresence,
Expand Down Expand Up @@ -387,7 +396,7 @@ extension RawSyntax: ByteTreeObjectDecodable {
case omitted = 2

static func read(from pointer: UnsafeRawPointer, size: Int,
userInfo: [ByteTreeUserInfoKey: Any]
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
) -> SyntaxType {
let rawValue = UInt8.read(from: pointer, size: size, userInfo: userInfo)
guard let type = SyntaxType(rawValue: rawValue) else {
Expand All @@ -397,29 +406,30 @@ extension RawSyntax: ByteTreeObjectDecodable {
}
}

static func read(from reader: ByteTreeObjectReader, numFields: Int,
userInfo: [ByteTreeUserInfoKey: Any]
static func read(from reader: UnsafeMutablePointer<ByteTreeObjectReader>,
numFields: Int,
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
) -> RawSyntax {
let syntaxNode: RawSyntax
let type = reader.readField(SyntaxType.self, index: 0)
let id = reader.readField(SyntaxNodeId.self, index: 1)
let type = reader.pointee.readField(SyntaxType.self, index: 0)
let id = reader.pointee.readField(SyntaxNodeId.self, index: 1)
switch type {
case .token:
let presence = reader.readField(SourcePresence.self, index: 2)
let kind = reader.readField(TokenKind.self, index: 3)
let leadingTrivia = reader.readField(Trivia.self, index: 4)
let trailingTrivia = reader.readField(Trivia.self, index: 5)
let presence = reader.pointee.readField(SourcePresence.self, index: 2)
let kind = reader.pointee.readField(TokenKind.self, index: 3)
let leadingTrivia = reader.pointee.readField(Trivia.self, index: 4)
let trailingTrivia = reader.pointee.readField(Trivia.self, index: 5)
syntaxNode = RawSyntax(kind: kind, leadingTrivia: leadingTrivia,
trailingTrivia: trailingTrivia,
presence: presence, id: id)
case .layout:
let presence = reader.readField(SourcePresence.self, index: 2)
let kind = reader.readField(SyntaxKind.self, index: 3)
let layout = reader.readField([RawSyntax?].self, index: 4)
let presence = reader.pointee.readField(SourcePresence.self, index: 2)
let kind = reader.pointee.readField(SyntaxKind.self, index: 3)
let layout = reader.pointee.readField([RawSyntax?].self, index: 4)
syntaxNode = RawSyntax(kind: kind, layout: layout, presence: presence,
id: id)
case .omitted:
guard let lookupFunc = userInfo[.omittedNodeLookupFunction] as?
guard let lookupFunc = userInfo.pointee[.omittedNodeLookupFunction] as?
(SyntaxNodeId) -> RawSyntax? else {
fatalError("omittedNodeLookupFunction is required when decoding an " +
"incrementally transferred syntax tree")
Expand All @@ -429,7 +439,7 @@ extension RawSyntax: ByteTreeObjectDecodable {
}
syntaxNode = lookupNode
}
if let callback = userInfo[.rawSyntaxDecodedCallback] as?
if let callback = userInfo.pointee[.rawSyntaxDecodedCallback] as?
(RawSyntax) -> Void {
callback(syntaxNode)
}
Expand All @@ -439,7 +449,7 @@ extension RawSyntax: ByteTreeObjectDecodable {

extension SyntaxNodeId: ByteTreeScalarDecodable {
static func read(from pointer: UnsafeRawPointer, size: Int,
userInfo: [ByteTreeUserInfoKey: Any]
userInfo: UnsafePointer<[ByteTreeUserInfoKey: Any]>
) -> SyntaxNodeId {
let rawValue = UInt32.read(from: pointer, size: size, userInfo: userInfo)
return SyntaxNodeId(rawValue: UInt(rawValue))
Expand Down
Loading