Skip to content

Fix two assertion failures related to invalid UTF-8 #1286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Sources/SwiftParser/Lexer/Cursor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2102,7 +2102,7 @@ extension Lexer.Cursor {
/// valid operator start, advance the cursor by what can be considered a
/// lexeme.
mutating func lexUnknown() -> UnknownCharactersClassification {
assert(self.peekScalar()?.isValidIdentifierStartCodePoint == false && self.peekScalar()?.isOperatorStartCodePoint == false)
assert(!(self.peekScalar()?.isValidIdentifierStartCodePoint ?? false) && !(self.peekScalar()?.isOperatorStartCodePoint ?? false))
var tmp = self
if tmp.advance(if: { Unicode.Scalar($0).isValidIdentifierContinuationCodePoint }) {
// If this is a valid identifier continuation, but not a valid identifier
Expand Down
33 changes: 17 additions & 16 deletions Sources/SwiftSyntax/SourceLocation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ public final class SourceLocationConverter {
public init(file: String, source: String) {
self.file = file
self.source = Array(source.utf8)
(self.lines, endOfFile) = computeLines(source)
(self.lines, endOfFile) = self.source.withUnsafeBufferPointer { buf in
return computeLines(SyntaxText(buffer: buf))
}
assert(source.utf8.count == endOfFile.utf8Offset)
}

Expand Down Expand Up @@ -397,7 +399,7 @@ fileprivate func computeLines(
return (lines, position)
}

fileprivate func computeLines(_ source: String) -> ([AbsolutePosition], AbsolutePosition) {
fileprivate func computeLines(_ source: SyntaxText) -> ([AbsolutePosition], AbsolutePosition) {
var lines: [AbsolutePosition] = []
// First line starts from the beginning.
lines.append(.startOfFile)
Expand All @@ -412,26 +414,25 @@ fileprivate func computeLines(_ source: String) -> ([AbsolutePosition], Absolute
return (lines, position)
}

fileprivate extension String {
fileprivate extension SyntaxText {
/// Walks and passes to `body` the `SourceLength` for every detected line,
/// with the newline character included.
/// - Returns: The leftover `SourceLength` at the end of the walk.
func forEachLineLength(
prefix: SourceLength = .zero,
body: (SourceLength) -> ()
) -> SourceLength {
let utf8 = self.utf8
let startIndex = utf8.startIndex
let endIndex = utf8.endIndex
// let startIndex = utf8.startIndex
// let endIndex = utf8.endIndex
var curIdx = startIndex
var lineLength = prefix
let advanceLengthByOne = { () -> () in
lineLength += SourceLength(utf8Length: 1)
curIdx = utf8.index(after: curIdx)
curIdx = self.index(after: curIdx)
}

while curIdx < endIndex {
let char = utf8[curIdx]
let char = self[curIdx]
advanceLengthByOne()

/// From https://docs.swift.org/swift-book/ReferenceManual/LexicalStructure.html#grammar_line-break
Expand All @@ -441,7 +442,7 @@ fileprivate extension String {
let isNewline = { () -> Bool in
if char == 10 { return true }
if char == 13 {
if curIdx < endIndex && utf8[curIdx] == 10 { advanceLengthByOne() }
if curIdx < endIndex && self[curIdx] == 10 { advanceLengthByOne() }
return true
}
return false
Expand All @@ -456,11 +457,11 @@ fileprivate extension String {
}

func containsSwiftNewline() -> Bool {
return utf8.contains { $0 == 10 || $0 == 13 }
return self.contains { $0 == 10 || $0 == 13 }
}
}

fileprivate extension TriviaPiece {
fileprivate extension RawTriviaPiece {
/// Walks and passes to `body` the `SourceLength` for every detected line,
/// with the newline character included.
/// - Returns: The leftover `SourceLength` at the end of the walk.
Expand Down Expand Up @@ -495,7 +496,7 @@ fileprivate extension TriviaPiece {
let .docLineComment(text):
// Line comments are not supposed to contain newlines.
assert(!text.containsSwiftNewline(), "line comment created that contained a new-line character")
lineLength += SourceLength(utf8Length: text.utf8.count)
lineLength += SourceLength(utf8Length: text.count)
case let .blockComment(text),
let .docBlockComment(text),
let .unexpectedText(text):
Expand All @@ -505,7 +506,7 @@ fileprivate extension TriviaPiece {
}
}

fileprivate extension Trivia {
fileprivate extension Array where Element == RawTriviaPiece {
/// Walks and passes to `body` the `SourceLength` for every detected line,
/// with the newline character included.
/// - Returns: The leftover `SourceLength` at the end of the walk.
Expand All @@ -530,9 +531,9 @@ fileprivate extension TokenSyntax {
body: (SourceLength) -> ()
) -> SourceLength {
var curPrefix = prefix
curPrefix = self.leadingTrivia.forEachLineLength(prefix: curPrefix, body: body)
curPrefix = self.text.forEachLineLength(prefix: curPrefix, body: body)
curPrefix = self.trailingTrivia.forEachLineLength(prefix: curPrefix, body: body)
curPrefix = self.tokenView.leadingRawTriviaPieces.forEachLineLength(prefix: curPrefix, body: body)
curPrefix = self.tokenView.rawText.forEachLineLength(prefix: curPrefix, body: body)
curPrefix = self.tokenView.trailingRawTriviaPieces.forEachLineLength(prefix: curPrefix, body: body)
return curPrefix
}
}
19 changes: 11 additions & 8 deletions Sources/swift-parser-cli/swift-parser-cli.swift
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,19 @@ class VerifyRoundTrip: ParsableCommand {
) throws {
let tree = Parser.parse(source: source)

_ = ParseDiagnosticsGenerator.diagnostics(for: tree)
var diags = ParseDiagnosticsGenerator.diagnostics(for: tree)

let resultTree: Syntax
if foldSequences {
resultTree = foldAllSequences(tree).0
let folded = foldAllSequences(tree)
resultTree = folded.0
diags += folded.1
} else {
resultTree = Syntax(tree)
}

_ = DiagnosticsFormatter.annotatedSource(tree: tree, diags: diags)

if resultTree.syntaxTextBytes != [UInt8](source) {
throw Error.roundTripFailed
}
Expand Down Expand Up @@ -207,6 +211,9 @@ class PrintDiags: ParsableCommand {
source.withUnsafeBufferPointer { sourceBuffer in
let tree = Parser.parse(source: sourceBuffer)
var diags = ParseDiagnosticsGenerator.diagnostics(for: tree)
if foldSequences {
diags += foldAllSequences(tree).1
}
let annotatedSource = DiagnosticsFormatter.annotatedSource(
tree: tree,
diags: diags,
Expand All @@ -215,10 +222,6 @@ class PrintDiags: ParsableCommand {

print(annotatedSource)

if foldSequences {
diags += foldAllSequences(tree).1
}

if diags.isEmpty {
print("No diagnostics produced")
}
Expand Down Expand Up @@ -424,7 +427,7 @@ class Reduce: ParsableCommand {
if verbose {
printerr("Reduced from \(source.count) to \(reduced.count) characters in \(checks) iterations")
}
let reducedString = String(decoding: reduced, as: UTF8.self)
print(reducedString)

FileHandle.standardOutput.write(Data(reduced))
}
}
142 changes: 92 additions & 50 deletions Tests/SwiftParserTest/LexerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,45 @@ import XCTest
@_spi(RawSyntax) import SwiftSyntax
@_spi(RawSyntax) import SwiftParser

fileprivate func lex(_ sourceBytes: [UInt8]) -> [Lexer.Lexeme] {
return sourceBytes.withUnsafeBufferPointer { buf in
var lexemes = [Lexer.Lexeme]()
for token in Lexer.tokenize(buf, from: 0) {
lexemes.append(token)

if token.rawTokenKind == .eof {
break
}
}
return lexemes
}
}

/// `LexemeSpec` heavily relies on string literals to represent the expected
/// values for trivia and text. While this is good for most cases, string
/// literals can't contain invalid UTF-8. Thus, we need a different assert
/// function working on byte arrays to test source code containing invalid UTF-8.
fileprivate func AssertRawBytesLexeme(
_ lexeme: Lexer.Lexeme,
kind: RawTokenKind,
leadingTrivia: [UInt8] = [],
text: [UInt8],
trailingTrivia: [UInt8] = [],
file: StaticString = #file,
line: UInt = #line
) {
XCTAssertEqual(lexeme.rawTokenKind, kind, file: file, line: line)
leadingTrivia.withUnsafeBufferPointer { leadingTrivia in
XCTAssertEqual(lexeme.leadingTriviaText, SyntaxText(buffer: leadingTrivia), file: file, line: line)
}
text.withUnsafeBufferPointer { text in
XCTAssertEqual(lexeme.tokenText, SyntaxText(buffer: text), file: file, line: line)
}
trailingTrivia.withUnsafeBufferPointer { trailingTrivia in
XCTAssertEqual(lexeme.trailingTriviaText, SyntaxText(buffer: trailingTrivia), file: file, line: line)
}
}

public class LexerTests: XCTestCase {
func testIdentifiers() {
AssertLexemes(
Expand Down Expand Up @@ -756,76 +795,79 @@ public class LexerTests: XCTestCase {

func testBOMAtStartOfFile() throws {
let sourceBytes: [UInt8] = [0xef, 0xbb, 0xbf]
let lexemes = sourceBytes.withUnsafeBufferPointer { buf in
var lexemes = [Lexer.Lexeme]()
for token in Lexer.tokenize(buf, from: 0) {
lexemes.append(token)
let lexemes = lex(sourceBytes)

if token.rawTokenKind == .eof {
break
}
}
return lexemes
guard lexemes.count == 1 else {
return XCTFail("Expected 1 lexeme, got \(lexemes.count)")
}

XCTAssertEqual(lexemes.count, 1)
let lexeme = try XCTUnwrap(lexemes.first)
XCTAssertEqual(lexeme.rawTokenKind, .eof)

let bomBytes: [UInt8] = [0xef, 0xbb, 0xbf]
bomBytes.withUnsafeBufferPointer { bomBytes in
XCTAssertEqual(lexeme.leadingTriviaText, SyntaxText(buffer: bomBytes))
}
AssertRawBytesLexeme(
lexemes[0],
kind: .eof,
leadingTrivia: sourceBytes,
text: []
)
}

func testBOMInTheMiddleOfIdentifier() throws {
let sourceBytes: [UInt8] = [UInt8(ascii: "a"), 0xef, 0xbb, 0xbf, UInt8(ascii: "b")]
let lexemes = sourceBytes.withUnsafeBufferPointer { buf in
var lexemes = [Lexer.Lexeme]()
for token in Lexer.tokenize(buf, from: 0) {
lexemes.append(token)
let lexemes = lex(sourceBytes)

if token.rawTokenKind == .eof {
break
}
}
return lexemes
guard lexemes.count == 2 else {
return XCTFail("Expected 2 lexemes, got \(lexemes.count)")
}

XCTAssertEqual(lexemes.count, 2)
let lexeme = try XCTUnwrap(lexemes.first)
XCTAssertEqual(lexeme.rawTokenKind, .identifier)

sourceBytes.withUnsafeBufferPointer { sourceBytes in
XCTAssertEqual(lexeme.tokenText, SyntaxText(buffer: sourceBytes))
}
AssertRawBytesLexeme(
lexemes[0],
kind: .identifier,
text: sourceBytes
)
}

func testBOMAsLeadingTriviaInSourceFile() throws {
let sourceBytes: [UInt8] = [UInt8(ascii: "1"), UInt8(ascii: " "), UInt8(ascii: "+"), UInt8(ascii: " "), 0xef, 0xbb, 0xbf, UInt8(ascii: "2")]
let lexemes = sourceBytes.withUnsafeBufferPointer { buf in
var lexemes = [Lexer.Lexeme]()
for token in Lexer.tokenize(buf, from: 0) {
lexemes.append(token)
let lexemes = lex(sourceBytes)

if token.rawTokenKind == .eof {
break
}
}
return lexemes
guard lexemes.count == 4 else {
return XCTFail("Expected 4 lexemes, got \(lexemes.count)")
}

guard lexemes.count == 4 else {
return XCTFail("Expected 4 lexemes")
AssertRawBytesLexeme(
lexemes[1],
kind: .binaryOperator,
text: [UInt8(ascii: "+")],
trailingTrivia: [UInt8(ascii: " "), 0xef, 0xbb, 0xbf]
)
}

func testInvalidUtf8() {
let sourceBytes: [UInt8] = [0xef, 0xfb, 0xbd, 0x0a]

let lexemes = lex(sourceBytes)
guard lexemes.count == 1 else {
return XCTFail("Expected 1 lexeme, got \(lexemes.count)")
}
let lexeme = lexemes[1]
XCTAssertEqual(lexeme.rawTokenKind, .binaryOperator)
AssertRawBytesLexeme(
lexemes[0],
kind: .eof,
leadingTrivia: sourceBytes,
text: []
)
}

func testInvalidUtf8_2() {
let sourceBytes: [UInt8] = [0xfd]

let expectedTrailingTrivia: [UInt8] = [UInt8(ascii: " "), 0xef, 0xbb, 0xbf]
expectedTrailingTrivia.withUnsafeBufferPointer { expectedTrailingTrivia in
XCTAssertEqual(lexeme.trailingTriviaText, SyntaxText(buffer: expectedTrailingTrivia))
XCTAssertEqual(lexeme.tokenText, "+")
let lexemes = lex(sourceBytes)
guard lexemes.count == 1 else {
return XCTFail("Expected 1 lexeme, got \(lexemes.count)")
}
AssertRawBytesLexeme(
lexemes[0],
kind: .eof,
leadingTrivia: sourceBytes,
text: []
)
}

func testInterpolatedString() {
Expand Down
46 changes: 46 additions & 0 deletions Tests/SwiftSyntaxTest/SourceLocationConverterTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

import XCTest
@_spi(RawSyntax) import SwiftSyntax

final class SourceLocationConverterTests: XCTestCase {
func testInvalidUtf8() {
let eofToken = withExtendedLifetime(SyntaxArena()) { arena in
let leadingTriviaText = [UInt8(0xfd)].withUnsafeBufferPointer { buf in
arena.intern(SyntaxText(buffer: buf))
}

let nodeWithInvalidUtf8 = RawTokenSyntax(
kind: .eof,
text: "",
leadingTriviaPieces: [
.unexpectedText(leadingTriviaText)
],
presence: .present,
arena: arena
)

return Syntax(raw: nodeWithInvalidUtf8.raw).cast(TokenSyntax.self)
}

let tree = SourceFileSyntax(statements: [], eofToken: eofToken)

// This used to violate the following assertion in the SourceLocationConverter's
// initializer, because we were using `String` which was lossy when handling the
// invalid UTF-8:
// ```
// assert(tree.byteSize == endOfFile.utf8Offset)
// ```
_ = SourceLocationConverter(file: "", tree: tree)
}
}