Skip to content

Commit 6bf823b

Browse files
authored
[AutoDiff] Add control flow AD leak checking tests. (#25249)
Add leak checking tests. Expose memory leak unrelated to control flow (TF-550).
1 parent 045e192 commit 6bf823b

File tree

2 files changed

+175
-24
lines changed

2 files changed

+175
-24
lines changed

stdlib/private/DifferentiationUnittest/GenericLifetimeTracked.swift

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,35 @@ public enum _GlobalLeakCount {
2020
/// automatic differentiation.
2121
public struct Tracked<T> {
2222
fileprivate class Box {
23-
fileprivate let value : T
23+
fileprivate var value : T
2424
init(_ value: T) {
2525
self.value = value
26-
_GlobalLeakCount.count += 1
26+
_GlobalLeakCount.count += 1
2727
}
2828
deinit {
2929
_GlobalLeakCount.count -= 1
3030
}
3131
}
32-
private let handle: Box
32+
private var handle: Box
33+
34+
@differentiable(
35+
vjp: _vjpInit
36+
where T : Differentiable, T == T.AllDifferentiableVariables,
37+
T == T.TangentVector
38+
)
3339
public init(_ value: T) {
3440
self.handle = Box(value)
3541
}
36-
public var value: T { return handle.value }
42+
43+
@differentiable(
44+
vjp: _vjpValue
45+
where T : Differentiable, T == T.AllDifferentiableVariables,
46+
T == T.TangentVector
47+
)
48+
public var value: T {
49+
get { handle.value }
50+
set { handle.value = newValue }
51+
}
3752
}
3853

3954
extension Tracked : ExpressibleByFloatLiteral where T : ExpressibleByFloatLiteral {
@@ -123,17 +138,37 @@ extension Tracked : Differentiable
123138
public typealias TangentVector = Tracked<T.TangentVector>
124139
}
125140

126-
@differentiable(vjp: _vjpAdd)
127-
public func + (_ a: Tracked<Float>, _ b: Tracked<Float>) -> Tracked<Float> {
128-
return Tracked<Float>(a.value + b.value)
141+
extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
142+
T == T.TangentVector
143+
{
144+
@usableFromInline
145+
internal static func _vjpInit(_ value: T)
146+
-> (value: Self, pullback: (Self.TangentVector) -> (T.TangentVector)) {
147+
return (Tracked(value), { v in v.value })
148+
}
149+
150+
@usableFromInline
151+
internal func _vjpValue() -> (T, (T.TangentVector) -> Self.TangentVector) {
152+
return (value, { v in Tracked(v) })
153+
}
129154
}
130155

131-
@usableFromInline
132-
func _vjpAdd(_ a: Tracked<Float>, _ b: Tracked<Float>)
133-
-> (Tracked<Float>, (Tracked<Float>) -> (Tracked<Float>, Tracked<Float>)) {
134-
return (Tracked<Float>(a.value + b.value), { v in
135-
return (v, v)
136-
})
156+
extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
157+
T == T.TangentVector
158+
{
159+
@usableFromInline
160+
@differentiating(+)
161+
internal static func _vjpAdd(lhs: Self, rhs: Self)
162+
-> (value: Self, pullback: (Self) -> (Self, Self)) {
163+
return (lhs + rhs, { v in (v, v) })
164+
}
165+
166+
@usableFromInline
167+
@differentiating(-)
168+
internal static func _vjpSubtract(lhs: Self, rhs: Self)
169+
-> (value: Self, pullback: (Self) -> (Self, Self)) {
170+
return (lhs - rhs, { v in (v, .zero - v) })
171+
}
137172
}
138173

139174
// Differential operators for `Tracked<Float>`.
@@ -151,4 +186,20 @@ public extension Differentiable {
151186
) -> (TangentVector, T.TangentVector) {
152187
return self.pullback(at: x, in: f)(1)
153188
}
189+
190+
@inlinable
191+
func valueWithGradient(
192+
in f: @differentiable (Self) -> Tracked<Float>
193+
) -> (value: Tracked<Float>, gradient: TangentVector) {
194+
let (y, pb) = self.valueWithPullback(in: f)
195+
return (y, pb(1))
196+
}
197+
198+
@inlinable
199+
func valueWithGradient<T : Differentiable>(
200+
at x: T, in f: @differentiable (Self, T) -> Tracked<Float>
201+
) -> (value: Tracked<Float>, gradient: (TangentVector, T.TangentVector)) {
202+
let (y, pb) = self.valueWithPullback(at: x, in: f)
203+
return (y, pb(1))
204+
}
154205
}

test/AutoDiff/leakchecking.swift

Lines changed: 111 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
// RUN: %target-run-simple-swift-control-flow-differentiation
22
// REQUIRES: executable_test
33

4-
// A test that we can properly differentiate types that require refcounting.
4+
// Test differentiation-related memory leaks.
55

66
import StdlibUnittest
77
import DifferentiationUnittest
88

99
var LeakCheckingTests = TestSuite("LeakChecking")
1010

1111
/// Execute body, check expected leak count, and reset global leak count.
12-
func testWithLeakChecking(expectedLeakCount: Int = 0, _ body: () -> Void) {
12+
func testWithLeakChecking(
13+
expectedLeakCount: Int = 0, file: String = #file, line: UInt = #line,
14+
_ body: () -> Void
15+
) {
1316
body()
14-
expectEqual(expectedLeakCount, _GlobalLeakCount.count, "Leak detected.")
17+
expectEqual(
18+
expectedLeakCount, _GlobalLeakCount.count, "Leak detected.",
19+
file: file, line: line)
1520
_GlobalLeakCount.count = 0
1621
}
1722

@@ -23,27 +28,122 @@ struct ExampleLeakModel : Differentiable {
2328
}
2429
}
2530

31+
struct FloatPair : Differentiable & AdditiveArithmetic {
32+
var first, second: Tracked<Float>
33+
init(_ first: Tracked<Float>, _ second: Tracked<Float>) {
34+
self.first = first
35+
self.second = second
36+
}
37+
}
38+
39+
struct Pair<T : Differentiable, U : Differentiable> : Differentiable
40+
where T == T.AllDifferentiableVariables, T == T.TangentVector,
41+
U == U.AllDifferentiableVariables, U == U.TangentVector
42+
{
43+
var first: Tracked<T>
44+
var second: Tracked<U>
45+
init(_ first: Tracked<T>, _ second: Tracked<U>) {
46+
self.first = first
47+
self.second = second
48+
}
49+
}
50+
2651
LeakCheckingTests.test("BasicVarLeakChecking") {
27-
do {
52+
testWithLeakChecking {
53+
var model = ExampleLeakModel()
54+
let x: Tracked<Float> = 1.0
55+
_ = model.gradient(at: x) { m, x in m.applied(to: x) }
56+
}
57+
58+
testWithLeakChecking {
59+
var model = ExampleLeakModel()
60+
let x: Tracked<Float> = 1.0
61+
62+
_ = model.gradient { m in m.applied(to: x) }
63+
for _ in 0..<10 {
64+
_ = model.gradient { m in m.applied(to: x) }
65+
}
66+
}
67+
68+
testWithLeakChecking {
69+
var model = ExampleLeakModel()
70+
var x: Tracked<Float> = 1.0
71+
_ = model.gradient { m in
72+
x = x + x
73+
var y = x + Tracked<Float>(x.value)
74+
return m.applied(to: y)
75+
}
76+
}
77+
78+
// TODO: Fix memory leak.
79+
testWithLeakChecking(expectedLeakCount: 1) {
2880
var model = ExampleLeakModel()
2981
let x: Tracked<Float> = 1.0
30-
let _ = model.gradient(at: x) { m, x in m.applied(to: x) }
82+
_ = model.gradient { m in
83+
var model = m
84+
// Next line causes leak.
85+
model.bias = x
86+
return model.applied(to: x)
87+
}
3188
}
32-
expectEqual(0, _GlobalLeakCount.count, "Leak detected.")
3389
}
3490

3591
LeakCheckingTests.test("ControlFlow") {
36-
// TODO: Add more `var` + control flow tests.
37-
// Porting tests from test/AutoDiff/control_flow.swift requires more support
38-
// for `Tracked<Float>`.
92+
// FIXME: Fix control flow AD memory leaks.
93+
// See related FIXME comments in adjoint value/buffer propagation in
94+
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
95+
testWithLeakChecking(expectedLeakCount: 105) {
96+
func cond_nestedtuple_var(_ x: Tracked<Float>) -> Tracked<Float> {
97+
// Convoluted function returning `x + x`.
98+
var y = (x + x, x - x)
99+
var z = (y, x)
100+
if x > 0 {
101+
var w = (x, x)
102+
y.0 = w.1
103+
y.1 = w.0
104+
z.0.0 = z.0.0 - y.0
105+
z.0.1 = z.0.1 + y.0
106+
} else {
107+
z = ((y.0 - x, y.1 + x), x)
108+
}
109+
return y.0 + y.1 - z.0.0 + z.0.1
110+
}
111+
expectEqual((8, 2), Tracked<Float>(4).valueWithGradient(in: cond_nestedtuple_var))
112+
expectEqual((-20, 2), Tracked<Float>(-10).valueWithGradient(in: cond_nestedtuple_var))
113+
expectEqual((-2674, 2), Tracked<Float>(-1337).valueWithGradient(in: cond_nestedtuple_var))
114+
}
115+
116+
// FIXME: Fix control flow AD memory leaks.
117+
// See related FIXME comments in adjoint value/buffer propagation in
118+
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
119+
testWithLeakChecking(expectedLeakCount: 379) {
120+
func cond_nestedstruct_var(_ x: Tracked<Float>) -> Tracked<Float> {
121+
// Convoluted function returning `x + x`.
122+
var y = FloatPair(x + x, x - x)
123+
var z = Pair(Tracked(y), x)
124+
if x > 0 {
125+
var w = FloatPair(x, x)
126+
y.first = w.second
127+
y.second = w.first
128+
z.first = Tracked(FloatPair(z.first.value.first - y.first,
129+
z.first.value.second + y.first))
130+
} else {
131+
z = Pair(Tracked(FloatPair(y.first - x, y.second + x)), x)
132+
}
133+
return y.first + y.second - z.first.value.first + z.first.value.second
134+
}
135+
expectEqual((8, 2), Tracked<Float>(4).valueWithGradient(in: cond_nestedstruct_var))
136+
expectEqual((-20, 2), Tracked<Float>(-10).valueWithGradient(in: cond_nestedstruct_var))
137+
expectEqual((-2674, 2), Tracked<Float>(-1337).valueWithGradient(in: cond_nestedstruct_var))
138+
}
39139

40140
// FIXME: Fix control flow AD memory leaks.
41141
// See related FIXME comments in adjoint value/buffer propagation in
42142
// lib/SILOptimizer/Mandatory/Differentiation.cpp.
43143
testWithLeakChecking(expectedLeakCount: 9) {
44144
var model = ExampleLeakModel()
45145
let x: Tracked<Float> = 1.0
46-
let _ = model.gradient(at: x) { m, x in
146+
_ = model.gradient(at: x) { m, x in
47147
let result: Tracked<Float>
48148
if x > 0 {
49149
result = m.applied(to: x)
@@ -60,7 +160,7 @@ LeakCheckingTests.test("ControlFlow") {
60160
testWithLeakChecking(expectedLeakCount: 14) {
61161
var model = ExampleLeakModel()
62162
let x: Tracked<Float> = 1.0
63-
let _ = model.gradient(at: x) { m, x in
163+
_ = model.gradient(at: x) { m, x in
64164
var result: Tracked<Float> = x
65165
if x > 0 {
66166
result = result + m.applied(to: x)

0 commit comments

Comments
 (0)