Skip to content

Commit 2e690e2

Browse files
committed
[AutoDiff] NFC: garden array differentiation.
Use consistent variable naming. Reorganize code.
1 parent 8829a14 commit 2e690e2

File tree

2 files changed

+42
-47
lines changed

2 files changed

+42
-47
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 42 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -184,36 +184,31 @@ extension Array where Element: Differentiable {
184184
func _vjpSubscript(index: Int) -> (
185185
value: Element, pullback: (Element.TangentVector) -> TangentVector
186186
) {
187-
func pullback(_ gradientIn: Element.TangentVector) -> TangentVector {
188-
var gradientOut = [Element.TangentVector](
187+
func pullback(_ v: Element.TangentVector) -> TangentVector {
188+
var dSelf = [Element.TangentVector](
189189
repeating: .zero,
190190
count: count)
191-
gradientOut[index] = gradientIn
192-
return TangentVector(gradientOut)
191+
dSelf[index] = v
192+
return TangentVector(dSelf)
193193
}
194194
return (self[index], pullback)
195195
}
196196

197197
@usableFromInline
198198
@derivative(of: +)
199-
static func _vjpConcatenate(_ lhs: [Element], _ rhs: [Element]) -> (
200-
value: [Element],
199+
static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> (
200+
value: Self,
201201
pullback: (TangentVector) -> (TangentVector, TangentVector)
202202
) {
203-
func pullback(_ gradientIn: TangentVector) -> (TangentVector, TangentVector)
204-
{
203+
func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) {
205204
precondition(
206-
gradientIn.base.count == lhs.count + rhs.count,
205+
v.base.count == lhs.count + rhs.count,
207206
"+ should receive gradient with count equal to sum of operand "
208-
+ "counts, but counts are: gradient \(gradientIn.base.count), "
207+
+ "counts, but counts are: gradient \(v.base.count), "
209208
+ "lhs \(lhs.count), rhs \(rhs.count)")
210209
return (
211-
TangentVector(
212-
[Element.TangentVector](
213-
gradientIn.base[0..<lhs.count])),
214-
TangentVector(
215-
[Element.TangentVector](
216-
gradientIn.base[lhs.count...]))
210+
TangentVector([Element.TangentVector](v.base[0..<lhs.count])),
211+
TangentVector([Element.TangentVector](v.base[lhs.count...]))
217212
)
218213
}
219214
return (lhs + rhs, pullback)
@@ -288,6 +283,37 @@ extension Array where Element: Differentiable {
288283
// Differentiable higher order functions for collections
289284
//===----------------------------------------------------------------------===//
290285

286+
extension Array where Element: Differentiable {
287+
@inlinable
288+
@differentiable(wrt: self)
289+
public func differentiableMap<Result: Differentiable>(
290+
_ body: @differentiable (Element) -> Result
291+
) -> [Result] {
292+
map(body)
293+
}
294+
295+
@inlinable
296+
@derivative(of: differentiableMap)
297+
internal func _vjpDifferentiableMap<Result: Differentiable>(
298+
_ body: @differentiable (Element) -> Result
299+
) -> (
300+
value: [Result],
301+
pullback: (Array<Result>.TangentVector) -> Array.TangentVector
302+
) {
303+
var values: [Result] = []
304+
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
305+
for x in self {
306+
let (y, pb) = valueWithPullback(at: x, in: body)
307+
values.append(y)
308+
pullbacks.append(pb)
309+
}
310+
func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector {
311+
.init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) })
312+
}
313+
return (value: values, pullback: pullback)
314+
}
315+
}
316+
291317
extension Array where Element: Differentiable {
292318
@inlinable
293319
@differentiable(wrt: (self, initialResult))
@@ -336,34 +362,3 @@ extension Array where Element: Differentiable {
336362
)
337363
}
338364
}
339-
340-
extension Array where Element: Differentiable {
341-
@inlinable
342-
@differentiable(wrt: self)
343-
public func differentiableMap<Result: Differentiable>(
344-
_ body: @differentiable (Element) -> Result
345-
) -> [Result] {
346-
map(body)
347-
}
348-
349-
@inlinable
350-
@derivative(of: differentiableMap)
351-
internal func _vjpDifferentiableMap<Result: Differentiable>(
352-
_ body: @differentiable (Element) -> Result
353-
) -> (
354-
value: [Result],
355-
pullback: (Array<Result>.TangentVector) -> Array.TangentVector
356-
) {
357-
var values: [Result] = []
358-
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
359-
for x in self {
360-
let (y, pb) = valueWithPullback(at: x, in: body)
361-
values.append(y)
362-
pullbacks.append(pb)
363-
}
364-
func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector {
365-
.init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) })
366-
}
367-
return (value: values, pullback: pullback)
368-
}
369-
}

0 commit comments

Comments
 (0)