Skip to content

Commit d4691b8

Browse files
committed
Add leak check tests.
Expose memory leak unrelated to control flow.
1 parent 9023577 commit d4691b8

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

test/AutoDiff/leakchecking.swift

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,43 @@ struct Pair<T : Differentiable, U : Differentiable> : Differentiable
4949
}
5050

5151
LeakCheckingTests.test("BasicVarLeakChecking") {
52-
testWithLeakChecking(expectedLeakCount: 0) {
52+
testWithLeakChecking {
5353
var model = ExampleLeakModel()
5454
let x: Tracked<Float> = 1.0
5555
_ = model.gradient(at: x) { m, x in m.applied(to: x) }
5656
}
57+
58+
testWithLeakChecking {
59+
var model = ExampleLeakModel()
60+
let x: Tracked<Float> = 1.0
61+
62+
_ = model.gradient { m in m.applied(to: x) }
63+
for _ in 0..<10 {
64+
_ = model.gradient { m in m.applied(to: x) }
65+
}
66+
}
67+
68+
testWithLeakChecking {
69+
var model = ExampleLeakModel()
70+
var x: Tracked<Float> = 1.0
71+
_ = model.gradient { m in
72+
x = x + x
73+
var y = x + Tracked<Float>(x.value)
74+
return m.applied(to: y)
75+
}
76+
}
77+
78+
// TODO: Fix memory leak.
79+
testWithLeakChecking(expectedLeakCount: 1) {
80+
var model = ExampleLeakModel()
81+
let x: Tracked<Float> = 1.0
82+
_ = model.gradient { m in
83+
var model = m
84+
// Next line causes leak.
85+
model.bias = x
86+
return model.applied(to: x)
87+
}
88+
}
5789
}
5890

5991
LeakCheckingTests.test("ControlFlow") {
@@ -111,7 +143,7 @@ LeakCheckingTests.test("ControlFlow") {
111143
testWithLeakChecking(expectedLeakCount: 9) {
112144
var model = ExampleLeakModel()
113145
let x: Tracked<Float> = 1.0
114-
let _ = model.gradient(at: x) { m, x in
146+
_ = model.gradient(at: x) { m, x in
115147
let result: Tracked<Float>
116148
if x > 0 {
117149
result = m.applied(to: x)
@@ -128,7 +160,7 @@ LeakCheckingTests.test("ControlFlow") {
128160
testWithLeakChecking(expectedLeakCount: 14) {
129161
var model = ExampleLeakModel()
130162
let x: Tracked<Float> = 1.0
131-
let _ = model.gradient(at: x) { m, x in
163+
_ = model.gradient(at: x) { m, x in
132164
var result: Tracked<Float> = x
133165
if x > 0 {
134166
result = result + m.applied(to: x)

0 commit comments

Comments
 (0)