Skip to content

Rewrite some of the AD tests with Tracked<Float> #27733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions test/AutoDiff/currying.swift
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
// RUN: %target-run-simple-swift

import StdlibUnittest
import DifferentiationUnittest

var CurryingAutodiffTests = TestSuite("CurryingAutodiff")

CurryingAutodiffTests.test("StructMember") {
CurryingAutodiffTests.testWithLeakChecking("StructMember") {
struct A {
@differentiable(wrt: (value))
func v(_ value: Float) -> Float { return value * value }
func v(_ value: Tracked<Float>) -> Tracked<Float> { return value * value }
}

let a = A()
// This implicitly constructs a function (A) -> (Float) -> Float
// This implicitly constructs a function (A) -> (Tracked<Float>) -> Tracked<Float>
// which gets called with a:
let g: @differentiable (Float) -> Float = a.v
let g: @differentiable (Tracked<Float>) -> Tracked<Float> = a.v


expectEqual(6.0, Float(3.0).gradient(in: g))
expectEqual(6.0, Tracked<Float>(3.0).gradient(in: g))
}

runAllTests()
35 changes: 18 additions & 17 deletions test/AutoDiff/derivative_registration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,70 +2,71 @@
// REQUIRES: executable_test

import StdlibUnittest
import DifferentiationUnittest

var DerivativeRegistrationTests = TestSuite("DerivativeRegistration")

@_semantics("autodiff.opaque")
func unary(x: Float) -> Float {
func unary(x: Tracked<Float>) -> Tracked<Float> {
return x
}
@differentiating(unary)
func _vjpUnary(x: Float) -> (value: Float, pullback: (Float) -> Float) {
func _vjpUnary(x: Tracked<Float>) -> (value: Tracked<Float>, pullback: (Tracked<Float>) -> Tracked<Float>) {
return (value: x, pullback: { v in v })
}
DerivativeRegistrationTests.test("UnaryFreeFunction") {
DerivativeRegistrationTests.testWithLeakChecking("UnaryFreeFunction") {
expectEqual(1, gradient(at: 3.0, in: unary))
}

@_semantics("autodiff.opaque")
func multiply(_ x: Float, _ y: Float) -> Float {
func multiply(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
return x * y
}
@differentiating(multiply)
func _vjpMultiply(_ x: Float, _ y: Float)
-> (value: Float, pullback: (Float) -> (Float, Float)) {
func _vjpMultiply(_ x: Tracked<Float>, _ y: Tracked<Float>)
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (Tracked<Float>, Tracked<Float>)) {
return (x * y, { v in (v * y, v * x) })
}
DerivativeRegistrationTests.test("BinaryFreeFunction") {
DerivativeRegistrationTests.testWithLeakChecking("BinaryFreeFunction") {
expectEqual((3.0, 2.0), gradient(at: 2.0, 3.0, in: { x, y in multiply(x, y) }))
}

struct Wrapper : Differentiable {
var float: Float
var float: Tracked<Float>
}

extension Wrapper {
@_semantics("autodiff.opaque")
static func multiply(_ x: Float, _ y: Float) -> Float {
static func multiply(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
return x * y
}

@differentiating(multiply)
static func _vjpMultiply(_ x: Float, _ y: Float)
-> (value: Float, pullback: (Float) -> (Float, Float)) {
static func _vjpMultiply(_ x: Tracked<Float>, _ y: Tracked<Float>)
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (Tracked<Float>, Tracked<Float>)) {
return (x * y, { v in (v * y, v * x) })
}
}
DerivativeRegistrationTests.test("StaticMethod") {
DerivativeRegistrationTests.testWithLeakChecking("StaticMethod") {
expectEqual((3.0, 2.0), gradient(at: 2.0, 3.0, in: { x, y in Wrapper.multiply(x, y) }))
}

extension Wrapper {
@_semantics("autodiff.opaque")
func multiply(_ x: Float) -> Float {
func multiply(_ x: Tracked<Float>) -> Tracked<Float> {
return float * x
}

@differentiating(multiply)
func _vjpMultiply(_ x: Float)
-> (value: Float, pullback: (Float) -> (Wrapper.TangentVector, Float)) {
func _vjpMultiply(_ x: Tracked<Float>)
-> (value: Tracked<Float>, pullback: (Tracked<Float>) -> (Wrapper.TangentVector, Tracked<Float>)) {
return (float * x, { v in
(Wrapper.TangentVector(float: v * x), v * self.float)
})
}
}
DerivativeRegistrationTests.test("InstanceMethod") {
let x: Float = 2
DerivativeRegistrationTests.testWithLeakChecking("InstanceMethod") {
let x: Tracked<Float> = 2
let wrapper = Wrapper(float: 3)
let (𝛁wrapper, 𝛁x) = wrapper.gradient(at: x) { wrapper, x in wrapper.multiply(x) }
expectEqual(Wrapper.TangentVector(float: 2), 𝛁wrapper)
Expand Down
68 changes: 35 additions & 33 deletions test/AutoDiff/protocol_requirement_autodiff.swift
Original file line number Diff line number Diff line change
@@ -1,48 +1,49 @@
// RUN: %target-run-simple-swift

import StdlibUnittest
import DifferentiationUnittest

var ProtocolRequirementAutodiffTests = TestSuite("ProtocolRequirementAutodiff")

// MARK: - Func requirements.

protocol DiffReq : Differentiable {
@differentiable(wrt: (self, x))
func f(_ x: Float) -> Float
func f(_ x: Tracked<Float>) -> Tracked<Float>
}

extension DiffReq where TangentVector : AdditiveArithmetic {
@inline(never) // Prevent specialization, to test all witness code.
func gradF(at x: Float) -> (Self.TangentVector, Float) {
func gradF(at x: Tracked<Float>) -> (Self.TangentVector, Tracked<Float>) {
return (valueWithPullback(at: x) { s, x in s.f(x) }).1(1)
}
}

struct Quadratic : DiffReq, VectorProtocol {
struct Quadratic : DiffReq, AdditiveArithmetic {
typealias TangentVector = Quadratic

@differentiable
let a: Float
let a: Tracked<Float>

@differentiable
let b: Float
let b: Tracked<Float>

@differentiable
let c: Float
let c: Tracked<Float>

init(_ a: Float, _ b: Float, _ c: Float) {
init(_ a: Tracked<Float>, _ b: Tracked<Float>, _ c: Tracked<Float>) {
self.a = a
self.b = b
self.c = c
}

@differentiable(wrt: (self, x))
func f(_ x: Float) -> Float {
func f(_ x: Tracked<Float>) -> Tracked<Float> {
return a * x * x + b * x + c
}
}

ProtocolRequirementAutodiffTests.test("func") {
ProtocolRequirementAutodiffTests.testWithLeakChecking("func") {
expectEqual((Quadratic(0, 0, 1), 12), Quadratic(11, 12, 13).gradF(at: 0))
expectEqual((Quadratic(1, 1, 1), 2 * 11 + 12),
Quadratic(11, 12, 13).gradF(at: 1))
Expand All @@ -54,48 +55,48 @@ ProtocolRequirementAutodiffTests.test("func") {

protocol FunctionsOfX: Differentiable {
@differentiable
init(x: Float)
init(x: Tracked<Float>)

@differentiable
var x: Float { get }
var x: Tracked<Float> { get }

@differentiable
var y: Float { get }
var y: Tracked<Float> { get }

@differentiable
var z: Float { get }
var z: Tracked<Float> { get }

@differentiable
subscript() -> Float { get }
subscript() -> Tracked<Float> { get }
}

struct TestFunctionsOfX: FunctionsOfX {
@differentiable
init(x: Float) {
init(x: Tracked<Float>) {
self.x = x
self.y = x * x
}

/// x = x
var x: Float
var x: Tracked<Float>

/// y = x * x
var y: Float
var y: Tracked<Float>

/// z = x * x + x
var z: Float {
var z: Tracked<Float> {
return y + x
}

@differentiable
subscript() -> Float {
subscript() -> Tracked<Float> {
return z
}
}

@inline(never) // Prevent specialization, to test all witness code.
func derivatives<F: FunctionsOfX>(at x: Float, in: F.Type)
-> (Float, Float, Float, Float)
func derivatives<F: FunctionsOfX>(at x: Tracked<Float>, in: F.Type)
-> (Tracked<Float>, Tracked<Float>, Tracked<Float>, Tracked<Float>)
{
let dxdx = gradient(at: x) { x in F(x: x).x }
let dydx = gradient(at: x) { x in F(x: x).y }
Expand All @@ -104,7 +105,7 @@ func derivatives<F: FunctionsOfX>(at x: Float, in: F.Type)
return (dxdx, dydx, dzdx, dsubscriptdx)
}

ProtocolRequirementAutodiffTests.test("constructor, accessor, subscript") {
ProtocolRequirementAutodiffTests.testWithLeakChecking("constructor, accessor, subscript") {
expectEqual(
(1.0, 4.0, 5.0, 5.0),
derivatives(at: 2.0, in: TestFunctionsOfX.self))
Expand All @@ -114,11 +115,11 @@ ProtocolRequirementAutodiffTests.test("constructor, accessor, subscript") {

protocol P : Differentiable {
@differentiable(wrt: (x, y))
func foo(_ x: Float, _ y: Double) -> Float
func foo(_ x: Tracked<Float>, _ y: Double) -> Tracked<Float>
}
struct S : P {
@differentiable(wrt: (x, y))
func foo(_ x: Float, _ y: Double) -> Float {
func foo(_ x: Tracked<Float>, _ y: Double) -> Tracked<Float> {
return x
}
}
Expand All @@ -127,23 +128,24 @@ struct S : P {

public protocol Distribution {
associatedtype Value
func logProbability(of value: Value) -> Float
func logProbability(of value: Value) -> Tracked<Float>
}

public protocol DifferentiableDistribution: Differentiable, Distribution {
@differentiable(wrt: self)
func logProbability(of value: Value) -> Float
func logProbability(of value: Value) -> Tracked<Float>
}

struct Foo: DifferentiableDistribution {
@differentiable(wrt: self)
func logProbability(of value: Float) -> Float {
func logProbability(of value: Tracked<Float>) -> Tracked<Float> {
.zero
}
}

@differentiable
func blah<T: DifferentiableDistribution>(_ x: T) -> Float where T.Value: AdditiveArithmetic {
func blah<T: DifferentiableDistribution>(_ x: T) -> Tracked<Float>
where T.Value: AdditiveArithmetic {
x.logProbability(of: .zero)
}

Expand All @@ -152,29 +154,29 @@ public protocol DoubleDifferentiableDistribution: DifferentiableDistribution
where Value: Differentiable {
@differentiable(wrt: self)
@differentiable(wrt: (self, value))
func logProbability(of value: Value) -> Float
func logProbability(of value: Value) -> Tracked<Float>
}

@differentiable
func blah2<T: DoubleDifferentiableDistribution>(_ x: T, _ value: T.Value) -> Float
func blah2<T: DoubleDifferentiableDistribution>(_ x: T, _ value: T.Value) -> Tracked<Float>
where T.Value: AdditiveArithmetic {
x.logProbability(of: value)
}

protocol DifferentiableFoo {
associatedtype T: Differentiable
@differentiable(wrt: x)
func foo(_ x: T) -> Float
func foo(_ x: T) -> Tracked<Float>
}

protocol MoreDifferentiableFoo: Differentiable, DifferentiableFoo {
@differentiable(wrt: (self, x))
func foo(_ x: T) -> Float
func foo(_ x: T) -> Tracked<Float>
}

struct MoreDifferentiableFooStruct: MoreDifferentiableFoo {
@differentiable(wrt: (self, x))
func foo(_ x: Float) -> Float {
func foo(_ x: Tracked<Float>) -> Tracked<Float> {
x
}
}
Expand Down
7 changes: 4 additions & 3 deletions test/AutoDiff/repeated_calls.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
// REQUIRES: executable_test

import StdlibUnittest
import DifferentiationUnittest

var RepeatedCallsTests = TestSuite("RepeatedCalls")

RepeatedCallsTests.test("Repeat") {
func mul2(_ x: Float) -> Float {
RepeatedCallsTests.testWithLeakChecking("Repeat") {
func mul2(_ x: Tracked<Float>) -> Tracked<Float> {
return 2 * x
}
func mul4(_ x: Float) -> Float {
func mul4(_ x: Tracked<Float>) -> Tracked<Float> {
return mul2(mul2(x))
}
expectEqual(4, gradient(at: 0, in: mul4))
Expand Down
15 changes: 8 additions & 7 deletions test/AutoDiff/separate_tangent_type.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,41 @@ import Darwin.C
#else
import Glibc
#endif
import DifferentiationUnittest

var SeparateTangentTypeTests = TestSuite("SeparateTangentType")

struct DifferentiableSubset : Differentiable {
@differentiable(wrt: self)
var w: Float
var w: Tracked<Float>
@differentiable(wrt: self)
var b: Float
var b: Tracked<Float>
@noDerivative var flag: Bool

struct TangentVector : Differentiable, VectorProtocol {
typealias TangentVector = DifferentiableSubset.TangentVector
var w: Float
var b: Float
var w: Tracked<Float>
var b: Tracked<Float>
}
mutating func move(along v: TangentVector) {
w.move(along: v.w)
b.move(along: v.b)
}
}

SeparateTangentTypeTests.test("Trivial") {
SeparateTangentTypeTests.testWithLeakChecking("Trivial") {
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
let pb = pullback(at: x) { x in x }
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
}

SeparateTangentTypeTests.test("Initialization") {
SeparateTangentTypeTests.testWithLeakChecking("Initialization") {
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
let pb = pullback(at: x) { x in DifferentiableSubset(w: 1, b: 2, flag: true) }
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
}

SeparateTangentTypeTests.test("SomeArithmetics") {
SeparateTangentTypeTests.testWithLeakChecking("SomeArithmetics") {
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) }
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
Expand Down
Loading