Skip to content

[AutoDiff] [stdlib] Made arrays differentiable #23183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Apr 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4a5001c
Made arrays differentiable.
eaplatanios Mar 8, 2019
24f026e
Addressed Marc's comment.
eaplatanios Mar 8, 2019
18d3154
Added documentation for the array differentiable conformance.
eaplatanios Mar 8, 2019
4711081
Addressed some of Richard's comments.
eaplatanios Mar 8, 2019
d5e1ed9
Had to make 'zipLongest' public due to its use in '@inlinable' functi…
eaplatanios Mar 8, 2019
6ec4b8c
Minor bug fix.
eaplatanios Mar 8, 2019
0aa8024
Addressed Richard's review comments.
eaplatanios Mar 8, 2019
668890b
Addressed some of the code review feedback.
eaplatanios Mar 9, 2019
37ebb0c
Fixed the array differentiable conformance derivation.
eaplatanios Mar 9, 2019
8a278d9
Addressed Marc's comment about 'moved' and 'tangentVector'.
eaplatanios Mar 9, 2019
7a56ef8
Style fixes.
eaplatanios Mar 9, 2019
f4b8b8c
Added some tests for array differentiation.
eaplatanios Mar 9, 2019
b2614c1
Added a simple array identity differentiation test.
eaplatanios Mar 9, 2019
cc6dc5d
Merge remote-tracking branch 'upstream/tensorflow' into array-differe…
eaplatanios Mar 28, 2019
da9c160
Merge remote-tracking branch 'upstream/tensorflow' into array-differe…
eaplatanios Mar 29, 2019
6fcd912
Added an 'Array.subscript' VJP.
eaplatanios Mar 29, 2019
0346b54
Merge branch 'array-differentiable' of https://github.com/eaplatanios…
Apr 17, 2019
2b656d7
update to use a single generic type for all array diffble assoc types
Apr 17, 2019
7d98750
Merge branch 'tensorflow' into eaplatanios-array-differentiable
Apr 17, 2019
d3365ea
add missing comment
Apr 17, 2019
7f8f273
address comments
Apr 17, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 218 additions & 0 deletions stdlib/public/core/Array.swift
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,8 @@ extension Array: RandomAccessCollection, MutableCollection {
/// bridged `NSArray` instance as its storage, in which case writing is
/// O(*n*), where *n* is the length of the array.
@inlinable
// SWIFT_ENABLE_TENSORFLOW
@differentiable(wrt: self, vjp: _vjpSubscript where Element : Differentiable)
public subscript(index: Int) -> Element {
get {
// This call may be hoisted or eliminated by the optimizer. If
Expand Down Expand Up @@ -1301,6 +1303,8 @@ extension Array: RangeReplaceableCollection {
// operator in the same expression.
extension Array {
@inlinable
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpPlus where Element : Differentiable)
public static func + (lhs: Array, rhs: Array) -> Array {
var lhs = lhs
lhs.append(contentsOf: rhs)
Expand Down Expand Up @@ -1865,3 +1869,217 @@ internal struct _ArrayAnyHashableBox<Element: Hashable>
return true
}
}

// SWIFT_ENABLE_TENSORFLOW
extension Array where Element : Differentiable {
/// The view of an array as the differentiable product manifold of `Element`
/// multiplied with itself `count` times.
@_fixed_layout
public struct DifferentiableView : Differentiable & KeyPathIterable {
private var _base: [Element]

/// The viewed array.
// I'm implementing this as a computed property instead of directly
// exposing `_base` because the `@differentiable` annotation does not make
// the stored property actually differentiable. I think this is a bug.
// Maybe it's related to `@_fixed_layout`?
// TODO: Determine if that is a bug, and fix.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. How about renaming the current base to _base, and renaming array/_vjpArray to base/_vjpBase?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

public var base: [Element] {
@differentiable(wrt: self, vjp: _vjpBase)
get { return _base }
_modify { yield &_base }
}

@usableFromInline
func _vjpBase() ->
([Element], (Array<Element>.CotangentVector) -> CotangentVector) {
return (base, { $0 })
}

/// Creates a differentiable view of the given array.
@differentiable(wrt: base, vjp: _vjpInit)
public init(_ base: [Element]) { self._base = base }

@usableFromInline
static func _vjpInit(_ base: [Element]) ->
(Array.DifferentiableView, (CotangentVector) -> CotangentVector) {
return (Array.DifferentiableView(base), { $0 })
}

// MARK: - Differentiable conformance.

public typealias TangentVector =
Array<Element.TangentVector>.DifferentiableView
public typealias CotangentVector =
Array<Element.CotangentVector>.DifferentiableView
public typealias AllDifferentiableVariables =
Array<Element.AllDifferentiableVariables>.DifferentiableView

public var allDifferentiableVariables: AllDifferentiableVariables {
get {
return AllDifferentiableVariables(
base.map { $0.allDifferentiableVariables })
}
set {
precondition(
base.count == newValue.base.count,
"cannot set Array.DifferentiableView.AllDifferentiableVariables " +
"with count \(base.count) to " +
"Array.DifferentiableView.AllDifferentiableVariables with " +
"different count \(newValue.base.count)")
for i in base.indices {
base[i].allDifferentiableVariables = newValue.base[i]
}
}
}

public func moved(along direction: TangentVector) -> DifferentiableView {
precondition(
base.count == direction.base.count,
"cannot move Array.DifferentiableView with count \(base.count) along " +
"direction with different count \(direction.base.count)")
return DifferentiableView(
zip(base, direction.base).map { $0.moved(along: $1) })
}

public func tangentVector(from cotangentVector: CotangentVector) ->
TangentVector {
precondition(
base.count == cotangentVector.base.count,
"cannot use Array.DifferentiableView with count \(base.count) to " +
"get tangentVector from cotangentVector with different count " +
"\(cotangentVector.base.count)")
return TangentVector(zip(base, cotangentVector.base).map {
(selfElement, cotangentVectorElement) in
selfElement.tangentVector(from: cotangentVectorElement)
})
}
}
}

extension Array.DifferentiableView : Equatable where Element : Equatable {
public static func == (
lhs: Array.DifferentiableView,
rhs: Array.DifferentiableView
) -> Bool {
return lhs.base == rhs.base
}
}

/// Makes `Array.DifferentiableView` additive as the product space.
///
/// Note that `Array.DifferentiableView([])` is the zero in the product spaces
/// of all counts.
extension Array.DifferentiableView : AdditiveArithmetic
where Element : AdditiveArithmetic {

public static var zero: Array.DifferentiableView {
return Array.DifferentiableView([])
}

public static func + (
lhs: Array.DifferentiableView,
rhs: Array.DifferentiableView
) -> Array.DifferentiableView {
precondition(
lhs.base.count == 0 || rhs.base.count == 0 ||
lhs.base.count == rhs.base.count,
"cannot add Array.DifferentiableViews with different counts: " +
"\(lhs.base.count) and \(rhs.base.count)")
if lhs.base.count == 0 {
return rhs
}
if rhs.base.count == 0 {
return lhs
}
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(+))
}

public static func - (
lhs: Array.DifferentiableView,
rhs: Array.DifferentiableView
) -> Array.DifferentiableView {
precondition(
lhs.base.count == 0 || rhs.base.count == 0 ||
lhs.base.count == rhs.base.count,
"cannot subtract Array.DifferentiableViews with different counts: " +
"\(lhs.base.count) and \(rhs.base.count)")
if lhs.base.count == 0 {
return rhs
}
if rhs.base.count == 0 {
return lhs
}
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(-))
}
}

/// Makes `Array` differentiable as the product manifold of `Element`
/// multiplied with itself `count` times.
extension Array : Differentiable where Element : Differentiable {
// In an ideal world, `TangentVector`, `CotangentVector`, and
// `AllDifferentiableVariables` would all be `Array`s. Unfortunately, we
// can't conform `Array` to `AdditiveArithmetic` for `TangentVector` and
// `CotangentVector`, because `Array` already has a static `+` method with
// different semantics from `AdditiveArithmetic` `+`. So we use
// `Array.DifferentiableView` for all these associated types.
public typealias TangentVector =
Array<Element.TangentVector>.DifferentiableView
public typealias CotangentVector =
Array<Element.CotangentVector>.DifferentiableView
public typealias AllDifferentiableVariables =
Array<Element.AllDifferentiableVariables>.DifferentiableView

public var allDifferentiableVariables: AllDifferentiableVariables {
get {
return DifferentiableView(self).allDifferentiableVariables
}
set {
var view = DifferentiableView(self)
view.allDifferentiableVariables = newValue
self = view.base
}
}

public func moved(along direction: TangentVector) -> Array {
return DifferentiableView(self).moved(along: direction).base
}

public func tangentVector(from cotangentVector: CotangentVector) ->
TangentVector {
return DifferentiableView(self).tangentVector(from: cotangentVector)
}
}

extension Array where Element : Differentiable {
public func _vjpSubscript(index: Int) ->
(Element, (Element.CotangentVector) -> CotangentVector)
{
func pullback(_ gradientIn: Element.CotangentVector) -> CotangentVector {
var gradientOut = Array<Element.CotangentVector>(
repeating: .zero,
count: count)
gradientOut[index] = gradientIn
return CotangentVector(gradientOut)
}
return (self[index], pullback)
}

public static func _vjpPlus(_ lhs: [Element], _ rhs: [Element]) ->
([Element], (CotangentVector) -> (CotangentVector, CotangentVector)) {
func pullback(_ gradientIn: CotangentVector) ->
(CotangentVector, CotangentVector) {
precondition(
gradientIn.base.count == lhs.count + rhs.count,
"+ should receive gradient with count equal to sum of operand " +
"counts, but counts are: gradient \(gradientIn.base.count), " +
"lhs \(lhs.count), rhs \(rhs.count)")
return (
CotangentVector(Array<Element.CotangentVector>(
gradientIn.base[0..<lhs.count])),
CotangentVector(Array<Element.CotangentVector>(
gradientIn.base[lhs.count...])))
}
return (lhs + rhs, pullback)
}
}
102 changes: 102 additions & 0 deletions test/AutoDiff/array.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// RUN: %target-run-simple-swift

import StdlibUnittest

var ArrayAutodiffTests = TestSuite("ArrayAutodiff")

typealias FloatArrayGrad = Array<Float>.CotangentVector

ArrayAutodiffTests.test("ArrayIdentity") {
func arrayIdentity(_ x: [Float]) -> [Float] {
return x
}

let backprop = pullback(at: [5, 6, 7, 8], in: arrayIdentity)
expectEqual(
FloatArrayGrad([1, 2, 3, 4]),
backprop(FloatArrayGrad([1, 2, 3, 4])))
}

ArrayAutodiffTests.test("ArraySubscript") {
func sumFirstThree(_ array: [Float]) -> Float {
return array[0] + array[1] + array[2]
}

expectEqual(
FloatArrayGrad([1, 1, 1, 0, 0, 0]),
gradient(at: [2, 3, 4, 5, 6, 7], in: sumFirstThree))
}

ArrayAutodiffTests.test("ArrayConcat") {
struct TwoArrays : Differentiable {
let a: [Float]
let b: [Float]
}

func sumFirstThreeConcatted(_ arrs: TwoArrays) -> Float {
let c = arrs.a + arrs.b
return c[0] + c[1] + c[2]
}

expectEqual(
TwoArrays.CotangentVector(
a: FloatArrayGrad([1, 1]),
b: FloatArrayGrad([1, 0])),
gradient(
at: TwoArrays(a: [0, 0], b: [0, 0]),
in: sumFirstThreeConcatted))
expectEqual(
TwoArrays.CotangentVector(
a: FloatArrayGrad([1, 1, 1, 0]),
b: FloatArrayGrad([0, 0])),
gradient(
at: TwoArrays(a: [0, 0, 0, 0], b: [0, 0]),
in: sumFirstThreeConcatted))
expectEqual(
TwoArrays.CotangentVector(
a: FloatArrayGrad([]),
b: FloatArrayGrad([1, 1, 1, 0])),
gradient(
at: TwoArrays(a: [], b: [0, 0, 0, 0]),
in: sumFirstThreeConcatted))
}

ArrayAutodiffTests.test("Array.DifferentiableView.init") {
@differentiable
func constructView(_ x: [Float]) -> Array<Float>.DifferentiableView {
return Array<Float>.DifferentiableView(x)
}

let backprop = pullback(at: [5, 6, 7, 8], in: constructView)
expectEqual(
FloatArrayGrad([1, 2, 3, 4]),
backprop(FloatArrayGrad([1, 2, 3, 4])))
}

ArrayAutodiffTests.test("Array.DifferentiableView.base") {
@differentiable
func accessBase(_ x: Array<Float>.DifferentiableView) -> [Float] {
return x.base
}

let backprop = pullback(
at: Array<Float>.DifferentiableView([5, 6, 7, 8]),
in: accessBase)
expectEqual(
FloatArrayGrad([1, 2, 3, 4]),
backprop(FloatArrayGrad([1, 2, 3, 4])))
}

ArrayAutodiffTests.test("Array.DifferentiableView : KeyPathIterable") {
struct Container : KeyPathIterable {
let a: Array<Float>.DifferentiableView
}
let container = Container(a: Array<Float>.DifferentiableView([1, 2, 3]))
expectEqual(
[1, 2, 3],
container.recursivelyAllKeyPaths(to: Float.self).map {
container[keyPath: $0]
})
}

runAllTests()
4 changes: 0 additions & 4 deletions test/Sema/struct_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,6 @@ struct GenericConstrained<T> {
extension GenericConstrained : Differentiable
where T : Differentiable {}

// TF-161: Test conditional conformance of `Array`.
// expected-warning @+1 {{stored property '_buffer' has no derivative because it does not conform to 'Differentiable'; add '@noDerivative' to make it explicit}}
extension Array : Differentiable where Element : Differentiable {}

struct TF_260<T : Differentiable> : Differentiable & AdditiveArithmetic {
var x: T.CotangentVector
}
Expand Down