Skip to content

Restrict the maximum nesting level in the parser to avoid stack overflows #1030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion Sources/SwiftParser/Declarations.swift
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,18 @@ extension Parser {

@_spi(RawSyntax)
public mutating func parseGenericParameters() -> RawGenericParameterClauseSyntax {
assert(self.currentToken.starts(with: "<"))
if let remainingTokens = remainingTokensIfMaximumNestingLevelReached() {
return RawGenericParameterClauseSyntax(
remainingTokens,
leftAngleBracket: missingToken(.leftAngle),
genericParameterList: RawGenericParameterListSyntax(elements: [], arena: self.arena),
genericWhereClause: nil,
rightAngleBracket: missingToken(.rightAngle),
arena: self.arena
)
}

assert(self.currentToken.starts(with: "<"))
let langle = self.consumeAnyToken(remapping: .leftAngle)
var elements = [RawGenericParameterSyntax]()
do {
Expand Down Expand Up @@ -616,6 +626,16 @@ extension Parser {
extension Parser {
@_spi(RawSyntax)
public mutating func parseMemberDeclListItem() -> RawMemberDeclListItemSyntax? {
if let remainingTokens = remainingTokensIfMaximumNestingLevelReached() {
let item = RawMemberDeclListItemSyntax(
remainingTokens,
decl: RawDeclSyntax(RawMissingDeclSyntax(attributes: nil, modifiers: nil, arena: self.arena)),
semicolon: nil,
arena: self.arena
)
return item
}

let decl: RawDeclSyntax
if self.at(.poundSourceLocationKeyword) {
decl = RawDeclSyntax(self.parsePoundSourceLocationDirective())
Expand Down
9 changes: 9 additions & 0 deletions Sources/SwiftParser/Directives.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ extension Parser {
addSemicolonIfNeeded: (_ lastElement: Element, _ newItemAtStartOfLine: Bool, _ parser: inout Parser) -> Element? = { _, _, _ in nil },
syntax: (inout Parser, [Element]) -> RawSyntax?
) -> RawIfConfigDeclSyntax {
if let remainingTokens = remainingTokensIfMaximumNestingLevelReached() {
return RawIfConfigDeclSyntax(
remainingTokens,
clauses: RawIfConfigClauseListSyntax(elements: [], arena: self.arena),
poundEndif: missingToken(.poundEndifKeyword),
arena: self.arena
)
}

var clauses = [RawIfConfigClauseSyntax]()
do {
var firstIteration = true
Expand Down
21 changes: 21 additions & 0 deletions Sources/SwiftParser/Expressions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1757,6 +1757,16 @@ extension Parser {
/// dictionary-literal-items → dictionary-literal-item ','? | dictionary-literal-item ',' dictionary-literal-items
@_spi(RawSyntax)
public mutating func parseCollectionLiteral() -> RawExprSyntax {
if let remainingTokens = remainingTokensIfMaximumNestingLevelReached() {
return RawExprSyntax(RawArrayExprSyntax(
remainingTokens,
leftSquare: missingToken(.leftSquareBracket),
elements: RawArrayElementListSyntax(elements: [], arena: self.arena),
rightSquare: missingToken(.rightSquareBracket),
arena: self.arena
))
}

let (unexpectedBeforeLSquare, lsquare) = self.expect(.leftSquareBracket)

if let rsquare = self.consume(if: .rightSquareBracket) {
Expand Down Expand Up @@ -2177,6 +2187,17 @@ extension Parser {
/// tuple-element → expression | identifier ':' expression
@_spi(RawSyntax)
public mutating func parseArgumentListElements(pattern: PatternContext) -> [RawTupleExprElementSyntax] {
if let remainingTokens = remainingTokensIfMaximumNestingLevelReached() {
return [RawTupleExprElementSyntax(
remainingTokens,
label: nil,
colon: nil,
expression: RawExprSyntax(RawMissingExprSyntax(arena: self.arena)),
trailingComma: nil,
arena: self.arena
)]
}

guard !self.at(.rightParen) else {
return []
}
Expand Down
147 changes: 87 additions & 60 deletions Sources/SwiftParser/Lookahead.swift
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,7 @@ extension Parser.Lookahead {

extension Parser.Lookahead {
mutating func skipUntil(_ t1: RawTokenKind, _ t2: RawTokenKind) {
while !self.at(any: [.eof, t1, t2, .poundEndifKeyword, .poundElseKeyword, .poundElseifKeyword]) {
self.skipSingle()
}
return skip(initialState: .skipUntil(t1, t2))
}

mutating func skipUntilEndOfLine() {
Expand All @@ -347,70 +345,99 @@ extension Parser.Lookahead {
}

mutating func skipSingle() {
enum BracketedTokens: RawTokenKindSubset {
case leftParen
case leftBrace
case leftSquareBracket
case poundIfKeyword
case poundElseKeyword
case poundElseifKeyword

init?(lexeme: Lexer.Lexeme) {
switch lexeme.tokenKind {
case .leftParen: self = .leftParen
case .leftBrace: self = .leftBrace
case .leftSquareBracket: self = .leftSquareBracket
case .poundIfKeyword: self = .poundIfKeyword
case .poundElseKeyword: self = .poundElseKeyword
case .poundElseifKeyword: self = .poundElseifKeyword
default: return nil
}
return skip(initialState: .skipSingle)
}

private enum BracketedTokens: RawTokenKindSubset {
case leftParen
case leftBrace
case leftSquareBracket
case poundIfKeyword
case poundElseKeyword
case poundElseifKeyword

init?(lexeme: Lexer.Lexeme) {
switch lexeme.tokenKind {
case .leftParen: self = .leftParen
case .leftBrace: self = .leftBrace
case .leftSquareBracket: self = .leftSquareBracket
case .poundIfKeyword: self = .poundIfKeyword
case .poundElseKeyword: self = .poundElseKeyword
case .poundElseifKeyword: self = .poundElseifKeyword
default: return nil
}
}

var rawTokenKind: RawTokenKind {
switch self {
case .leftParen: return .leftParen
case .leftBrace: return .leftBrace
case .leftSquareBracket: return .leftSquareBracket
case .poundIfKeyword: return .poundIfKeyword
case .poundElseKeyword: return .poundElseKeyword
case .poundElseifKeyword: return .poundElseifKeyword
}
var rawTokenKind: RawTokenKind {
switch self {
case .leftParen: return .leftParen
case .leftBrace: return .leftBrace
case .leftSquareBracket: return .leftSquareBracket
case .poundIfKeyword: return .poundIfKeyword
case .poundElseKeyword: return .poundElseKeyword
case .poundElseifKeyword: return .poundElseifKeyword
}
}
}

switch self.at(anyIn: BracketedTokens.self) {
case (.leftParen, let handle)?:
self.eat(handle)
self.skipUntil(.rightParen, .rightBrace)
self.consume(if: .rightParen)
return
case (.leftBrace, let handle)?:
self.eat(handle)
self.skipUntil(.rightBrace, .rightBrace)
self.consume(if: .rightBrace)
return
case (.leftSquareBracket, let handle)?:
self.eat(handle)
self.skipUntil(.rightSquareBracket, .rightSquareBracket)
self.consume(if: .rightSquareBracket)
return
case (.poundIfKeyword, let handle)?,
(.poundElseKeyword, let handle)?,
(.poundElseifKeyword, let handle)?:
self.eat(handle)
// skipUntil also implicitly stops at tok::pound_endif.
self.skipUntil(.poundElseKeyword, .poundElseifKeyword)
private enum SkippingState {
/// Equivalent to a call to `skipSingle`. Skip a single token.
/// If that token is bracketed, skip until the closing bracket
case skipSingle
/// Execute code after skipping bracketed tokens detected from `skipSingle`.
case skipSinglePost(start: BracketedTokens)
/// Skip until either `t1` or `t2`.
case skipUntil(_ t1: RawTokenKind, _ t2: RawTokenKind)
}

if self.at(any: [.poundElseKeyword, .poundElseifKeyword]) {
self.skipSingle()
} else {
self.consume(if: .poundElseifKeyword)
/// A non-recursie function to skip tokens.
private mutating func skip(initialState: SkippingState) {
var stack: [SkippingState] = [initialState]

while let state = stack.popLast() {
switch state {
case .skipSingle:
let t = self.at(anyIn: BracketedTokens.self)
switch t {
case (.leftParen, let handle)?:
self.eat(handle)
stack += [.skipSinglePost(start: .leftParen), .skipUntil(.rightParen, .rightBrace)]
case (.leftBrace, let handle)?:
self.eat(handle)
stack += [.skipSinglePost(start: .leftBrace), .skipUntil(.rightBrace, .rightBrace)]
case (.leftSquareBracket, let handle)?:
self.eat(handle)
stack += [.skipSinglePost(start: .leftSquareBracket), .skipUntil(.rightSquareBracket, .rightSquareBracket)]
case (.poundIfKeyword, let handle)?,
(.poundElseKeyword, let handle)?,
(.poundElseifKeyword, let handle)?:
self.eat(handle)
// skipUntil also implicitly stops at tok::pound_endif.
stack += [.skipSinglePost(start: t!.0), .skipUntil(.poundElseKeyword, .poundElseifKeyword)]
case nil:
self.consumeAnyToken()
}
case .skipSinglePost(start: let start):
switch start {
case .leftParen:
self.consume(if: .rightParen)
case .leftBrace:
self.consume(if: .rightBrace)
case .leftSquareBracket:
self.consume(if: .rightSquareBracket)
case .poundIfKeyword, .poundElseKeyword, .poundElseifKeyword:
if self.at(any: [.poundElseKeyword, .poundElseifKeyword]) {
stack += [.skipSingle]
} else {
self.consume(if: .poundElseifKeyword)
}
return
}
case .skipUntil(let t1, let t2):
if !self.at(any: [.eof, t1, t2, .poundEndifKeyword, .poundElseKeyword, .poundElseifKeyword]) {
stack += [.skipUntil(t1, t2), .skipSingle]
}
}
return
case nil:
self.consumeAnyToken()
return
}
}
}
50 changes: 47 additions & 3 deletions Sources/SwiftParser/Parser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ extension Parser {
}

/// Parse the source code in the given string as Swift source file.
/// If `maximumNestingLevel` is set, the parser will stop if a nesting level
/// that is greater than this value is reached to avoid overflowing the stack.
/// The nesting level is increased whenever a bracketed expression like `(`
/// or `{` is stared.
public static func parse(
source: UnsafeBufferPointer<UInt8>,
maximumNestingLevel: Int? = nil,
parseTransition: IncrementalParseTransition? = nil
) -> SourceFileSyntax {
var parser = Parser(source)
var parser = Parser(source, maximumNestingLevel: maximumNestingLevel)
// Extended lifetime is required because `SyntaxArena` in the parser must
// be alive until `Syntax(raw:)` retains the arena.
return withExtendedLifetime(parser) {
Expand Down Expand Up @@ -122,6 +127,23 @@ public struct Parser: TokenConsumer {
@_spi(RawSyntax)
public var currentToken: Lexer.Lexeme

/// The current nesting level, i.e. the number of tokens that
/// `startNestingLevel` minus the number of tokens that `endNestingLevel`
/// which have been consumed so far.
public var nestingLevel: Int = 0

/// When this nesting level is exceeded, the parser should stop parsing.
public let maximumNestingLevel: Int

/// A default maximum nesting level that is used if the client didn't
/// explicitly specify one. Debug builds of the parser comume a lot more stack
/// space and thus have a lower default maximum nesting level.
#if DEBUG
public static let defaultMaximumNestingLevel = 25
#else
public static let defaultMaximumNestingLevel = 256
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's terrible, but should we do some kind of #if DEBUG check to decide whether to use the lower limit or higher limit?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s actually a good idea. The entire check is pretty horrible in itself, so a #if DEBUG check doesn’t make a huge difference anymore.

#endif

/// Initializes a Parser from the given input buffer.
///
/// The lexer will copy any string data it needs from the resulting buffer
Expand All @@ -133,7 +155,9 @@ public struct Parser: TokenConsumer {
/// arena is created automatically, and `input` copied into the
/// arena. If non-`nil`, `input` must be the registered source
/// buffer of `arena` or a slice of the source buffer.
public init(_ input: UnsafeBufferPointer<UInt8>, arena: SyntaxArena? = nil) {
public init(_ input: UnsafeBufferPointer<UInt8>, maximumNestingLevel: Int? = nil, arena: SyntaxArena? = nil) {
self.maximumNestingLevel = maximumNestingLevel ?? Self.defaultMaximumNestingLevel

var sourceBuffer: UnsafeBufferPointer<UInt8>
if let arena = arena {
self.arena = arena
Expand All @@ -150,6 +174,7 @@ public struct Parser: TokenConsumer {

@_spi(RawSyntax)
public mutating func missingToken(_ kind: RawTokenKind, text: SyntaxText? = nil) -> RawTokenSyntax {
adjustNestingLevel(for: kind)
return RawTokenSyntax(missing: kind, text: text, arena: self.arena)
}

Expand All @@ -158,6 +183,12 @@ public struct Parser: TokenConsumer {
/// - Returns: The token that was consumed.
@_spi(RawSyntax)
public mutating func consumeAnyToken() -> RawTokenSyntax {
adjustNestingLevel(for: self.currentToken.tokenKind)
return self.consumeAnyTokenWithoutAdjustingNestingLevel()
}

@_spi(RawSyntax)
public mutating func consumeAnyTokenWithoutAdjustingNestingLevel() -> RawTokenSyntax {
let tok = self.currentToken
self.currentToken = self.lexemes.advance()
return RawTokenSyntax(
Expand All @@ -168,6 +199,17 @@ public struct Parser: TokenConsumer {
arena: arena
)
}

private mutating func adjustNestingLevel(for tokenKind: RawTokenKind) {
switch tokenKind {
case .leftAngle, .leftBrace, .leftParen, .leftSquareBracket, .poundIfKeyword:
nestingLevel += 1
case .rightAngle, .rightBrace, .rightParen, .rightSquareBracket, .poundEndifKeyword:
nestingLevel -= 1
default:
break
}
}
}

// MARK: Inspecting Tokens
Expand Down Expand Up @@ -279,7 +321,7 @@ extension Parser {
if handle.unexpectedTokens > 0 {
var unexpectedTokens = [RawSyntax]()
for _ in 0..<handle.unexpectedTokens {
unexpectedTokens.append(RawSyntax(self.consumeAnyToken()))
unexpectedTokens.append(RawSyntax(self.consumeAnyTokenWithoutAdjustingNestingLevel()))
}
unexpectedNodes = RawUnexpectedNodesSyntax(elements: unexpectedTokens, arena: self.arena)
} else {
Expand Down Expand Up @@ -512,6 +554,8 @@ extension Parser {
arena: self.arena
)

self.adjustNestingLevel(for: tokenKind)

// ... or a multi-character token with the first N characters being the one
// that we want to consume as a separate token.
// Careful: We need to reset the lexer to a point just before it saw the
Expand Down
12 changes: 12 additions & 0 deletions Sources/SwiftParser/Patterns.swift
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,18 @@ extension Parser {
/// tuple-pattern-element-list → tuple-pattern-element | tuple-pattern-element ',' tuple-pattern-element-list
/// tuple-pattern-element → pattern | identifier ':' pattern
mutating func parsePatternTupleElements() -> RawTuplePatternElementListSyntax {
if let remainingTokens = remainingTokensIfMaximumNestingLevelReached() {
return RawTuplePatternElementListSyntax(elements: [
RawTuplePatternElementSyntax(
remainingTokens,
labelName: nil,
labelColon: nil,
pattern: RawPatternSyntax(RawMissingPatternSyntax(arena: self.arena)),
trailingComma: nil,
arena: self.arena
)
], arena: self.arena)
}
var elements = [RawTuplePatternElementSyntax]()
do {
var keepGoing = true
Expand Down
Loading