@@ -49,11 +49,43 @@ struct Pair<T : Differentiable, U : Differentiable> : Differentiable
49
49
}
50
50
51
51
LeakCheckingTests . test ( " BasicVarLeakChecking " ) {
52
- testWithLeakChecking ( expectedLeakCount : 0 ) {
52
+ testWithLeakChecking {
53
53
var model = ExampleLeakModel ( )
54
54
let x : Tracked < Float > = 1.0
55
55
_ = model. gradient ( at: x) { m, x in m. applied ( to: x) }
56
56
}
57
+
58
+ testWithLeakChecking {
59
+ var model = ExampleLeakModel ( )
60
+ let x : Tracked < Float > = 1.0
61
+
62
+ _ = model. gradient { m in m. applied ( to: x) }
63
+ for _ in 0 ..< 10 {
64
+ _ = model. gradient { m in m. applied ( to: x) }
65
+ }
66
+ }
67
+
68
+ testWithLeakChecking {
69
+ var model = ExampleLeakModel ( )
70
+ var x : Tracked < Float > = 1.0
71
+ _ = model. gradient { m in
72
+ x = x + x
73
+ var y = x + Tracked < Float > ( x. value)
74
+ return m. applied ( to: y)
75
+ }
76
+ }
77
+
78
+ // TODO: Fix memory leak.
79
+ testWithLeakChecking ( expectedLeakCount: 1 ) {
80
+ var model = ExampleLeakModel ( )
81
+ let x : Tracked < Float > = 1.0
82
+ _ = model. gradient { m in
83
+ var model = m
84
+ // Next line causes leak.
85
+ model. bias = x
86
+ return model. applied ( to: x)
87
+ }
88
+ }
57
89
}
58
90
59
91
LeakCheckingTests . test ( " ControlFlow " ) {
@@ -111,7 +143,7 @@ LeakCheckingTests.test("ControlFlow") {
111
143
testWithLeakChecking ( expectedLeakCount: 9 ) {
112
144
var model = ExampleLeakModel ( )
113
145
let x : Tracked < Float > = 1.0
114
- let _ = model. gradient ( at: x) { m, x in
146
+ _ = model. gradient ( at: x) { m, x in
115
147
let result : Tracked < Float >
116
148
if x > 0 {
117
149
result = m. applied ( to: x)
@@ -128,7 +160,7 @@ LeakCheckingTests.test("ControlFlow") {
128
160
testWithLeakChecking ( expectedLeakCount: 14 ) {
129
161
var model = ExampleLeakModel ( )
130
162
let x : Tracked < Float > = 1.0
131
- let _ = model. gradient ( at: x) { m, x in
163
+ _ = model. gradient ( at: x) { m, x in
132
164
var result : Tracked < Float > = x
133
165
if x > 0 {
134
166
result = result + m. applied ( to: x)
0 commit comments