Skip to content

[AD] Rewrite some of the tests with Tracked<Float> (2) #27767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 18, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,27 @@ extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector
}
}

// Differential operators for `Tracked<Float>`.
public func gradient(
at x: Tracked<Float>, in f: @differentiable (Tracked<Float>) -> Tracked<Float>
) -> Tracked<Float> {
return pullback(at: x, in: f)(1)
// Differential operators for `Tracked<T>`.

public func gradient<T, U>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Differential operators in this file can actually be deleted once Tracked conforms to FloatingPoint. Were you working in this direction?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC,Tracked: FloatingPoint conformance caused * operator lookup for @differentiating(*) to become ambiguous.

That seems workaround-able by using @differentiable(vjp: ...) for now. We should probably investigate fixing @differentiating(*) ambiguous lookup (and @differentiating original declaration lookup for initializers/subscripts/properties) sometime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxwei, yes it would be good to make Tracked conform to FloatingPoint, but have issues that Dan mentions. I had already filed a bug: https://bugs.swift.org/browse/TF-926

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single gradient operator definitions for (...) -> Tracked<T> functions is great!
For reference: TF-927 tracks the confusing type-checker error we encountered this morning.

at x: T, in f: @differentiable (T) -> Tracked<U>
) -> T.TangentVector
where U : FloatingPoint, U.TangentVector == U {
return pullback(at: x, in: f)(Tracked<U>(1))
}

public func gradient<T, U, R>(
at x: T, _ y: U, in f: @differentiable (T, U) -> Tracked<R>
) -> (T.TangentVector, U.TangentVector)
where R : FloatingPoint, R.TangentVector == R {
return pullback(at: x, y, in: f)(Tracked<R>(1))
}

public func gradient(
at x: Tracked<Float>, _ y: Tracked<Float>,
in f: @differentiable (Tracked<Float>, Tracked<Float>) -> Tracked<Float>
) -> (Tracked<Float>, Tracked<Float>) {
return pullback(at: x, y, in: f)(1)
public func valueWithGradient<T, U : FloatingPoint>(
at x: T, in f: @differentiable (T) -> Tracked<U>
) -> (value: Tracked<U>, gradient: T.TangentVector) {
let (y, pullback) = valueWithPullback(at: x, in: f)
return (y, pullback(Tracked<U>(1)))
}

public extension Differentiable {
Expand Down
4 changes: 2 additions & 2 deletions test/AutoDiff/array.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ArrayAutoDiffTests.test("ArraySubscript") {
gradient(at: [2, 3, 4, 5, 6, 7], in: sumFirstThree))
}

ArrayAutoDiffTests.test("ArrayLiteral") {
ArrayAutoDiffTests.testWithLeakChecking("ArrayLiteral") {
func twoElementLiteral(_ x: Tracked<Float>, _ y: Tracked<Float>) -> [Tracked<Float>] {
return [x, y]
}
Expand Down Expand Up @@ -90,7 +90,7 @@ ArrayAutoDiffTests.test("ArrayConcat") {
in: sumFirstThreeConcatted))
}

ArrayAutoDiffTests.test("Array.init(repeating:count:)") {
ArrayAutoDiffTests.testWithLeakChecking("Array.init(repeating:count:)") {
@differentiable
func repeating(_ x: Tracked<Float>) -> [Tracked<Float>] {
Array(repeating: x, count: 10)
Expand Down
Loading