Skip to content

Commit 6f8efd8

Browse files
committed
Parse @transpose(of:wrt:)
1 parent 1fa4aa4 commit 6f8efd8

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

Sources/SwiftParser/Attributes.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ extension Parser {
4141
return RawSyntax(self.parseDifferentiableAttribute())
4242
case .derivative:
4343
return RawSyntax(self.parseDerivativeAttribute())
44+
case .transpose:
45+
return RawSyntax(self.parseTransposeAttribute())
4446
case .objc:
4547
return RawSyntax(self.parseObjectiveCAttribute())
4648
case ._specialize:
@@ -321,6 +323,29 @@ extension Parser {
321323
arena: self.arena)
322324
}
323325

326+
mutating func parseTransposeAttribute() -> RawAttributeSyntax {
327+
let (unexpectedBeforeAtSign, atSign) = self.expect(.atSign)
328+
let (unexpectedBeforeTranspose, transpose) = self.expectContextualKeyword("transpose")
329+
330+
let (unexpectedBeforeLeftParen, leftParen) = self.expect(.leftParen)
331+
let argument = self.parseDerivativeAttributeArguments()
332+
let (unexpectedBeforeRightParen, rightParen) = self.expect(.rightParen)
333+
334+
return RawAttributeSyntax(
335+
unexpectedBeforeAtSign,
336+
atSignToken: atSign,
337+
unexpectedBeforeTranspose,
338+
attributeName: transpose,
339+
unexpectedBeforeLeftParen,
340+
leftParen: leftParen,
341+
argument: RawSyntax(argument),
342+
unexpectedBeforeRightParen,
343+
rightParen: rightParen,
344+
tokenList: nil,
345+
arena: self.arena)
346+
}
347+
348+
324349
mutating func parseDerivativeAttributeArguments() -> RawDerivativeRegistrationAttributeArgumentsSyntax {
325350
let (unexpectedBeforeOfLabel, ofLabel) = self.expectContextualKeyword("of")
326351
let (unexpectedBetweenOfLabelAndColon, colon) = self.expect(.colon)

Tests/SwiftParserTest/Attributes.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,30 @@ final class AttributeTests: XCTestCase {
142142
) {}
143143
""")
144144
}
145+
146+
func testTransposeAttribute() {
147+
AssertParse(
148+
"""
149+
@transpose(of: +)
150+
func addTranspose(_ v: Float) -> (Float, Float) {
151+
return (v, v)
152+
}
153+
""")
154+
155+
AssertParse(
156+
"""
157+
@transpose(of: -, wrt: (0, 1))
158+
func subtractTranspose(_ v: Float) -> (Float, Float) {
159+
return (v, -v)
160+
}
161+
""")
162+
163+
AssertParse(
164+
"""
165+
@transpose(of: Float.-, wrt: (0, 1))
166+
func subtractTranspose(_ v: Float) -> (Float, Float) {
167+
return (v, -v)
168+
}
169+
""")
170+
}
145171
}

0 commit comments

Comments
 (0)