Skip to content

Commit b287015

Browse files
authored
Merge pull request #791 from CodaFi/derivative!
Parse @Derivative(of:wrt:) and @transpose(of:wrt:)
2 parents 50729b2 + 6f8efd8 commit b287015

File tree

3 files changed

+183
-0
lines changed

3 files changed

+183
-0
lines changed

Sources/SwiftParser/Attributes.swift

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ extension Parser {
3939
return RawSyntax(self.parseAvailabilityAttribute())
4040
case .differentiable:
4141
return RawSyntax(self.parseDifferentiableAttribute())
42+
case .derivative:
43+
return RawSyntax(self.parseDerivativeAttribute())
44+
case .transpose:
45+
return RawSyntax(self.parseTransposeAttribute())
4246
case .objc:
4347
return RawSyntax(self.parseObjectiveCAttribute())
4448
case ._specialize:
@@ -296,6 +300,84 @@ extension Parser {
296300
}
297301
}
298302

303+
extension Parser {
304+
mutating func parseDerivativeAttribute() -> RawAttributeSyntax {
305+
let (unexpectedBeforeAtSign, atSign) = self.expect(.atSign)
306+
let (unexpectedBeforeDerivative, derivative) = self.expectContextualKeyword("derivative")
307+
308+
let (unexpectedBeforeLeftParen, leftParen) = self.expect(.leftParen)
309+
let argument = self.parseDerivativeAttributeArguments()
310+
let (unexpectedBeforeRightParen, rightParen) = self.expect(.rightParen)
311+
312+
return RawAttributeSyntax(
313+
unexpectedBeforeAtSign,
314+
atSignToken: atSign,
315+
unexpectedBeforeDerivative,
316+
attributeName: derivative,
317+
unexpectedBeforeLeftParen,
318+
leftParen: leftParen,
319+
argument: RawSyntax(argument),
320+
unexpectedBeforeRightParen,
321+
rightParen: rightParen,
322+
tokenList: nil,
323+
arena: self.arena)
324+
}
325+
326+
mutating func parseTransposeAttribute() -> RawAttributeSyntax {
327+
let (unexpectedBeforeAtSign, atSign) = self.expect(.atSign)
328+
let (unexpectedBeforeTranspose, transpose) = self.expectContextualKeyword("transpose")
329+
330+
let (unexpectedBeforeLeftParen, leftParen) = self.expect(.leftParen)
331+
let argument = self.parseDerivativeAttributeArguments()
332+
let (unexpectedBeforeRightParen, rightParen) = self.expect(.rightParen)
333+
334+
return RawAttributeSyntax(
335+
unexpectedBeforeAtSign,
336+
atSignToken: atSign,
337+
unexpectedBeforeTranspose,
338+
attributeName: transpose,
339+
unexpectedBeforeLeftParen,
340+
leftParen: leftParen,
341+
argument: RawSyntax(argument),
342+
unexpectedBeforeRightParen,
343+
rightParen: rightParen,
344+
tokenList: nil,
345+
arena: self.arena)
346+
}
347+
348+
349+
mutating func parseDerivativeAttributeArguments() -> RawDerivativeRegistrationAttributeArgumentsSyntax {
350+
let (unexpectedBeforeOfLabel, ofLabel) = self.expectContextualKeyword("of")
351+
let (unexpectedBetweenOfLabelAndColon, colon) = self.expect(.colon)
352+
let originalDeclName = self.parseQualifiedDeclarationName()
353+
let period = self.consume(if: .period)
354+
let accessor: RawTokenSyntax?
355+
if period != nil {
356+
accessor = self.parseAnyIdentifier()
357+
} else {
358+
accessor = nil
359+
}
360+
let comma = self.consume(if: .comma)
361+
let diffParams: RawDifferentiabilityParamsClauseSyntax?
362+
if comma != nil {
363+
diffParams = self.parseDifferentiabilityParameters()
364+
} else {
365+
diffParams = nil
366+
}
367+
return RawDerivativeRegistrationAttributeArgumentsSyntax(
368+
unexpectedBeforeOfLabel,
369+
ofLabel: ofLabel,
370+
unexpectedBetweenOfLabelAndColon,
371+
colon: colon,
372+
originalDeclName: originalDeclName,
373+
period: period,
374+
accessorKind: accessor,
375+
comma: comma,
376+
diffParams: diffParams,
377+
arena: self.arena)
378+
}
379+
}
380+
299381
extension Parser {
300382
mutating func parseObjectiveCAttribute() -> RawAttributeSyntax {
301383
let (unexpectedBeforeAtSign, atSign) = self.expect(.atSign)

Sources/SwiftParser/Names.swift

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,42 @@ extension Parser {
127127
}
128128
}
129129

130+
extension Parser {
131+
mutating func parseQualifiedDeclarationName() -> RawQualifiedDeclNameSyntax {
132+
let type: RawTypeSyntax?
133+
let dot: RawTokenSyntax?
134+
if self.lookahead().canParseBaseTypeForQualifiedDeclName() {
135+
type = self.parseTypeIdentifier()
136+
dot = self.consumePrefix(".", as: .period)
137+
} else {
138+
type = nil
139+
dot = nil
140+
}
141+
142+
let (name, args) = self.parseDeclNameRef([
143+
.zeroArgCompoundNames,
144+
.keywordsUsingSpecialNames,
145+
.operators,
146+
])
147+
return RawQualifiedDeclNameSyntax(
148+
baseType: type,
149+
dot: dot,
150+
name: name,
151+
arguments: args,
152+
arena: self.arena)
153+
}
154+
}
155+
156+
extension Parser.Lookahead {
157+
func canParseBaseTypeForQualifiedDeclName() -> Bool {
158+
var lookahead = self.lookahead()
159+
guard lookahead.canParseSimpleTypeIdentifier() else {
160+
return false
161+
}
162+
return lookahead.currentToken.starts(with: ".")
163+
}
164+
}
165+
130166
extension Parser.Lookahead {
131167
func canParseArgumentLabelList() -> Bool {
132168
var lookahead = self.lookahead()

Tests/SwiftParserTest/Attributes.swift

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,69 @@ final class AttributeTests: XCTestCase {
103103
@_Concurrency.MainActor(unsafe) public struct Image : SwiftUI.View {}
104104
""")
105105
}
106+
107+
func testDerivativeAttribute() {
108+
AssertParse(
109+
"""
110+
@inlinable
111+
@differentiable(reverse, wrt: self)
112+
public func differentiableMap<Result: Differentiable>(
113+
_ body: @differentiable(reverse) (Element) -> Result
114+
) -> [Result] {
115+
map(body)
116+
}
117+
""")
118+
119+
AssertParse(
120+
"""
121+
@inlinable
122+
@differentiable(reverse, wrt: (self, initialResult))
123+
public func differentiableReduce<Result: Differentiable>(
124+
_ initialResult: Result,
125+
_ nextPartialResult: @differentiable(reverse) (Result, Element) -> Result
126+
) -> Result {
127+
reduce(initialResult, nextPartialResult)
128+
}
129+
""")
130+
131+
AssertParse(
132+
"""
133+
@inlinable
134+
@derivative(of: differentiableReduce)
135+
internal func _vjpDifferentiableReduce<Result: Differentiable>(
136+
_ initialResult: Result,
137+
_ nextPartialResult: @differentiable(reverse) (Result, Element) -> Result
138+
) -> (
139+
value: Result,
140+
pullback: (Result.TangentVector)
141+
-> (Array.TangentVector, Result.TangentVector)
142+
) {}
143+
""")
144+
}
145+
146+
func testTransposeAttribute() {
147+
AssertParse(
148+
"""
149+
@transpose(of: +)
150+
func addTranspose(_ v: Float) -> (Float, Float) {
151+
return (v, v)
152+
}
153+
""")
154+
155+
AssertParse(
156+
"""
157+
@transpose(of: -, wrt: (0, 1))
158+
func subtractTranspose(_ v: Float) -> (Float, Float) {
159+
return (v, -v)
160+
}
161+
""")
162+
163+
AssertParse(
164+
"""
165+
@transpose(of: Float.-, wrt: (0, 1))
166+
func subtractTranspose(_ v: Float) -> (Float, Float) {
167+
return (v, -v)
168+
}
169+
""")
170+
}
106171
}

0 commit comments

Comments
 (0)