Skip to content

Commit 75548c3

Browse files
authored
Merge pull request #30875 from apple/autodiff-upstream-stdlib-differentiation
[AutoDiff upstream] Add stdlib `Differentiable` conformances and derivatives.
2 parents 2257667 + 4495f6b commit 75548c3

12 files changed

+2138
-20
lines changed

lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,9 +1132,15 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
11321132
// Apply `Array.TangentVector.subscript.getter` to get array element adjoint
11331133
// buffer.
11341134
auto &ctx = builder.getASTContext();
1135-
// %index_literal = integer_literal $Builtin.Int64, <index>
1136-
auto *eltIndexLiteral = builder.createIntegerLiteral(
1137-
loc, SILType::getBuiltinIntegerType(64, ctx), eltIndex);
1135+
// %index_literal = integer_literal $Builtin.IntXX, <index>
1136+
auto builtinIntType =
1137+
SILType::getPrimitiveObjectType(ctx.getIntDecl()
1138+
->getStoredProperties()
1139+
.front()
1140+
->getInterfaceType()
1141+
->getCanonicalType());
1142+
auto *eltIndexLiteral =
1143+
builder.createIntegerLiteral(loc, builtinIntType, eltIndex);
11381144
auto intType = SILType::getPrimitiveObjectType(
11391145
ctx.getIntDecl()->getDeclaredType()->getCanonicalType());
11401146
// %index_int = struct $Int (%index_literal)
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
//===--- ArrayDifferentiation.swift ---------------------------*- swift -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
import Swift
14+
15+
//===----------------------------------------------------------------------===//
16+
// Protocol conformances
17+
//===----------------------------------------------------------------------===//
18+
19+
// TODO(TF-938): Add `Element: Differentiable` requirement.
20+
extension Array {
21+
/// The view of an array as the differentiable product manifold of `Element`
22+
/// multiplied with itself `count` times.
23+
@frozen
24+
public struct DifferentiableView {
25+
var _base: [Element]
26+
}
27+
}
28+
29+
extension Array.DifferentiableView: Differentiable
30+
where Element: Differentiable {
31+
/// The viewed array.
32+
public var base: [Element] {
33+
get { return _base }
34+
_modify { yield &_base }
35+
}
36+
37+
@usableFromInline
38+
@derivative(of: base)
39+
func _vjpBase() -> (
40+
value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector
41+
) {
42+
return (base, { $0 })
43+
}
44+
45+
/// Creates a differentiable view of the given array.
46+
public init(_ base: [Element]) { self._base = base }
47+
48+
@usableFromInline
49+
@derivative(of: init(_:))
50+
static func _vjpInit(_ base: [Element]) -> (
51+
value: Array.DifferentiableView, pullback: (TangentVector) -> TangentVector
52+
) {
53+
return (Array.DifferentiableView(base), { $0 })
54+
}
55+
56+
public typealias TangentVector =
57+
Array<Element.TangentVector>.DifferentiableView
58+
59+
public mutating func move(along direction: TangentVector) {
60+
precondition(
61+
base.count == direction.base.count,
62+
"cannot move Array.DifferentiableView with count \(base.count) along "
63+
+ "direction with different count \(direction.base.count)")
64+
for i in base.indices {
65+
base[i].move(along: direction.base[i])
66+
}
67+
}
68+
}
69+
70+
extension Array.DifferentiableView: Equatable
71+
where Element: Differentiable & Equatable {
72+
public static func == (
73+
lhs: Array.DifferentiableView,
74+
rhs: Array.DifferentiableView
75+
) -> Bool {
76+
return lhs.base == rhs.base
77+
}
78+
}
79+
80+
extension Array.DifferentiableView: ExpressibleByArrayLiteral
81+
where Element: Differentiable {
82+
public init(arrayLiteral elements: Element...) {
83+
self.init(elements)
84+
}
85+
}
86+
87+
extension Array.DifferentiableView: CustomStringConvertible
88+
where Element: Differentiable {
89+
public var description: String {
90+
return base.description
91+
}
92+
}
93+
94+
/// Makes `Array.DifferentiableView` additive as the product space.
95+
///
96+
/// Note that `Array.DifferentiableView([])` is the zero in the product spaces
97+
/// of all counts.
98+
extension Array.DifferentiableView: AdditiveArithmetic
99+
where Element: AdditiveArithmetic & Differentiable {
100+
101+
public static var zero: Array.DifferentiableView {
102+
return Array.DifferentiableView([])
103+
}
104+
105+
public static func + (
106+
lhs: Array.DifferentiableView,
107+
rhs: Array.DifferentiableView
108+
) -> Array.DifferentiableView {
109+
precondition(
110+
lhs.base.count == 0 || rhs.base.count == 0
111+
|| lhs.base.count == rhs.base.count,
112+
"cannot add Array.DifferentiableViews with different counts: "
113+
+ "\(lhs.base.count) and \(rhs.base.count)")
114+
if lhs.base.count == 0 {
115+
return rhs
116+
}
117+
if rhs.base.count == 0 {
118+
return lhs
119+
}
120+
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(+))
121+
}
122+
123+
public static func - (
124+
lhs: Array.DifferentiableView,
125+
rhs: Array.DifferentiableView
126+
) -> Array.DifferentiableView {
127+
precondition(
128+
lhs.base.count == 0 || rhs.base.count == 0
129+
|| lhs.base.count == rhs.base.count,
130+
"cannot subtract Array.DifferentiableViews with different counts: "
131+
+ "\(lhs.base.count) and \(rhs.base.count)")
132+
if lhs.base.count == 0 {
133+
return rhs
134+
}
135+
if rhs.base.count == 0 {
136+
return lhs
137+
}
138+
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(-))
139+
}
140+
141+
@inlinable
142+
public subscript(_ index: Int) -> Element {
143+
if index < base.count {
144+
return base[index]
145+
} else {
146+
return Element.zero
147+
}
148+
}
149+
}
150+
151+
/// Makes `Array` differentiable as the product manifold of `Element`
152+
/// multiplied with itself `count` times.
153+
extension Array: Differentiable where Element: Differentiable {
154+
// In an ideal world, `TangentVector` would be `[Element.TangentVector]`.
155+
// Unfortunately, we cannot conform `Array` to `AdditiveArithmetic` for
156+
// `TangentVector` because `Array` already has a static `+` method with
157+
// different semantics from `AdditiveArithmetic.+`. So we use
158+
// `Array.DifferentiableView` for all these associated types.
159+
public typealias TangentVector =
160+
Array<Element.TangentVector>.DifferentiableView
161+
162+
public mutating func move(along direction: TangentVector) {
163+
var view = DifferentiableView(self)
164+
view.move(along: direction)
165+
self = view.base
166+
}
167+
}
168+
169+
//===----------------------------------------------------------------------===//
170+
// Derivatives
171+
//===----------------------------------------------------------------------===//
172+
173+
extension Array where Element: Differentiable {
174+
@usableFromInline
175+
@derivative(of: subscript)
176+
func _vjpSubscript(index: Int) -> (
177+
value: Element, pullback: (Element.TangentVector) -> TangentVector
178+
) {
179+
func pullback(_ gradientIn: Element.TangentVector) -> TangentVector {
180+
var gradientOut = [Element.TangentVector](
181+
repeating: .zero,
182+
count: count)
183+
gradientOut[index] = gradientIn
184+
return TangentVector(gradientOut)
185+
}
186+
return (self[index], pullback)
187+
}
188+
189+
@usableFromInline
190+
@derivative(of: +)
191+
static func _vjpConcatenate(_ lhs: [Element], _ rhs: [Element]) -> (
192+
value: [Element],
193+
pullback: (TangentVector) -> (TangentVector, TangentVector)
194+
) {
195+
func pullback(_ gradientIn: TangentVector) -> (TangentVector, TangentVector)
196+
{
197+
precondition(
198+
gradientIn.base.count == lhs.count + rhs.count,
199+
"+ should receive gradient with count equal to sum of operand "
200+
+ "counts, but counts are: gradient \(gradientIn.base.count), "
201+
+ "lhs \(lhs.count), rhs \(rhs.count)")
202+
return (
203+
TangentVector(
204+
[Element.TangentVector](
205+
gradientIn.base[0..<lhs.count])),
206+
TangentVector(
207+
[Element.TangentVector](
208+
gradientIn.base[lhs.count...]))
209+
)
210+
}
211+
return (lhs + rhs, pullback)
212+
}
213+
}
214+
215+
extension Array where Element: Differentiable {
216+
@usableFromInline
217+
@derivative(of: append)
218+
mutating func _vjpAppend(_ element: Element) -> (
219+
value: Void, pullback: (inout TangentVector) -> Element.TangentVector
220+
) {
221+
let appendedElementIndex = count
222+
defer { append(element) }
223+
return ((), { dself in dself.base[appendedElementIndex] })
224+
}
225+
226+
@usableFromInline
227+
@derivative(of: append)
228+
mutating func _jvpAppend(_ element: Element) -> (
229+
value: Void,
230+
differential: (inout TangentVector, Element.TangentVector) -> Void
231+
) {
232+
append(element)
233+
return ((), { $0.base.append($1) })
234+
}
235+
}
236+
237+
extension Array where Element: Differentiable {
238+
@usableFromInline
239+
@derivative(of: init(repeating:count:))
240+
static func _vjpInit(repeating repeatedValue: Element, count: Int) -> (
241+
value: Self, pullback: (TangentVector) -> Element.TangentVector
242+
) {
243+
(
244+
value: Self(repeating: repeatedValue, count: count),
245+
pullback: { v in
246+
v.base.reduce(.zero, +)
247+
}
248+
)
249+
}
250+
}
251+
252+
//===----------------------------------------------------------------------===//
253+
// Differentiable higher order functions for collections
254+
//===----------------------------------------------------------------------===//
255+
256+
extension Array where Element: Differentiable {
257+
@differentiable(wrt: (self, initialResult))
258+
public func differentiableReduce<Result: Differentiable>(
259+
_ initialResult: Result,
260+
_ nextPartialResult: @differentiable (Result, Element) -> Result
261+
) -> Result {
262+
reduce(initialResult, nextPartialResult)
263+
}
264+
265+
@usableFromInline
266+
@derivative(of: differentiableReduce)
267+
internal func _vjpDifferentiableReduce<Result: Differentiable>(
268+
_ initialResult: Result,
269+
_ nextPartialResult: @differentiable (Result, Element) -> Result
270+
) -> (
271+
value: Result,
272+
pullback: (Result.TangentVector)
273+
-> (Array.TangentVector, Result.TangentVector)
274+
) {
275+
var pullbacks:
276+
[(Result.TangentVector) -> (Result.TangentVector, Element.TangentVector)] =
277+
[]
278+
let count = self.count
279+
pullbacks.reserveCapacity(count)
280+
var result = initialResult
281+
for element in self {
282+
let (y, pb) =
283+
valueWithPullback(at: result, element, in: nextPartialResult)
284+
result = y
285+
pullbacks.append(pb)
286+
}
287+
return (
288+
value: result,
289+
pullback: { tangent in
290+
var resultTangent = tangent
291+
var elementTangents = TangentVector([])
292+
elementTangents.base.reserveCapacity(count)
293+
for pullback in pullbacks.reversed() {
294+
let (newResultTangent, elementTangent) = pullback(resultTangent)
295+
resultTangent = newResultTangent
296+
elementTangents.base.append(elementTangent)
297+
}
298+
return (TangentVector(elementTangents.base.reversed()), resultTangent)
299+
}
300+
)
301+
}
302+
}
303+
304+
extension Array where Element: Differentiable {
305+
@differentiable(wrt: self)
306+
public func differentiableMap<Result: Differentiable>(
307+
_ body: @differentiable (Element) -> Result
308+
) -> [Result] {
309+
map(body)
310+
}
311+
312+
@usableFromInline
313+
@derivative(of: differentiableMap)
314+
internal func _vjpDifferentiableMap<Result: Differentiable>(
315+
_ body: @differentiable (Element) -> Result
316+
) -> (
317+
value: [Result],
318+
pullback: (Array<Result>.TangentVector) -> Array.TangentVector
319+
) {
320+
var values: [Result] = []
321+
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
322+
for x in self {
323+
let (y, pb) = valueWithPullback(at: x, in: body)
324+
values.append(y)
325+
pullbacks.append(pb)
326+
}
327+
func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector {
328+
.init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) })
329+
}
330+
return (value: values, pullback: pullback)
331+
}
332+
}

stdlib/public/Differentiation/CMakeLists.txt

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source file is part of the Swift.org open source project
44
#
5-
# Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
5+
# Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
66
# Licensed under Apache License v2.0 with Runtime Library Exception
77
#
88
# See https://swift.org/LICENSE.txt for license information
@@ -14,6 +14,22 @@ add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPE
1414
Differentiable.swift
1515
DifferentialOperators.swift
1616
DifferentiationUtilities.swift
17+
ArrayDifferentiation.swift
18+
19+
GYB_SOURCES
20+
FloatingPointDifferentiation.swift.gyb
21+
TgmathDerivatives.swift.gyb
22+
SIMDDifferentiation.swift.gyb
23+
24+
SWIFT_MODULE_DEPENDS_OSX Darwin
25+
SWIFT_MODULE_DEPENDS_IOS Darwin
26+
SWIFT_MODULE_DEPENDS_TVOS Darwin
27+
SWIFT_MODULE_DEPENDS_WATCHOS Darwin
28+
SWIFT_MODULE_DEPENDS_LINUX Glibc
29+
SWIFT_MODULE_DEPENDS_FREEBSD Glibc
30+
SWIFT_MODULE_DEPENDS_CYGWIN Glibc
31+
SWIFT_MODULE_DEPENDS_HAIKU Glibc
32+
SWIFT_MODULE_DEPENDS_WINDOWS MSVCRT
1733

1834
SWIFT_COMPILE_FLAGS
1935
${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}

0 commit comments

Comments
 (0)