Skip to content

Commit 99356cd

Browse files
author
marcrasi
authored
[AutoDiff upstream] add more differentiation tests (#30933)
1 parent f069371 commit 99356cd

File tree

4 files changed

+1345
-0
lines changed

4 files changed

+1345
-0
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
3+
import _Differentiation
4+
5+
// Test supported `br`, `cond_br`, and `switch_enum` terminators.
6+
7+
@differentiable
8+
func branch(_ x: Float) -> Float {
9+
if x > 0 {
10+
return x
11+
} else if x < 10 {
12+
return x
13+
}
14+
return x
15+
}
16+
17+
enum Enum {
18+
case a(Float)
19+
case b(Float)
20+
}
21+
22+
@differentiable
23+
func enum_nonactive1(_ e: Enum, _ x: Float) -> Float {
24+
switch e {
25+
case .a: return x
26+
case .b: return x
27+
}
28+
}
29+
30+
@differentiable
31+
func enum_nonactive2(_ e: Enum, _ x: Float) -> Float {
32+
switch e {
33+
case let .a(a): return x + a
34+
case let .b(b): return x + b
35+
}
36+
}
37+
38+
// Test loops.
39+
40+
@differentiable
41+
func for_loop(_ x: Float) -> Float {
42+
var result: Float = x
43+
for _ in 0..<3 {
44+
result = result * x
45+
}
46+
return result
47+
}
48+
49+
@differentiable
50+
func while_loop(_ x: Float) -> Float {
51+
var result = x
52+
var i = 1
53+
while i < 3 {
54+
result = result * x
55+
i += 1
56+
}
57+
return result
58+
}
59+
60+
@differentiable
61+
func nested_loop(_ x: Float) -> Float {
62+
var outer = x
63+
for _ in 1..<3 {
64+
outer = outer * x
65+
66+
var inner = outer
67+
var i = 1
68+
while i < 3 {
69+
inner = inner / x
70+
i += 1
71+
}
72+
outer = inner
73+
}
74+
return outer
75+
}
76+
77+
// TF-433: Test throwing functions.
78+
79+
func rethrowing(_ x: () throws -> Void) rethrows -> Void {}
80+
81+
// expected-error @+1 {{function is not differentiable}}
82+
@differentiable
83+
// expected-note @+1 {{when differentiating this function definition}}
84+
func testTryApply(_ x: Float) -> Float {
85+
// expected-note @+1 {{cannot differentiate unsupported control flow}}
86+
rethrowing({})
87+
return x
88+
}
89+
90+
// expected-error @+1 {{function is not differentiable}}
91+
@differentiable
92+
// expected-note @+1 {{when differentiating this function definition}}
93+
func withoutDerivative<T : Differentiable, R: Differentiable>(
94+
at x: T, in body: (T) throws -> R
95+
) rethrows -> R {
96+
// expected-note @+1 {{cannot differentiate unsupported control flow}}
97+
try body(x)
98+
}
99+
100+
// Test unsupported differentiation of active enum values.
101+
102+
// expected-error @+1 {{function is not differentiable}}
103+
@differentiable
104+
// expected-note @+1 {{when differentiating this function definition}}
105+
func enum_active(_ x: Float) -> Float {
106+
// expected-note @+1 {{differentiating enum values is not yet supported}}
107+
let e: Enum
108+
if x > 0 {
109+
e = .a(x)
110+
} else {
111+
e = .b(x)
112+
}
113+
switch e {
114+
case let .a(a): return x + a
115+
case let .b(b): return x + b
116+
}
117+
}
118+
119+
enum Tree : Differentiable & AdditiveArithmetic {
120+
case leaf(Float)
121+
case branch(Float, Float)
122+
123+
typealias TangentVector = Self
124+
typealias AllDifferentiableVariables = Self
125+
static var zero: Self { .leaf(0) }
126+
127+
// expected-error @+1 {{function is not differentiable}}
128+
@differentiable
129+
// TODO(TF-956): Improve location of active enum non-differentiability errors
130+
// so that they are closer to the source of the non-differentiability.
131+
// expected-note @+2 {{when differentiating this function definition}}
132+
// expected-note @+1 {{differentiating enum values is not yet supported}}
133+
static func +(_ lhs: Self, _ rhs: Self) -> Self {
134+
switch (lhs, rhs) {
135+
case let (.leaf(x), .leaf(y)):
136+
return .leaf(x + y)
137+
case let (.branch(x1, x2), .branch(y1, y2)):
138+
return .branch(x1 + x2, y1 + y2)
139+
default:
140+
fatalError()
141+
}
142+
}
143+
144+
// expected-error @+1 {{function is not differentiable}}
145+
@differentiable
146+
// TODO(TF-956): Improve location of active enum non-differentiability errors
147+
// so that they are closer to the source of the non-differentiability.
148+
// expected-note @+2 {{when differentiating this function definition}}
149+
// expected-note @+1 {{differentiating enum values is not yet supported}}
150+
static func -(_ lhs: Self, _ rhs: Self) -> Self {
151+
switch (lhs, rhs) {
152+
case let (.leaf(x), .leaf(y)):
153+
return .leaf(x - y)
154+
case let (.branch(x1, x2), .branch(y1, y2)):
155+
return .branch(x1 - x2, y1 - y2)
156+
default:
157+
fatalError()
158+
}
159+
}
160+
}
161+
162+
// expected-error @+1 {{function is not differentiable}}
163+
@differentiable
164+
// expected-note @+1 {{when differentiating this function definition}}
165+
func loop_array(_ array: [Float]) -> Float {
166+
var result: Float = 1
167+
// TODO(TF-957): Improve non-differentiability errors for for-in loops
168+
// (`Collection.makeIterator` and `IteratorProtocol.next`).
169+
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
170+
for x in array {
171+
result = result * x
172+
}
173+
return result
174+
}

0 commit comments

Comments
 (0)