Skip to content

Commit e9bfd14

Browse files
bgogulrxwei
andauthored
[AD] Rewrite some of the tests with Tracked<Float> (2) (#27767)
* Leak checking in superset_adjoint.swift * Leak checking in method.swift. * Add a valueWithGradient function. * Leak checking class_method.swift. * Leak checking for existential.swift. * Simplify versions of gradient. * Leak checking for e2e_differentiable_property.swift * Leak checking for custom_derivatives.swift * Some leak checking in array.swift * Some leak checking in forward_mode_runtime.swift * Formatting. * Formatting update to test/AutoDiff/custom_derivatives.swift Co-Authored-By: Richard Wei <[email protected]>
1 parent 5c03b64 commit e9bfd14

File tree

9 files changed

+273
-209
lines changed

9 files changed

+273
-209
lines changed

stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,18 +257,27 @@ extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector
257257
}
258258
}
259259

260-
// Differential operators for `Tracked<Float>`.
261-
public func gradient(
262-
at x: Tracked<Float>, in f: @differentiable (Tracked<Float>) -> Tracked<Float>
263-
) -> Tracked<Float> {
264-
return pullback(at: x, in: f)(1)
260+
// Differential operators for `Tracked<T>`.
261+
262+
public func gradient<T, U>(
263+
at x: T, in f: @differentiable (T) -> Tracked<U>
264+
) -> T.TangentVector
265+
where U : FloatingPoint, U.TangentVector == U {
266+
return pullback(at: x, in: f)(Tracked<U>(1))
267+
}
268+
269+
public func gradient<T, U, R>(
270+
at x: T, _ y: U, in f: @differentiable (T, U) -> Tracked<R>
271+
) -> (T.TangentVector, U.TangentVector)
272+
where R : FloatingPoint, R.TangentVector == R {
273+
return pullback(at: x, y, in: f)(Tracked<R>(1))
265274
}
266275

267-
public func gradient(
268-
at x: Tracked<Float>, _ y: Tracked<Float>,
269-
in f: @differentiable (Tracked<Float>, Tracked<Float>) -> Tracked<Float>
270-
) -> (Tracked<Float>, Tracked<Float>) {
271-
return pullback(at: x, y, in: f)(1)
276+
public func valueWithGradient<T, U : FloatingPoint>(
277+
at x: T, in f: @differentiable (T) -> Tracked<U>
278+
) -> (value: Tracked<U>, gradient: T.TangentVector) {
279+
let (y, pullback) = valueWithPullback(at: x, in: f)
280+
return (y, pullback(Tracked<U>(1)))
272281
}
273282

274283
public extension Differentiable {

test/AutoDiff/array.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ ArrayAutoDiffTests.test("ArraySubscript") {
2828
gradient(at: [2, 3, 4, 5, 6, 7], in: sumFirstThree))
2929
}
3030

31-
ArrayAutoDiffTests.test("ArrayLiteral") {
31+
ArrayAutoDiffTests.testWithLeakChecking("ArrayLiteral") {
3232
func twoElementLiteral(_ x: Tracked<Float>, _ y: Tracked<Float>) -> [Tracked<Float>] {
3333
return [x, y]
3434
}
@@ -90,7 +90,7 @@ ArrayAutoDiffTests.test("ArrayConcat") {
9090
in: sumFirstThreeConcatted))
9191
}
9292

93-
ArrayAutoDiffTests.test("Array.init(repeating:count:)") {
93+
ArrayAutoDiffTests.testWithLeakChecking("Array.init(repeating:count:)") {
9494
@differentiable
9595
func repeating(_ x: Tracked<Float>) -> [Tracked<Float>] {
9696
Array(repeating: x, count: 10)

0 commit comments

Comments
 (0)