@@ -184,36 +184,31 @@ extension Array where Element: Differentiable {
184
184
func _vjpSubscript( index: Int ) -> (
185
185
value: Element , pullback: ( Element . TangentVector ) -> TangentVector
186
186
) {
187
- func pullback( _ gradientIn : Element . TangentVector ) -> TangentVector {
188
- var gradientOut = [ Element . TangentVector] (
187
+ func pullback( _ v : Element . TangentVector ) -> TangentVector {
188
+ var dSelf = [ Element . TangentVector] (
189
189
repeating: . zero,
190
190
count: count)
191
- gradientOut [ index] = gradientIn
192
- return TangentVector ( gradientOut )
191
+ dSelf [ index] = v
192
+ return TangentVector ( dSelf )
193
193
}
194
194
return ( self [ index] , pullback)
195
195
}
196
196
197
197
@usableFromInline
198
198
@derivative ( of: + )
199
- static func _vjpConcatenate( _ lhs: [ Element ] , _ rhs: [ Element ] ) -> (
200
- value: [ Element ] ,
199
+ static func _vjpConcatenate( _ lhs: Self , _ rhs: Self ) -> (
200
+ value: Self ,
201
201
pullback: ( TangentVector ) -> ( TangentVector , TangentVector )
202
202
) {
203
- func pullback( _ gradientIn: TangentVector ) -> ( TangentVector , TangentVector )
204
- {
203
+ func pullback( _ v: TangentVector ) -> ( TangentVector , TangentVector ) {
205
204
precondition (
206
- gradientIn . base. count == lhs. count + rhs. count,
205
+ v . base. count == lhs. count + rhs. count,
207
206
" + should receive gradient with count equal to sum of operand "
208
- + " counts, but counts are: gradient \( gradientIn . base. count) , "
207
+ + " counts, but counts are: gradient \( v . base. count) , "
209
208
+ " lhs \( lhs. count) , rhs \( rhs. count) " )
210
209
return (
211
- TangentVector (
212
- [ Element . TangentVector] (
213
- gradientIn. base [ 0 ..< lhs. count] ) ) ,
214
- TangentVector (
215
- [ Element . TangentVector] (
216
- gradientIn. base [ lhs. count... ] ) )
210
+ TangentVector ( [ Element . TangentVector] ( v. base [ 0 ..< lhs. count] ) ) ,
211
+ TangentVector ( [ Element . TangentVector] ( v. base [ lhs. count... ] ) )
217
212
)
218
213
}
219
214
return ( lhs + rhs, pullback)
@@ -227,8 +222,11 @@ extension Array where Element: Differentiable {
227
222
value: Void , pullback: ( inout TangentVector ) -> Element . TangentVector
228
223
) {
229
224
let appendedElementIndex = count
230
- defer { append ( element) }
231
- return ( ( ) , { dself in dself. base [ appendedElementIndex] } )
225
+ append ( element)
226
+ return ( ( ) , { v in
227
+ defer { v. base. removeLast ( ) }
228
+ return v. base [ appendedElementIndex]
229
+ } )
232
230
}
233
231
234
232
@usableFromInline
@@ -288,6 +286,37 @@ extension Array where Element: Differentiable {
288
286
// Differentiable higher order functions for collections
289
287
//===----------------------------------------------------------------------===//
290
288
289
+ extension Array where Element: Differentiable {
290
+ @inlinable
291
+ @differentiable ( wrt: self )
292
+ public func differentiableMap< Result: Differentiable > (
293
+ _ body: @differentiable ( Element ) -> Result
294
+ ) -> [ Result ] {
295
+ map ( body)
296
+ }
297
+
298
+ @inlinable
299
+ @derivative ( of: differentiableMap)
300
+ internal func _vjpDifferentiableMap< Result: Differentiable > (
301
+ _ body: @differentiable ( Element ) -> Result
302
+ ) -> (
303
+ value: [ Result ] ,
304
+ pullback: ( Array < Result > . TangentVector ) -> Array . TangentVector
305
+ ) {
306
+ var values : [ Result ] = [ ]
307
+ var pullbacks : [ ( Result . TangentVector ) -> Element . TangentVector ] = [ ]
308
+ for x in self {
309
+ let ( y, pb) = valueWithPullback ( at: x, in: body)
310
+ values. append ( y)
311
+ pullbacks. append ( pb)
312
+ }
313
+ func pullback( _ tans: Array < Result > . TangentVector ) -> Array . TangentVector {
314
+ . init( zip ( tans. base, pullbacks) . map { tan, pb in pb ( tan) } )
315
+ }
316
+ return ( value: values, pullback: pullback)
317
+ }
318
+ }
319
+
291
320
extension Array where Element: Differentiable {
292
321
@inlinable
293
322
@differentiable ( wrt: ( self , initialResult) )
@@ -336,34 +365,3 @@ extension Array where Element: Differentiable {
336
365
)
337
366
}
338
367
}
339
-
340
- extension Array where Element: Differentiable {
341
- @inlinable
342
- @differentiable ( wrt: self )
343
- public func differentiableMap< Result: Differentiable > (
344
- _ body: @differentiable ( Element ) -> Result
345
- ) -> [ Result ] {
346
- map ( body)
347
- }
348
-
349
- @inlinable
350
- @derivative ( of: differentiableMap)
351
- internal func _vjpDifferentiableMap< Result: Differentiable > (
352
- _ body: @differentiable ( Element ) -> Result
353
- ) -> (
354
- value: [ Result ] ,
355
- pullback: ( Array < Result > . TangentVector ) -> Array . TangentVector
356
- ) {
357
- var values : [ Result ] = [ ]
358
- var pullbacks : [ ( Result . TangentVector ) -> Element . TangentVector ] = [ ]
359
- for x in self {
360
- let ( y, pb) = valueWithPullback ( at: x, in: body)
361
- values. append ( y)
362
- pullbacks. append ( pb)
363
- }
364
- func pullback( _ tans: Array < Result > . TangentVector ) -> Array . TangentVector {
365
- . init( zip ( tans. base, pullbacks) . map { tan, pb in pb ( tan) } )
366
- }
367
- return ( value: values, pullback: pullback)
368
- }
369
- }
0 commit comments