Skip to content

Commit 59ea40a

Browse files
committed
Format differentiation attributes.
This change adds correct formatting for `@differentiable`, `@derivative`, and `@transpose`. Fixes SR-12002.
1 parent d4f58d3 commit 59ea40a

File tree

3 files changed

+179
-0
lines changed

3 files changed

+179
-0
lines changed

Sources/SwiftFormatPrettyPrint/TokenStreamCreator.swift

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,56 @@ private final class TokenStreamCreator: SyntaxVisitor {
19331933
return .visitChildren
19341934
}
19351935

1936+
override func visit(_ node: DifferentiableAttributeArgumentsSyntax) -> SyntaxVisitorContinueKind {
1937+
// This node encapsulates the entire list of arguments in a `@differentiable(...)` attribute.
1938+
if let vjp = node.maybeVJP {
1939+
before(vjp.firstToken, tokens: .break(.same), .open)
1940+
after(vjp.lastToken, tokens: .close)
1941+
}
1942+
if let jvp = node.maybeJVP {
1943+
before(jvp.firstToken, tokens: .break(.same), .open)
1944+
after(jvp.lastToken, tokens: .close)
1945+
}
1946+
if let whereClause = node.whereClause {
1947+
before(whereClause.firstToken, tokens: .break(.same), .open)
1948+
after(whereClause.lastToken, tokens: .close)
1949+
}
1950+
return .visitChildren
1951+
}
1952+
1953+
override func visit(_ node: DifferentiableAttributeFuncSpecifierSyntax)
1954+
-> SyntaxVisitorContinueKind
1955+
{
1956+
// This node encapsulates the `vjp:` or `jvp:` label and decl name in a `@differentiable`
1957+
// attribute.
1958+
after(node.colon, tokens: .break(.continue, newlines: .elective(ignoresDiscretionary: true)))
1959+
return .visitChildren
1960+
}
1961+
1962+
override func visit(_ node: DerivativeRegistrationAttributeArgumentsSyntax)
1963+
-> SyntaxVisitorContinueKind
1964+
{
1965+
// This node encapsulates the entire list of arguments in a `@derivative(...)` or
1966+
// `@transpose(...)` attribute.
1967+
before(node.ofLabel, tokens: .open)
1968+
after(node.colon, tokens: .break(.continue, newlines: .elective(ignoresDiscretionary: true)))
1969+
after(node.comma, tokens: .close)
1970+
1971+
if let diffParams = node.diffParams {
1972+
before(diffParams.firstToken, tokens: .break(.same), .open)
1973+
after(diffParams.lastToken, tokens: .close)
1974+
}
1975+
1976+
return .visitChildren
1977+
}
1978+
1979+
override func visit(_ node: DifferentiationParamsClauseSyntax) -> SyntaxVisitorContinueKind {
1980+
// This node encapsulates the `wrt:` label and value/variable in a `@differentiable`,
1981+
// `@derivative`, or `@transpose` attribute.
1982+
after(node.colon, tokens: .break(.continue, newlines: .elective(ignoresDiscretionary: true)))
1983+
return .visitChildren
1984+
}
1985+
19361986
// MARK: - Nodes representing unknown or malformed syntax
19371987

19381988
override func visit(_ node: UnknownDeclSyntax) -> SyntaxVisitorContinueKind {
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
class DifferentiationAttributeTests: PrettyPrintTestCase {
2+
func testDifferentiable() {
3+
let input =
4+
"""
5+
@differentiable(wrt: x, vjp: d where T: D)
6+
func foo<T>(_ x: T) -> T {}
7+
8+
@differentiable(wrt: x, vjp: deriv where T: D)
9+
func foo<T>(_ x: T) -> T {}
10+
11+
@differentiable(wrt: x, vjp: derivativeFoo where T: Differentiable)
12+
func foo<T>(_ x: T) -> T {}
13+
14+
@differentiable(wrt: theVariableNamedX, vjp: derivativeFoo where T: Differentiable)
15+
func foo<T>(_ theVariableNamedX: T) -> T {}
16+
"""
17+
18+
let expected =
19+
"""
20+
@differentiable(wrt: x, vjp: d where T: D)
21+
func foo<T>(_ x: T) -> T {}
22+
23+
@differentiable(
24+
wrt: x, vjp: deriv where T: D
25+
)
26+
func foo<T>(_ x: T) -> T {}
27+
28+
@differentiable(
29+
wrt: x, vjp: derivativeFoo
30+
where T: Differentiable
31+
)
32+
func foo<T>(_ x: T) -> T {}
33+
34+
@differentiable(
35+
wrt: theVariableNamedX,
36+
vjp: derivativeFoo
37+
where T: Differentiable
38+
)
39+
func foo<T>(_ theVariableNamedX: T) -> T {}
40+
41+
"""
42+
43+
assertPrettyPrintEqual(input: input, expected: expected, linelength: 43)
44+
}
45+
46+
func testDerivative() {
47+
let input =
48+
"""
49+
@derivative(of: foo, wrt: x)
50+
func deriv<T>(_ x: T) {}
51+
52+
@derivative(of: foobar, wrt: x)
53+
func deriv<T>(_ x: T) {}
54+
55+
@derivative(of: foobarbaz, wrt: theVariableNamedX)
56+
func deriv<T>(_ theVariableNamedX: T) {}
57+
"""
58+
59+
let expected =
60+
"""
61+
@derivative(of: foo, wrt: x)
62+
func deriv<T>(_ x: T) {}
63+
64+
@derivative(
65+
of: foobar, wrt: x
66+
)
67+
func deriv<T>(_ x: T) {}
68+
69+
@derivative(
70+
of: foobarbaz,
71+
wrt: theVariableNamedX
72+
)
73+
func deriv<T>(
74+
_ theVariableNamedX: T
75+
) {}
76+
77+
"""
78+
79+
assertPrettyPrintEqual(input: input, expected: expected, linelength: 28)
80+
}
81+
82+
func testTranspose() {
83+
let input =
84+
"""
85+
@transpose(of: foo, wrt: 0)
86+
func trans<T>(_ v: T) {}
87+
88+
@transpose(of: foobar, wrt: 0)
89+
func trans<T>(_ v: T) {}
90+
91+
@transpose(of: someReallyLongName, wrt: 0)
92+
func trans<T>(_ theVariableNamedV: T) {}
93+
"""
94+
95+
let expected =
96+
"""
97+
@transpose(of: foo, wrt: 0)
98+
func trans<T>(_ v: T) {}
99+
100+
@transpose(
101+
of: foobar, wrt: 0
102+
)
103+
func trans<T>(_ v: T) {}
104+
105+
@transpose(
106+
of: someReallyLongName,
107+
wrt: 0
108+
)
109+
func trans<T>(
110+
_ theVariableNamedV: T
111+
) {}
112+
113+
"""
114+
115+
assertPrettyPrintEqual(input: input, expected: expected, linelength: 27)
116+
}
117+
}

Tests/SwiftFormatPrettyPrintTests/XCTestManifests.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,17 @@ extension DictionaryDeclTests {
185185
]
186186
}
187187

188+
extension DifferentiationAttributeTests {
189+
// DO NOT MODIFY: This is autogenerated, use:
190+
// `swift test --generate-linuxmain`
191+
// to regenerate.
192+
static let __allTests__DifferentiationAttributeTests = [
193+
("testDerivative", testDerivative),
194+
("testDifferentiable", testDifferentiable),
195+
("testTranspose", testTranspose),
196+
]
197+
}
198+
188199
extension EnumDeclTests {
189200
// DO NOT MODIFY: This is autogenerated, use:
190201
// `swift test --generate-linuxmain`
@@ -765,6 +776,7 @@ public func __allTests() -> [XCTestCaseEntry] {
765776
testCase(CommentTests.__allTests__CommentTests),
766777
testCase(DeinitializerDeclTests.__allTests__DeinitializerDeclTests),
767778
testCase(DictionaryDeclTests.__allTests__DictionaryDeclTests),
779+
testCase(DifferentiationAttributeTests.__allTests__DifferentiationAttributeTests),
768780
testCase(EnumDeclTests.__allTests__EnumDeclTests),
769781
testCase(ExtensionDeclTests.__allTests__ExtensionDeclTests),
770782
testCase(ForInStmtTests.__allTests__ForInStmtTests),

0 commit comments

Comments
 (0)