Skip to content

Parse @derivative(of:wrt:) and @transpose(of:wrt:) #791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions Sources/SwiftParser/Attributes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ extension Parser {
return RawSyntax(self.parseAvailabilityAttribute())
case .differentiable:
return RawSyntax(self.parseDifferentiableAttribute())
case .derivative:
return RawSyntax(self.parseDerivativeAttribute())
case .transpose:
return RawSyntax(self.parseTransposeAttribute())
case .objc:
return RawSyntax(self.parseObjectiveCAttribute())
case ._specialize:
Expand Down Expand Up @@ -296,6 +300,84 @@ extension Parser {
}
}

extension Parser {
mutating func parseDerivativeAttribute() -> RawAttributeSyntax {
let (unexpectedBeforeAtSign, atSign) = self.expect(.atSign)
let (unexpectedBeforeDerivative, derivative) = self.expectContextualKeyword("derivative")

let (unexpectedBeforeLeftParen, leftParen) = self.expect(.leftParen)
let argument = self.parseDerivativeAttributeArguments()
let (unexpectedBeforeRightParen, rightParen) = self.expect(.rightParen)

return RawAttributeSyntax(
unexpectedBeforeAtSign,
atSignToken: atSign,
unexpectedBeforeDerivative,
attributeName: derivative,
unexpectedBeforeLeftParen,
leftParen: leftParen,
argument: RawSyntax(argument),
unexpectedBeforeRightParen,
rightParen: rightParen,
tokenList: nil,
arena: self.arena)
}

mutating func parseTransposeAttribute() -> RawAttributeSyntax {
let (unexpectedBeforeAtSign, atSign) = self.expect(.atSign)
let (unexpectedBeforeTranspose, transpose) = self.expectContextualKeyword("transpose")

let (unexpectedBeforeLeftParen, leftParen) = self.expect(.leftParen)
let argument = self.parseDerivativeAttributeArguments()
let (unexpectedBeforeRightParen, rightParen) = self.expect(.rightParen)

return RawAttributeSyntax(
unexpectedBeforeAtSign,
atSignToken: atSign,
unexpectedBeforeTranspose,
attributeName: transpose,
unexpectedBeforeLeftParen,
leftParen: leftParen,
argument: RawSyntax(argument),
unexpectedBeforeRightParen,
rightParen: rightParen,
tokenList: nil,
arena: self.arena)
}


mutating func parseDerivativeAttributeArguments() -> RawDerivativeRegistrationAttributeArgumentsSyntax {
let (unexpectedBeforeOfLabel, ofLabel) = self.expectContextualKeyword("of")
let (unexpectedBetweenOfLabelAndColon, colon) = self.expect(.colon)
let originalDeclName = self.parseQualifiedDeclarationName()
let period = self.consume(if: .period)
let accessor: RawTokenSyntax?
if period != nil {
accessor = self.parseAnyIdentifier()
} else {
accessor = nil
}
let comma = self.consume(if: .comma)
let diffParams: RawDifferentiabilityParamsClauseSyntax?
if comma != nil {
diffParams = self.parseDifferentiabilityParameters()
} else {
diffParams = nil
}
return RawDerivativeRegistrationAttributeArgumentsSyntax(
unexpectedBeforeOfLabel,
ofLabel: ofLabel,
unexpectedBetweenOfLabelAndColon,
colon: colon,
originalDeclName: originalDeclName,
period: period,
accessorKind: accessor,
comma: comma,
diffParams: diffParams,
arena: self.arena)
}
}

extension Parser {
mutating func parseObjectiveCAttribute() -> RawAttributeSyntax {
let (unexpectedBeforeAtSign, atSign) = self.expect(.atSign)
Expand Down
36 changes: 36 additions & 0 deletions Sources/SwiftParser/Names.swift
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,42 @@ extension Parser {
}
}

extension Parser {
mutating func parseQualifiedDeclarationName() -> RawQualifiedDeclNameSyntax {
let type: RawTypeSyntax?
let dot: RawTokenSyntax?
if self.lookahead().canParseBaseTypeForQualifiedDeclName() {
type = self.parseTypeIdentifier()
dot = self.consumePrefix(".", as: .period)
} else {
type = nil
dot = nil
}

let (name, args) = self.parseDeclNameRef([
.zeroArgCompoundNames,
.keywordsUsingSpecialNames,
.operators,
])
return RawQualifiedDeclNameSyntax(
baseType: type,
dot: dot,
name: name,
arguments: args,
arena: self.arena)
}
}

extension Parser.Lookahead {
func canParseBaseTypeForQualifiedDeclName() -> Bool {
var lookahead = self.lookahead()
guard lookahead.canParseSimpleTypeIdentifier() else {
return false
}
return lookahead.currentToken.starts(with: ".")
}
}

extension Parser.Lookahead {
func canParseArgumentLabelList() -> Bool {
var lookahead = self.lookahead()
Expand Down
65 changes: 65 additions & 0 deletions Tests/SwiftParserTest/Attributes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,69 @@ final class AttributeTests: XCTestCase {
@_Concurrency.MainActor(unsafe) public struct Image : SwiftUI.View {}
""")
}

func testDerivativeAttribute() {
AssertParse(
"""
@inlinable
@differentiable(reverse, wrt: self)
public func differentiableMap<Result: Differentiable>(
_ body: @differentiable(reverse) (Element) -> Result
) -> [Result] {
map(body)
}
""")

AssertParse(
"""
@inlinable
@differentiable(reverse, wrt: (self, initialResult))
public func differentiableReduce<Result: Differentiable>(
_ initialResult: Result,
_ nextPartialResult: @differentiable(reverse) (Result, Element) -> Result
) -> Result {
reduce(initialResult, nextPartialResult)
}
""")

AssertParse(
"""
@inlinable
@derivative(of: differentiableReduce)
internal func _vjpDifferentiableReduce<Result: Differentiable>(
_ initialResult: Result,
_ nextPartialResult: @differentiable(reverse) (Result, Element) -> Result
) -> (
value: Result,
pullback: (Result.TangentVector)
-> (Array.TangentVector, Result.TangentVector)
) {}
""")
}

func testTransposeAttribute() {
AssertParse(
"""
@transpose(of: +)
func addTranspose(_ v: Float) -> (Float, Float) {
return (v, v)
}
""")

AssertParse(
"""
@transpose(of: -, wrt: (0, 1))
func subtractTranspose(_ v: Float) -> (Float, Float) {
return (v, -v)
}
""")

AssertParse(
"""
@transpose(of: Float.-, wrt: (0, 1))
func subtractTranspose(_ v: Float) -> (Float, Float) {
return (v, -v)
}
""")
}
}