@@ -693,8 +693,6 @@ extension Array: RandomAccessCollection, MutableCollection {
693
693
/// bridged `NSArray` instance as its storage, in which case writing is
694
694
/// O(*n*), where *n* is the length of the array.
695
695
@inlinable
696
- // SWIFT_ENABLE_TENSORFLOW
697
- @differentiable ( wrt: self , vjp: _vjpSubscript where Element : Differentiable)
698
696
public subscript( index: Int ) -> Element {
699
697
get {
700
698
// This call may be hoisted or eliminated by the optimizer. If
@@ -873,8 +871,6 @@ extension Array: RangeReplaceableCollection {
873
871
/// `repeating` parameter. `count` must be zero or greater.
874
872
@inlinable
875
873
@_semantics ( " array.init " )
876
- @differentiable ( wrt: repeatedValue, vjp: _vjpInit ( repeating: count: )
877
- where Element: Differentiable)
878
874
public init ( repeating repeatedValue: Element , count: Int ) {
879
875
var p : UnsafeMutablePointer < Element >
880
876
( self , p) = Array . _allocateUninitialized ( count)
@@ -1330,8 +1326,6 @@ extension Array: RangeReplaceableCollection {
1330
1326
// operator in the same expression.
1331
1327
extension Array {
1332
1328
@inlinable
1333
- // SWIFT_ENABLE_TENSORFLOW
1334
- @differentiable ( vjp: _vjpPlus where Element : Differentiable)
1335
1329
public static func + ( lhs: Array , rhs: Array ) -> Array {
1336
1330
var lhs = lhs
1337
1331
lhs. append ( contentsOf: rhs)
@@ -1930,24 +1924,24 @@ extension Array {
1930
1924
extension Array . DifferentiableView : Differentiable where Element : Differentiable {
1931
1925
/// The viewed array.
1932
1926
public var base : [ Element ] {
1933
- @differentiable ( wrt: self , vjp: _vjpBase)
1934
1927
get { return _base }
1935
1928
_modify { yield & _base }
1936
1929
}
1937
1930
1938
1931
@usableFromInline
1932
+ @derivative ( of: base)
1939
1933
func _vjpBase( ) ->
1940
- ( [ Element ] , ( Array < Element > . TangentVector ) -> TangentVector ) {
1934
+ ( value : [ Element ] , pullback : ( Array < Element > . TangentVector ) -> TangentVector ) {
1941
1935
return ( base, { $0 } )
1942
1936
}
1943
1937
1944
1938
/// Creates a differentiable view of the given array.
1945
- @differentiable ( wrt: base, vjp: _vjpInit)
1946
1939
public init ( _ base: [ Element ] ) { self . _base = base }
1947
1940
1948
1941
@usableFromInline
1942
+ @derivative ( of: init ( _: ) )
1949
1943
static func _vjpInit( _ base: [ Element ] ) ->
1950
- ( Array . DifferentiableView , ( TangentVector ) -> TangentVector ) {
1944
+ ( value : Array . DifferentiableView , pullback : ( TangentVector ) -> TangentVector ) {
1951
1945
return ( Array . DifferentiableView ( base) , { $0 } )
1952
1946
}
1953
1947
@@ -2088,8 +2082,9 @@ extension Array : EuclideanDifferentiable
2088
2082
}
2089
2083
2090
2084
extension Array where Element : Differentiable {
2085
+ @derivative ( of: subscript)
2091
2086
public func _vjpSubscript( index: Int ) ->
2092
- ( Element , ( Element . TangentVector ) -> TangentVector )
2087
+ ( value : Element , pullback : ( Element . TangentVector ) -> TangentVector )
2093
2088
{
2094
2089
func pullback( _ gradientIn: Element . TangentVector ) -> TangentVector {
2095
2090
var gradientOut = Array < Element . TangentVector > (
@@ -2101,8 +2096,9 @@ extension Array where Element : Differentiable {
2101
2096
return ( self [ index] , pullback)
2102
2097
}
2103
2098
2099
+ @derivative ( of: + )
2104
2100
public static func _vjpPlus( _ lhs: [ Element ] , _ rhs: [ Element ] ) ->
2105
- ( [ Element ] , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
2101
+ ( value : [ Element ] , pullback : ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
2106
2102
func pullback( _ gradientIn: TangentVector ) ->
2107
2103
( TangentVector , TangentVector ) {
2108
2104
precondition (
@@ -2122,6 +2118,7 @@ extension Array where Element : Differentiable {
2122
2118
2123
2119
extension Array where Element: Differentiable {
2124
2120
@usableFromInline
2121
+ @derivative ( of: init ( repeating: count: ) )
2125
2122
static func _vjpInit( repeating repeatedValue: Element , count: Int ) -> (
2126
2123
value: Self , pullback: ( TangentVector ) -> Element . TangentVector
2127
2124
) {
0 commit comments