Skip to content

Commit 22a6926

Browse files
authored
---
yaml --- r: 340981 b: refs/heads/rxwei-patch-1 c: 73984ad h: refs/heads/master i: 340979: b881a59
1 parent 7d0e9e0 commit 22a6926

File tree

4 files changed

+116
-1
lines changed

4 files changed

+116
-1
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-08-18-a: b10b1fce14385faa6d44f6b933e95
10151015
refs/heads/rdar-43033749-fix-batch-mode-no-diags-swift-5.0-branch: a14e64eaad30de89f0f5f0b2a782eed7ecdcb255
10161016
refs/heads/revert-19006-error-bridging-integer-type: 8a9065a3696535305ea53fe9b71f91cbe6702019
10171017
refs/heads/revert-19050-revert-19006-error-bridging-integer-type: ecf752d54b05dd0a20f510f0bfa54a3fec3bcaca
1018-
refs/heads/rxwei-patch-1: 2ce6f9c40acc4713189fa0a4999ac04ee6f786f8
1018+
refs/heads/rxwei-patch-1: 73984ad36b531fdecebe9d51644e6cdb34afc82c
10191019
refs/heads/shahmishal-patch-1: e58ec0f7488258d42bef51bc3e6d7b3dc74d7b2a
10201020
refs/heads/typelist-existential: 4046359efd541fb5c72d69a92eefc0a784df8f5e
10211021
refs/tags/swift-4.2-DEVELOPMENT-SNAPSHOT-2018-08-20-a: 4319ba09e4fb8650ee86061075c74a016b6baab9

branches/rxwei-patch-1/stdlib/public/core/Array.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,6 +1994,12 @@ extension Array.DifferentiableView : Equatable where Element : Equatable {
19941994
}
19951995
}
19961996

1997+
extension Array.DifferentiableView : ExpressibleByArrayLiteral {
1998+
public init(arrayLiteral elements: Element...) {
1999+
self.init(elements)
2000+
}
2001+
}
2002+
19972003
extension Array.DifferentiableView : CustomStringConvertible {
19982004
public var description: String {
19992005
return base.description

branches/rxwei-patch-1/stdlib/public/core/AutoDiff.swift

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,3 +885,76 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
885885
_box._move(along: direction._box)
886886
}
887887
}
888+
889+
//===----------------------------------------------------------------------===//
890+
// Differentiable higher order functions for collections
891+
//===----------------------------------------------------------------------===//
892+
893+
public extension Array where Element: Differentiable {
894+
@differentiable(wrt: (self, initialResult), vjp: _vjpDifferentiableReduce)
895+
func differentiableReduce<Result: Differentiable>(
896+
_ initialResult: Result,
897+
_ nextPartialResult: @differentiable (Result, Element) -> Result
898+
) -> Result {
899+
reduce(initialResult, nextPartialResult)
900+
}
901+
902+
@usableFromInline
903+
internal func _vjpDifferentiableReduce<Result: Differentiable>(
904+
_ initialResult: Result,
905+
_ nextPartialResult: @differentiable (Result, Element) -> Result
906+
) -> (value: Result,
907+
pullback: (Result.TangentVector)
908+
-> (Array.TangentVector, Result.TangentVector)) {
909+
var pullbacks:
910+
[(Result.TangentVector) -> (Result.TangentVector, Element.TangentVector)]
911+
= []
912+
let count = self.count
913+
pullbacks.reserveCapacity(count)
914+
var result = initialResult
915+
for element in self {
916+
let (y, pb) =
917+
Swift.valueWithPullback(at: result, element, in: nextPartialResult)
918+
result = y
919+
pullbacks.append(pb)
920+
}
921+
return (value: result, pullback: { tangent in
922+
var resultTangent = tangent
923+
var elementTangents = TangentVector([])
924+
elementTangents.base.reserveCapacity(count)
925+
for pullback in pullbacks.reversed() {
926+
let (newResultTangent, elementTangent) = pullback(resultTangent)
927+
resultTangent = newResultTangent
928+
elementTangents.base.append(elementTangent)
929+
}
930+
return (TangentVector(elementTangents.base.reversed()), resultTangent)
931+
})
932+
}
933+
}
934+
935+
public extension Array where Element: Differentiable {
936+
@differentiable(wrt: self, vjp: _vjpDifferentiableMap)
937+
func differentiableMap<Result: Differentiable>(
938+
_ body: @differentiable (Element) -> Result
939+
) -> [Result] {
940+
map(body)
941+
}
942+
943+
@usableFromInline
944+
internal func _vjpDifferentiableMap<Result: Differentiable>(
945+
_ body: @differentiable (Element) -> Result
946+
) -> (value: [Result],
947+
pullback: (Array<Result>.TangentVector) -> Array.TangentVector) {
948+
var values: [Result] = []
949+
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
950+
for x in self {
951+
let (y, pb) = Swift.valueWithPullback(at: x, in: body)
952+
values.append(y)
953+
pullbacks.append(pb)
954+
}
955+
func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector {
956+
.init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) })
957+
}
958+
return (value: values, pullback: pullback)
959+
}
960+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
#if os(macOS)
6+
import Darwin.C
7+
#else
8+
import Glibc
9+
#endif
10+
11+
// Test suite for differentiable higher order functions for collections
12+
// such as `differentiableMap(_:)` and `differentiableReduce(_:)`.
13+
var CollectionHOFTests = TestSuite("CollectionHOF")
14+
15+
let xx: [Float] = [1, 2, 3, 4, 5]
16+
17+
CollectionHOFTests.test("differentiableMap(_:)") {
18+
func double(_ xx: [Float]) -> [Float] {
19+
xx.differentiableMap { $0 * $0 }
20+
}
21+
expectEqual([], pullback(at: xx, in: double)([]))
22+
expectEqual([0], pullback(at: xx, in: double)([0]))
23+
expectEqual([2], pullback(at: xx, in: double)([1]))
24+
expectEqual([2, 4, 6, 8, 10], pullback(at: xx, in: double)([1, 1, 1, 1, 1]))
25+
}
26+
27+
CollectionHOFTests.test("differentiableReduce(_:)") {
28+
func product(_ xx: [Float]) -> Float {
29+
xx.differentiableReduce(1) { $0 * $1 }
30+
}
31+
expectEqual([1], gradient(at: [0], in: product))
32+
expectEqual([1], gradient(at: [1], in: product))
33+
expectEqual([120, 60, 40, 30, 24], gradient(at: xx, in: product))
34+
}
35+
36+
runAllTests()

0 commit comments

Comments
 (0)