Skip to content

Commit b573b95

Browse files
authored
Merge pull request #183 from allevato/no-mo-vjp-jvp
Minor updates to `@differentiable`.
2 parents 98fc920 + 005acdd commit b573b95

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

Sources/SwiftFormatPrettyPrint/TokenStreamCreator.swift

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,10 +2035,18 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor {
20352035

20362036
override func visit(_ node: DifferentiableAttributeArgumentsSyntax) -> SyntaxVisitorContinueKind {
20372037
// This node encapsulates the entire list of arguments in a `@differentiable(...)` attribute.
2038-
after(node.diffParamsComma, tokens: .break(.same))
2039-
20402038
var needsBreakBeforeWhereClause = false
20412039

2040+
if let diffParamsComma = node.diffParamsComma {
2041+
after(diffParamsComma, tokens: .break(.same))
2042+
} else if node.diffParams != nil {
2043+
// If there were diff params but no comma following them, then we have "wrt: foo where ..."
2044+
// and we need a break before the where clause.
2045+
needsBreakBeforeWhereClause = true
2046+
}
2047+
2048+
// TODO: These properties will likely go away in a future version since the parser no longer
2049+
// reads the `vjp:` and `jvp:` arguments to `@differentiable`.
20422050
if let vjp = node.maybeVJP {
20432051
before(vjp.firstToken, tokens: .open)
20442052
after(vjp.lastToken, tokens: .close)
@@ -2051,6 +2059,7 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor {
20512059
after(jvp.trailingComma, tokens: .break(.same))
20522060
needsBreakBeforeWhereClause = true
20532061
}
2062+
20542063
if let whereClause = node.whereClause {
20552064
if needsBreakBeforeWhereClause {
20562065
before(whereClause.firstToken, tokens: .break(.same))
@@ -2066,6 +2075,8 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor {
20662075
{
20672076
// This node encapsulates the `vjp:` or `jvp:` label and decl name in a `@differentiable`
20682077
// attribute.
2078+
// TODO: This node will likely go away in a future version since the parser no longer reads the
2079+
// `vjp:` and `jvp:` arguments to `@differentiable`.
20692080
after(node.colon, tokens: .break(.continue, newlines: .elective(ignoresDiscretionary: true)))
20702081
return .visitChildren
20712082
}

Tests/SwiftFormatPrettyPrintTests/DifferentiationAttributeTests.swift

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,28 @@ final class DifferentiationAttributeTests: PrettyPrintTestCase {
22
func testDifferentiable() {
33
let input =
44
"""
5-
@differentiable(wrt: x, vjp: d where T: D)
5+
@differentiable(wrt: x where T: D)
66
func foo<T>(_ x: T) -> T {}
77
8-
@differentiable(wrt: x, vjp: deriv where T: D)
8+
@differentiable(wrt: x where T: Differentiable)
99
func foo<T>(_ x: T) -> T {}
1010
11-
@differentiable(wrt: x, vjp: derivativeFoo where T: Differentiable)
12-
func foo<T>(_ x: T) -> T {}
13-
14-
@differentiable(wrt: theVariableNamedX, vjp: derivativeFoo where T: Differentiable)
11+
@differentiable(wrt: theVariableNamedX where T: Differentiable)
1512
func foo<T>(_ theVariableNamedX: T) -> T {}
1613
"""
1714

1815
let expected =
1916
"""
20-
@differentiable(wrt: x, vjp: d where T: D)
17+
@differentiable(wrt: x where T: D)
2118
func foo<T>(_ x: T) -> T {}
2219
2320
@differentiable(
24-
wrt: x, vjp: deriv where T: D
25-
)
26-
func foo<T>(_ x: T) -> T {}
27-
28-
@differentiable(
29-
wrt: x, vjp: derivativeFoo
30-
where T: Differentiable
21+
wrt: x where T: Differentiable
3122
)
3223
func foo<T>(_ x: T) -> T {}
3324
3425
@differentiable(
35-
wrt: theVariableNamedX,
36-
vjp: derivativeFoo
26+
wrt: theVariableNamedX
3727
where T: Differentiable
3828
)
3929
func foo<T>(_ theVariableNamedX: T) -> T {}

0 commit comments

Comments
 (0)