Skip to content

[AutoDiff] Add AnyDifferentiable. #30928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,105 @@
//
//===----------------------------------------------------------------------===//
//
// This file defines `AnyDerivative`, a type-erased wrapper for
// `Differentiable.TangentVector` associated type implementations.
// This file defines type-erased wrappers for `Differentiable`-conforming types
// and `Differentiable.TangentVector` associated type implementations.
//
//===----------------------------------------------------------------------===//

import Swift

//===----------------------------------------------------------------------===//
// `AnyDifferentiable`
//===----------------------------------------------------------------------===//

internal protocol _AnyDifferentiableBox {
// `Differentiable` requirements.
mutating func _move(along direction: AnyDerivative)

/// The underlying base value, type-erased to `Any`.
var _typeErasedBase: Any { get }

/// Returns the underlying value unboxed to the given type, if possible.
func _unboxed<U: Differentiable>(to type: U.Type) -> U?
}

internal struct _ConcreteDifferentiableBox<T: Differentiable>: _AnyDifferentiableBox
{
/// The underlying base value.
var _base: T

init(_ base: T) {
self._base = base
}

/// The underlying base value, type-erased to `Any`.
var _typeErasedBase: Any {
return _base
}

func _unboxed<U: Differentiable>(to type: U.Type) -> U? {
return (self as? _ConcreteDifferentiableBox<U>)?._base
}

mutating func _move(along direction: AnyDerivative) {
guard
let directionBase =
direction.base as? T.TangentVector
else {
_derivativeTypeMismatch(T.self, type(of: direction.base))
}
_base.move(along: directionBase)
}
}

public struct AnyDifferentiable: Differentiable {
internal var _box: _AnyDifferentiableBox

internal init(_box: _AnyDifferentiableBox) {
self._box = _box
}

/// The underlying base value.
public var base: Any {
return _box._typeErasedBase
}

/// Creates a type-erased derivative from the given derivative.
@differentiable
public init<T: Differentiable>(_ base: T) {
self._box = _ConcreteDifferentiableBox<T>(base)
}

@inlinable
@derivative(of: init)
internal static func _vjpInit<T: Differentiable>(
_ base: T
) -> (value: AnyDifferentiable, pullback: (AnyDerivative) -> T.TangentVector)
{
return (AnyDifferentiable(base), { v in v.base as! T.TangentVector })
}

@inlinable
@derivative(of: init)
internal static func _jvpInit<T: Differentiable>(
_ base: T
) -> (
value: AnyDifferentiable, differential: (T.TangentVector) -> AnyDerivative
) {
return (AnyDifferentiable(base), { dbase in AnyDerivative(dbase) })
}

public typealias TangentVector = AnyDerivative

public mutating func move(along direction: TangentVector) {
_box._move(along: direction)
}
}

//===----------------------------------------------------------------------===//
// `AnyDerivative`
//===----------------------------------------------------------------------===//

@usableFromInline
internal protocol _AnyDerivativeBox {
// `Equatable` requirements (implied by `AdditiveArithmetic`).
Expand Down Expand Up @@ -47,18 +139,6 @@ extension _AnyDerivativeBox {
}
}

@inline(never)
@usableFromInline
internal func _derivativeTypeMismatch(
_ x: Any.Type, _ y: Any.Type, file: StaticString = #file, line: UInt = #line
) -> Never {
preconditionFailure(
"""
Derivative type mismatch: \
\(String(reflecting: x)) and \(String(reflecting: y))
""", file: file, line: line)
}

@frozen
@usableFromInline
internal struct _ConcreteDerivativeBox<T>: _AnyDerivativeBox
Expand Down Expand Up @@ -290,3 +370,19 @@ public struct AnyDerivative: Differentiable & AdditiveArithmetic {
_box._move(along: direction._box)
}
}

//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//

@inline(never)
@usableFromInline
internal func _derivativeTypeMismatch(
_ x: Any.Type, _ y: Any.Type, file: StaticString = #file, line: UInt = #line
) -> Never {
preconditionFailure(
"""
Derivative type mismatch: \
\(String(reflecting: x)) and \(String(reflecting: y))
""", file: file, line: line)
}
2 changes: 1 addition & 1 deletion stdlib/public/Differentiation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPE
Differentiable.swift
DifferentialOperators.swift
DifferentiationUtilities.swift
AnyDerivative.swift
AnyDifferentiable.swift
ArrayDifferentiation.swift

GYB_SOURCES
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import _Differentiation
import StdlibUnittest

var AnyDerivativeTests = TestSuite("AnyDerivative")
var TypeErasureTests = TestSuite("DifferentiableTypeErasure")

struct Vector: Differentiable {
struct Vector: Differentiable, Equatable {
var x, y: Float
}
struct Generic<T: Differentiable>: Differentiable {
struct Generic<T: Differentiable & Equatable>: Differentiable, Equatable {
var x: T
}

Expand All @@ -22,28 +22,44 @@ extension AnyDerivative {
}
}

AnyDerivativeTests.test("Vector") {
var tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
tan += tan
expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan)
expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan + tan)
expectEqual(AnyDerivative(Vector.TangentVector(x: 0, y: 0)), tan - tan)
expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan.moved(along: tan))
expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan)
TypeErasureTests.test("AnyDifferentiable operations") {
do {
var any = AnyDifferentiable(Vector(x: 1, y: 1))
let tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
any.move(along: tan)
expectEqual(Vector(x: 2, y: 2), any.base as? Vector)
}

do {
var any = AnyDifferentiable(Generic<Float>(x: 1))
let tan = AnyDerivative(Generic<Float>.TangentVector(x: 1))
any.move(along: tan)
expectEqual(Generic<Float>(x: 2), any.base as? Generic<Float>)
}
}

AnyDerivativeTests.test("Generic") {
var tan = AnyDerivative(Generic<Float>.TangentVector(x: 1))
let cotan = AnyDerivative(Generic<Float>.TangentVector(x: 1))
tan += tan
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 2)), tan)
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 4)), tan + tan)
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 0)), tan - tan)
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 4)), tan.moved(along: tan))
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 1)), cotan)
TypeErasureTests.test("AnyDerivative operations") {
do {
var tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
tan += tan
expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan)
expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan + tan)
expectEqual(AnyDerivative(Vector.TangentVector(x: 0, y: 0)), tan - tan)
expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan.moved(along: tan))
expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan)
}

do {
var tan = AnyDerivative(Generic<Float>.TangentVector(x: 1))
tan += tan
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 2)), tan)
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 4)), tan + tan)
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 0)), tan - tan)
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 4)), tan.moved(along: tan))
}
}

AnyDerivativeTests.test("Zero") {
TypeErasureTests.test("AnyDerivative.zero") {
var zero = AnyDerivative.zero
zero += zero
zero -= zero
Expand All @@ -66,7 +82,17 @@ AnyDerivativeTests.test("Zero") {
expectEqual(tan, tan)
}

AnyDerivativeTests.test("Casting") {
TypeErasureTests.test("AnyDifferentiable casting") {
let any = AnyDifferentiable(Vector(x: 1, y: 1))
expectEqual(Vector(x: 1, y: 1), any.base as? Vector)

let genericAny = AnyDifferentiable(Generic<Float>(x: 1))
expectEqual(Generic<Float>(x: 1),
genericAny.base as? Generic<Float>)
expectEqual(nil, genericAny.base as? Generic<Double>)
}

TypeErasureTests.test("AnyDerivative casting") {
let tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
expectEqual(Vector.TangentVector(x: 1, y: 1), tan.base as? Vector.TangentVector)

Expand All @@ -81,7 +107,34 @@ AnyDerivativeTests.test("Casting") {
expectEqual(nil, zero.base as? Generic<Float>.TangentVector)
}

AnyDerivativeTests.test("Derivatives") {
TypeErasureTests.test("AnyDifferentiable differentiation") {
// Test `AnyDifferentiable` initializer.
do {
let x: Float = 3
let v = AnyDerivative(Float(2))
let 𝛁x = pullback(at: x, in: { AnyDifferentiable($0) })(v)
let expectedVJP: Float = 2
expectEqual(expectedVJP, 𝛁x)
}

do {
let x = Vector(x: 4, y: 5)
let v = AnyDerivative(Vector.TangentVector(x: 2, y: 2))
let 𝛁x = pullback(at: x, in: { AnyDifferentiable($0) })(v)
let expectedVJP = Vector.TangentVector(x: 2, y: 2)
expectEqual(expectedVJP, 𝛁x)
}

do {
let x = Generic<Double>(x: 4)
let v = AnyDerivative(Generic<Double>.TangentVector(x: 2))
let 𝛁x = pullback(at: x, in: { AnyDifferentiable($0) })(v)
let expectedVJP = Generic<Double>.TangentVector(x: 2)
expectEqual(expectedVJP, 𝛁x)
}
}

TypeErasureTests.test("AnyDerivative differentiation") {
// Test `AnyDerivative` operations.
func tripleSum(_ x: AnyDerivative, _ y: AnyDerivative) -> AnyDerivative {
let sum = x + y
Expand Down