@@ -184,36 +184,31 @@ extension Array where Element: Differentiable {
184
184
func _vjpSubscript( index: Int ) -> (
185
185
value: Element , pullback: ( Element . TangentVector ) -> TangentVector
186
186
) {
187
- func pullback( _ gradientIn : Element . TangentVector ) -> TangentVector {
188
- var gradientOut = [ Element . TangentVector] (
187
+ func pullback( _ v : Element . TangentVector ) -> TangentVector {
188
+ var dSelf = [ Element . TangentVector] (
189
189
repeating: . zero,
190
190
count: count)
191
- gradientOut [ index] = gradientIn
192
- return TangentVector ( gradientOut )
191
+ dSelf [ index] = v
192
+ return TangentVector ( dSelf )
193
193
}
194
194
return ( self [ index] , pullback)
195
195
}
196
196
197
197
@usableFromInline
198
198
@derivative ( of: + )
199
- static func _vjpConcatenate( _ lhs: [ Element ] , _ rhs: [ Element ] ) -> (
200
- value: [ Element ] ,
199
+ static func _vjpConcatenate( _ lhs: Self , _ rhs: Self ) -> (
200
+ value: Self ,
201
201
pullback: ( TangentVector ) -> ( TangentVector , TangentVector )
202
202
) {
203
- func pullback( _ gradientIn: TangentVector ) -> ( TangentVector , TangentVector )
204
- {
203
+ func pullback( _ v: TangentVector ) -> ( TangentVector , TangentVector ) {
205
204
precondition (
206
- gradientIn . base. count == lhs. count + rhs. count,
205
+ v . base. count == lhs. count + rhs. count,
207
206
" + should receive gradient with count equal to sum of operand "
208
- + " counts, but counts are: gradient \( gradientIn . base. count) , "
207
+ + " counts, but counts are: gradient \( v . base. count) , "
209
208
+ " lhs \( lhs. count) , rhs \( rhs. count) " )
210
209
return (
211
- TangentVector (
212
- [ Element . TangentVector] (
213
- gradientIn. base [ 0 ..< lhs. count] ) ) ,
214
- TangentVector (
215
- [ Element . TangentVector] (
216
- gradientIn. base [ lhs. count... ] ) )
210
+ TangentVector ( [ Element . TangentVector] ( v. base [ 0 ..< lhs. count] ) ) ,
211
+ TangentVector ( [ Element . TangentVector] ( v. base [ lhs. count... ] ) )
217
212
)
218
213
}
219
214
return ( lhs + rhs, pullback)
@@ -288,6 +283,37 @@ extension Array where Element: Differentiable {
288
283
// Differentiable higher order functions for collections
289
284
//===----------------------------------------------------------------------===//
290
285
286
+ extension Array where Element: Differentiable {
287
+ @inlinable
288
+ @differentiable ( wrt: self )
289
+ public func differentiableMap< Result: Differentiable > (
290
+ _ body: @differentiable ( Element ) -> Result
291
+ ) -> [ Result ] {
292
+ map ( body)
293
+ }
294
+
295
+ @inlinable
296
+ @derivative ( of: differentiableMap)
297
+ internal func _vjpDifferentiableMap< Result: Differentiable > (
298
+ _ body: @differentiable ( Element ) -> Result
299
+ ) -> (
300
+ value: [ Result ] ,
301
+ pullback: ( Array < Result > . TangentVector ) -> Array . TangentVector
302
+ ) {
303
+ var values : [ Result ] = [ ]
304
+ var pullbacks : [ ( Result . TangentVector ) -> Element . TangentVector ] = [ ]
305
+ for x in self {
306
+ let ( y, pb) = valueWithPullback ( at: x, in: body)
307
+ values. append ( y)
308
+ pullbacks. append ( pb)
309
+ }
310
+ func pullback( _ tans: Array < Result > . TangentVector ) -> Array . TangentVector {
311
+ . init( zip ( tans. base, pullbacks) . map { tan, pb in pb ( tan) } )
312
+ }
313
+ return ( value: values, pullback: pullback)
314
+ }
315
+ }
316
+
291
317
extension Array where Element: Differentiable {
292
318
@inlinable
293
319
@differentiable ( wrt: ( self , initialResult) )
@@ -336,34 +362,3 @@ extension Array where Element: Differentiable {
336
362
)
337
363
}
338
364
}
339
-
340
- extension Array where Element: Differentiable {
341
- @inlinable
342
- @differentiable ( wrt: self )
343
- public func differentiableMap< Result: Differentiable > (
344
- _ body: @differentiable ( Element ) -> Result
345
- ) -> [ Result ] {
346
- map ( body)
347
- }
348
-
349
- @inlinable
350
- @derivative ( of: differentiableMap)
351
- internal func _vjpDifferentiableMap< Result: Differentiable > (
352
- _ body: @differentiable ( Element ) -> Result
353
- ) -> (
354
- value: [ Result ] ,
355
- pullback: ( Array < Result > . TangentVector ) -> Array . TangentVector
356
- ) {
357
- var values : [ Result ] = [ ]
358
- var pullbacks : [ ( Result . TangentVector ) -> Element . TangentVector ] = [ ]
359
- for x in self {
360
- let ( y, pb) = valueWithPullback ( at: x, in: body)
361
- values. append ( y)
362
- pullbacks. append ( pb)
363
- }
364
- func pullback( _ tans: Array < Result > . TangentVector ) -> Array . TangentVector {
365
- . init( zip ( tans. base, pullbacks) . map { tan, pb in pb ( tan) } )
366
- }
367
- return ( value: values, pullback: pullback)
368
- }
369
- }
0 commit comments