Skip to content

Commit 4ddd488

Browse files
committed
[AutoDiff upstream] Test forward-mode differentiation diagnostics.
1 parent 1b98ca3 commit 4ddd488

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
// RUN: %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil -verify %s
2+
3+
// Test forward-mode differentiation transform diagnostics.
4+
5+
// TODO: Move these tests back into `autodiff_diagnostics.swift` once
6+
// forward mode reaches feature parity with reverse mode.
7+
8+
import _Differentiation
9+
10+
//===----------------------------------------------------------------------===//
11+
// Basic function
12+
//===----------------------------------------------------------------------===//
13+
14+
@differentiable
15+
func basic(_ x: Float) -> Float {
16+
return x
17+
}
18+
19+
//===----------------------------------------------------------------------===//
20+
// Control flow
21+
//===----------------------------------------------------------------------===//
22+
23+
// expected-error @+1 {{function is not differentiable}}
24+
@differentiable
25+
// expected-note @+2 {{when differentiating this function definition}}
26+
// expected-note @+1 {{forward-mode differentiation does not yet support control flow}}
27+
func cond(_ x: Float) -> Float {
28+
if x > 0 {
29+
return x * x
30+
}
31+
return x + x
32+
}
33+
34+
//===----------------------------------------------------------------------===//
35+
// Non-varied results
36+
//===----------------------------------------------------------------------===//
37+
38+
@differentiable
39+
func nonVariedResult(_ x: Float) -> Float {
40+
// TODO(TF-788): Re-enable non-varied result warning.
41+
// xpected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}} {{15-15=)}}
42+
return 0
43+
}
44+
45+
//===----------------------------------------------------------------------===//
46+
// Multiple results
47+
//===----------------------------------------------------------------------===//
48+
49+
// TODO(TF-983): Support differentiation of multiple results.
50+
/*
51+
func multipleResults(_ x: Float) -> (Float, Float) {
52+
return (x, x)
53+
}
54+
@differentiable
55+
func usesMultipleResults(_ x: Float) -> Float {
56+
let tuple = multipleResults(x)
57+
return tuple.0 + tuple.1
58+
}
59+
*/
60+
61+
//===----------------------------------------------------------------------===//
62+
// `inout` parameter differentiation
63+
//===----------------------------------------------------------------------===//
64+
65+
// expected-error @+1 {{function is not differentiable}}
66+
@differentiable
67+
// expected-note @+1 {{when differentiating this function definition}}
68+
func activeInoutParamNonactiveInitialResult(_ x: Float) -> Float {
69+
var result: Float = 1
70+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
71+
result += x
72+
return result
73+
}
74+
75+
// expected-error @+1 {{function is not differentiable}}
76+
@differentiable
77+
// expected-note @+1 {{when differentiating this function definition}}
78+
func activeInoutParamTuple(_ x: Float) -> Float {
79+
var tuple = (x, x)
80+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
81+
tuple.0 *= x
82+
return x * tuple.0
83+
}
84+
85+
// expected-error @+1 {{function is not differentiable}}
86+
@differentiable
87+
// expected-note @+2 {{when differentiating this function definition}}
88+
// expected-note @+1 {{forward-mode differentiation does not yet support control flow}}
89+
func activeInoutParamControlFlow(_ array: [Float]) -> Float {
90+
var result: Float = 1
91+
for i in withoutDerivative(at: array).indices {
92+
result += array[i]
93+
}
94+
return result
95+
}
96+
97+
struct Mut: Differentiable {}
98+
extension Mut {
99+
@differentiable(wrt: x)
100+
mutating func mutatingMethod(_ x: Mut) {}
101+
}
102+
103+
// FIXME(TF-984): Forward-mode crash due to unset tangent buffer.
104+
/*
105+
@differentiable(wrt: x)
106+
func nonActiveInoutParam(_ nonactive: inout Mut, _ x: Mut) -> Mut {
107+
return nonactive.mutatingMethod(x)
108+
}
109+
*/
110+
111+
// FIXME(TF-984): Forward-mode crash due to unset tangent buffer.
112+
/*
113+
@differentiable(wrt: x)
114+
func activeInoutParamMutatingMethod(_ x: Mut) -> Mut {
115+
var result = x
116+
result = result.mutatingMethod(result)
117+
return result
118+
}
119+
*/
120+
121+
// FIXME(TF-984): Forward-mode crash due to unset tangent buffer.
122+
/*
123+
@differentiable(wrt: x)
124+
func activeInoutParamMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) -> Mut {
125+
var result = nonactive
126+
result = result.mutatingMethod(x)
127+
return result
128+
}
129+
*/
130+
131+
// FIXME(TF-984): Forward-mode crash due to unset tangent buffer.
132+
/*
133+
@differentiable(wrt: x)
134+
func activeInoutParamMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) -> Mut {
135+
var result = (nonactive, x)
136+
let result2 = result.0.mutatingMethod(result.0)
137+
return result2
138+
}
139+
*/
140+
141+
//===----------------------------------------------------------------------===//
142+
// Subset parameter differentiation thunks
143+
//===----------------------------------------------------------------------===//
144+
145+
// FIXME(SR-13046): Non-differentiability diagnostic crash due to invalid source location.
146+
/*
147+
func testNoDerivativeParameter(_ f: @differentiable (Float, @noDerivative Float) -> Float) -> Float {
148+
return derivative(at: 2, 3) { (x, y) in f(x * x, y) }
149+
}
150+
*/

0 commit comments

Comments
 (0)