@@ -693,6 +693,8 @@ extension Array: RandomAccessCollection, MutableCollection {
693
693
/// bridged `NSArray` instance as its storage, in which case writing is
694
694
/// O(*n*), where *n* is the length of the array.
695
695
@inlinable
696
+ // SWIFT_ENABLE_TENSORFLOW
697
+ @differentiable ( wrt: self , vjp: _vjpSubscript where Element : Differentiable)
696
698
public subscript( index: Int ) -> Element {
697
699
get {
698
700
// This call may be hoisted or eliminated by the optimizer. If
@@ -1301,6 +1303,8 @@ extension Array: RangeReplaceableCollection {
1301
1303
// operator in the same expression.
1302
1304
extension Array {
1303
1305
@inlinable
1306
+ // SWIFT_ENABLE_TENSORFLOW
1307
+ @differentiable ( vjp: _vjpPlus where Element : Differentiable)
1304
1308
public static func + ( lhs: Array , rhs: Array ) -> Array {
1305
1309
var lhs = lhs
1306
1310
lhs. append ( contentsOf: rhs)
@@ -1865,3 +1869,217 @@ internal struct _ArrayAnyHashableBox<Element: Hashable>
1865
1869
return true
1866
1870
}
1867
1871
}
1872
+
1873
+ // SWIFT_ENABLE_TENSORFLOW
1874
+ extension Array where Element : Differentiable {
1875
+ /// The view of an array as the differentiable product manifold of `Element`
1876
+ /// multiplied with itself `count` times.
1877
+ @_fixed_layout
1878
+ public struct DifferentiableView : Differentiable & KeyPathIterable {
1879
+ private var _base : [ Element ]
1880
+
1881
+ /// The viewed array.
1882
+ // I'm implementing this as a computed property instead of directly
1883
+ // exposing `_base` because the `@differentiable` annotation does not make
1884
+ // the stored property actually differentiable. I think this is a bug.
1885
+ // Maybe it's related to `@_fixed_layout`?
1886
+ // TODO: Determine if that is a bug, and fix.
1887
+ public var base : [ Element ] {
1888
+ @differentiable ( wrt: self , vjp: _vjpBase)
1889
+ get { return _base }
1890
+ _modify { yield & _base }
1891
+ }
1892
+
1893
+ @usableFromInline
1894
+ func _vjpBase( ) ->
1895
+ ( [ Element ] , ( Array < Element > . CotangentVector ) -> CotangentVector ) {
1896
+ return ( base, { $0 } )
1897
+ }
1898
+
1899
+ /// Creates a differentiable view of the given array.
1900
+ @differentiable ( wrt: base, vjp: _vjpInit)
1901
+ public init ( _ base: [ Element ] ) { self . _base = base }
1902
+
1903
+ @usableFromInline
1904
+ static func _vjpInit( _ base: [ Element ] ) ->
1905
+ ( Array . DifferentiableView , ( CotangentVector ) -> CotangentVector ) {
1906
+ return ( Array . DifferentiableView ( base) , { $0 } )
1907
+ }
1908
+
1909
+ // MARK: - Differentiable conformance.
1910
+
1911
+ public typealias TangentVector =
1912
+ Array < Element . TangentVector > . DifferentiableView
1913
+ public typealias CotangentVector =
1914
+ Array < Element . CotangentVector > . DifferentiableView
1915
+ public typealias AllDifferentiableVariables =
1916
+ Array < Element . AllDifferentiableVariables > . DifferentiableView
1917
+
1918
+ public var allDifferentiableVariables : AllDifferentiableVariables {
1919
+ get {
1920
+ return AllDifferentiableVariables (
1921
+ base. map { $0. allDifferentiableVariables } )
1922
+ }
1923
+ set {
1924
+ precondition (
1925
+ base. count == newValue. base. count,
1926
+ " cannot set Array.DifferentiableView.AllDifferentiableVariables " +
1927
+ " with count \( base. count) to " +
1928
+ " Array.DifferentiableView.AllDifferentiableVariables with " +
1929
+ " different count \( newValue. base. count) " )
1930
+ for i in base. indices {
1931
+ base [ i] . allDifferentiableVariables = newValue. base [ i]
1932
+ }
1933
+ }
1934
+ }
1935
+
1936
+ public func moved( along direction: TangentVector ) -> DifferentiableView {
1937
+ precondition (
1938
+ base. count == direction. base. count,
1939
+ " cannot move Array.DifferentiableView with count \( base. count) along " +
1940
+ " direction with different count \( direction. base. count) " )
1941
+ return DifferentiableView (
1942
+ zip ( base, direction. base) . map { $0. moved ( along: $1) } )
1943
+ }
1944
+
1945
+ public func tangentVector( from cotangentVector: CotangentVector ) ->
1946
+ TangentVector {
1947
+ precondition (
1948
+ base. count == cotangentVector. base. count,
1949
+ " cannot use Array.DifferentiableView with count \( base. count) to " +
1950
+ " get tangentVector from cotangentVector with different count " +
1951
+ " \( cotangentVector. base. count) " )
1952
+ return TangentVector ( zip ( base, cotangentVector. base) . map {
1953
+ ( selfElement, cotangentVectorElement) in
1954
+ selfElement. tangentVector ( from: cotangentVectorElement)
1955
+ } )
1956
+ }
1957
+ }
1958
+ }
1959
+
1960
+ extension Array . DifferentiableView : Equatable where Element : Equatable {
1961
+ public static func == (
1962
+ lhs: Array . DifferentiableView ,
1963
+ rhs: Array . DifferentiableView
1964
+ ) -> Bool {
1965
+ return lhs. base == rhs. base
1966
+ }
1967
+ }
1968
+
1969
+ /// Makes `Array.DifferentiableView` additive as the product space.
1970
+ ///
1971
+ /// Note that `Array.DifferentiableView([])` is the zero in the product spaces
1972
+ /// of all counts.
1973
+ extension Array . DifferentiableView : AdditiveArithmetic
1974
+ where Element : AdditiveArithmetic {
1975
+
1976
+ public static var zero : Array . DifferentiableView {
1977
+ return Array . DifferentiableView ( [ ] )
1978
+ }
1979
+
1980
+ public static func + (
1981
+ lhs: Array . DifferentiableView ,
1982
+ rhs: Array . DifferentiableView
1983
+ ) -> Array . DifferentiableView {
1984
+ precondition (
1985
+ lhs. base. count == 0 || rhs. base. count == 0 ||
1986
+ lhs. base. count == rhs. base. count,
1987
+ " cannot add Array.DifferentiableViews with different counts: " +
1988
+ " \( lhs. base. count) and \( rhs. base. count) " )
1989
+ if lhs. base. count == 0 {
1990
+ return rhs
1991
+ }
1992
+ if rhs. base. count == 0 {
1993
+ return lhs
1994
+ }
1995
+ return Array . DifferentiableView ( zip ( lhs. base, rhs. base) . map ( + ) )
1996
+ }
1997
+
1998
+ public static func - (
1999
+ lhs: Array . DifferentiableView ,
2000
+ rhs: Array . DifferentiableView
2001
+ ) -> Array . DifferentiableView {
2002
+ precondition (
2003
+ lhs. base. count == 0 || rhs. base. count == 0 ||
2004
+ lhs. base. count == rhs. base. count,
2005
+ " cannot subtract Array.DifferentiableViews with different counts: " +
2006
+ " \( lhs. base. count) and \( rhs. base. count) " )
2007
+ if lhs. base. count == 0 {
2008
+ return rhs
2009
+ }
2010
+ if rhs. base. count == 0 {
2011
+ return lhs
2012
+ }
2013
+ return Array . DifferentiableView ( zip ( lhs. base, rhs. base) . map ( - ) )
2014
+ }
2015
+ }
2016
+
2017
+ /// Makes `Array` differentiable as the product manifold of `Element`
2018
+ /// multiplied with itself `count` times.
2019
+ extension Array : Differentiable where Element : Differentiable {
2020
+ // In an ideal world, `TangentVector`, `CotangentVector`, and
2021
+ // `AllDifferentiableVariables` would all be `Array`s. Unfortunately, we
2022
+ // can't conform `Array` to `AdditiveArithmetic` for `TangentVector` and
2023
+ // `CotangentVector`, because `Array` already has a static `+` method with
2024
+ // different semantics from `AdditiveArithmetic` `+`. So we use
2025
+ // `Array.DifferentiableView` for all these associated types.
2026
+ public typealias TangentVector =
2027
+ Array < Element . TangentVector > . DifferentiableView
2028
+ public typealias CotangentVector =
2029
+ Array < Element . CotangentVector > . DifferentiableView
2030
+ public typealias AllDifferentiableVariables =
2031
+ Array < Element . AllDifferentiableVariables > . DifferentiableView
2032
+
2033
+ public var allDifferentiableVariables : AllDifferentiableVariables {
2034
+ get {
2035
+ return DifferentiableView ( self ) . allDifferentiableVariables
2036
+ }
2037
+ set {
2038
+ var view = DifferentiableView ( self )
2039
+ view. allDifferentiableVariables = newValue
2040
+ self = view. base
2041
+ }
2042
+ }
2043
+
2044
+ public func moved( along direction: TangentVector ) -> Array {
2045
+ return DifferentiableView ( self ) . moved ( along: direction) . base
2046
+ }
2047
+
2048
+ public func tangentVector( from cotangentVector: CotangentVector ) ->
2049
+ TangentVector {
2050
+ return DifferentiableView ( self ) . tangentVector ( from: cotangentVector)
2051
+ }
2052
+ }
2053
+
2054
+ extension Array where Element : Differentiable {
2055
+ public func _vjpSubscript( index: Int ) ->
2056
+ ( Element , ( Element . CotangentVector ) -> CotangentVector )
2057
+ {
2058
+ func pullback( _ gradientIn: Element . CotangentVector ) -> CotangentVector {
2059
+ var gradientOut = Array < Element . CotangentVector > (
2060
+ repeating: . zero,
2061
+ count: count)
2062
+ gradientOut [ index] = gradientIn
2063
+ return CotangentVector ( gradientOut)
2064
+ }
2065
+ return ( self [ index] , pullback)
2066
+ }
2067
+
2068
+ public static func _vjpPlus( _ lhs: [ Element ] , _ rhs: [ Element ] ) ->
2069
+ ( [ Element ] , ( CotangentVector ) -> ( CotangentVector , CotangentVector ) ) {
2070
+ func pullback( _ gradientIn: CotangentVector ) ->
2071
+ ( CotangentVector , CotangentVector ) {
2072
+ precondition (
2073
+ gradientIn. base. count == lhs. count + rhs. count,
2074
+ " + should receive gradient with count equal to sum of operand " +
2075
+ " counts, but counts are: gradient \( gradientIn. base. count) , " +
2076
+ " lhs \( lhs. count) , rhs \( rhs. count) " )
2077
+ return (
2078
+ CotangentVector ( Array < Element . CotangentVector > (
2079
+ gradientIn. base [ 0 ..< lhs. count] ) ) ,
2080
+ CotangentVector ( Array < Element . CotangentVector > (
2081
+ gradientIn. base [ lhs. count... ] ) ) )
2082
+ }
2083
+ return ( lhs + rhs, pullback)
2084
+ }
2085
+ }
0 commit comments