Skip to content

Commit 9c34970

Browse files
eaplataniosrxwei
authored andcommitted
[AutoDiff] [stdlib] Made arrays differentiable (#23183)
1 parent 3ff86f2 commit 9c34970

File tree

3 files changed

+320
-4
lines changed

3 files changed

+320
-4
lines changed

stdlib/public/core/Array.swift

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,8 @@ extension Array: RandomAccessCollection, MutableCollection {
693693
/// bridged `NSArray` instance as its storage, in which case writing is
694694
/// O(*n*), where *n* is the length of the array.
695695
@inlinable
696+
// SWIFT_ENABLE_TENSORFLOW
697+
@differentiable(wrt: self, vjp: _vjpSubscript where Element : Differentiable)
696698
public subscript(index: Int) -> Element {
697699
get {
698700
// This call may be hoisted or eliminated by the optimizer. If
@@ -1301,6 +1303,8 @@ extension Array: RangeReplaceableCollection {
13011303
// operator in the same expression.
13021304
extension Array {
13031305
@inlinable
1306+
// SWIFT_ENABLE_TENSORFLOW
1307+
@differentiable(vjp: _vjpPlus where Element : Differentiable)
13041308
public static func + (lhs: Array, rhs: Array) -> Array {
13051309
var lhs = lhs
13061310
lhs.append(contentsOf: rhs)
@@ -1865,3 +1869,217 @@ internal struct _ArrayAnyHashableBox<Element: Hashable>
18651869
return true
18661870
}
18671871
}
1872+
1873+
// SWIFT_ENABLE_TENSORFLOW
1874+
extension Array where Element : Differentiable {
1875+
/// The view of an array as the differentiable product manifold of `Element`
1876+
/// multiplied with itself `count` times.
1877+
@_fixed_layout
1878+
public struct DifferentiableView : Differentiable & KeyPathIterable {
1879+
private var _base: [Element]
1880+
1881+
/// The viewed array.
1882+
// I'm implementing this as a computed property instead of directly
1883+
// exposing `_base` because the `@differentiable` annotation does not make
1884+
// the stored property actually differentiable. I think this is a bug.
1885+
// Maybe it's related to `@_fixed_layout`?
1886+
// TODO: Determine if that is a bug, and fix.
1887+
public var base: [Element] {
1888+
@differentiable(wrt: self, vjp: _vjpBase)
1889+
get { return _base }
1890+
_modify { yield &_base }
1891+
}
1892+
1893+
@usableFromInline
1894+
func _vjpBase() ->
1895+
([Element], (Array<Element>.CotangentVector) -> CotangentVector) {
1896+
return (base, { $0 })
1897+
}
1898+
1899+
/// Creates a differentiable view of the given array.
1900+
@differentiable(wrt: base, vjp: _vjpInit)
1901+
public init(_ base: [Element]) { self._base = base }
1902+
1903+
@usableFromInline
1904+
static func _vjpInit(_ base: [Element]) ->
1905+
(Array.DifferentiableView, (CotangentVector) -> CotangentVector) {
1906+
return (Array.DifferentiableView(base), { $0 })
1907+
}
1908+
1909+
// MARK: - Differentiable conformance.
1910+
1911+
public typealias TangentVector =
1912+
Array<Element.TangentVector>.DifferentiableView
1913+
public typealias CotangentVector =
1914+
Array<Element.CotangentVector>.DifferentiableView
1915+
public typealias AllDifferentiableVariables =
1916+
Array<Element.AllDifferentiableVariables>.DifferentiableView
1917+
1918+
public var allDifferentiableVariables: AllDifferentiableVariables {
1919+
get {
1920+
return AllDifferentiableVariables(
1921+
base.map { $0.allDifferentiableVariables })
1922+
}
1923+
set {
1924+
precondition(
1925+
base.count == newValue.base.count,
1926+
"cannot set Array.DifferentiableView.AllDifferentiableVariables " +
1927+
"with count \(base.count) to " +
1928+
"Array.DifferentiableView.AllDifferentiableVariables with " +
1929+
"different count \(newValue.base.count)")
1930+
for i in base.indices {
1931+
base[i].allDifferentiableVariables = newValue.base[i]
1932+
}
1933+
}
1934+
}
1935+
1936+
public func moved(along direction: TangentVector) -> DifferentiableView {
1937+
precondition(
1938+
base.count == direction.base.count,
1939+
"cannot move Array.DifferentiableView with count \(base.count) along " +
1940+
"direction with different count \(direction.base.count)")
1941+
return DifferentiableView(
1942+
zip(base, direction.base).map { $0.moved(along: $1) })
1943+
}
1944+
1945+
public func tangentVector(from cotangentVector: CotangentVector) ->
1946+
TangentVector {
1947+
precondition(
1948+
base.count == cotangentVector.base.count,
1949+
"cannot use Array.DifferentiableView with count \(base.count) to " +
1950+
"get tangentVector from cotangentVector with different count " +
1951+
"\(cotangentVector.base.count)")
1952+
return TangentVector(zip(base, cotangentVector.base).map {
1953+
(selfElement, cotangentVectorElement) in
1954+
selfElement.tangentVector(from: cotangentVectorElement)
1955+
})
1956+
}
1957+
}
1958+
}
1959+
1960+
extension Array.DifferentiableView : Equatable where Element : Equatable {
1961+
public static func == (
1962+
lhs: Array.DifferentiableView,
1963+
rhs: Array.DifferentiableView
1964+
) -> Bool {
1965+
return lhs.base == rhs.base
1966+
}
1967+
}
1968+
1969+
/// Makes `Array.DifferentiableView` additive as the product space.
1970+
///
1971+
/// Note that `Array.DifferentiableView([])` is the zero in the product spaces
1972+
/// of all counts.
1973+
extension Array.DifferentiableView : AdditiveArithmetic
1974+
where Element : AdditiveArithmetic {
1975+
1976+
public static var zero: Array.DifferentiableView {
1977+
return Array.DifferentiableView([])
1978+
}
1979+
1980+
public static func + (
1981+
lhs: Array.DifferentiableView,
1982+
rhs: Array.DifferentiableView
1983+
) -> Array.DifferentiableView {
1984+
precondition(
1985+
lhs.base.count == 0 || rhs.base.count == 0 ||
1986+
lhs.base.count == rhs.base.count,
1987+
"cannot add Array.DifferentiableViews with different counts: " +
1988+
"\(lhs.base.count) and \(rhs.base.count)")
1989+
if lhs.base.count == 0 {
1990+
return rhs
1991+
}
1992+
if rhs.base.count == 0 {
1993+
return lhs
1994+
}
1995+
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(+))
1996+
}
1997+
1998+
public static func - (
1999+
lhs: Array.DifferentiableView,
2000+
rhs: Array.DifferentiableView
2001+
) -> Array.DifferentiableView {
2002+
precondition(
2003+
lhs.base.count == 0 || rhs.base.count == 0 ||
2004+
lhs.base.count == rhs.base.count,
2005+
"cannot subtract Array.DifferentiableViews with different counts: " +
2006+
"\(lhs.base.count) and \(rhs.base.count)")
2007+
if lhs.base.count == 0 {
2008+
return rhs
2009+
}
2010+
if rhs.base.count == 0 {
2011+
return lhs
2012+
}
2013+
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(-))
2014+
}
2015+
}
2016+
2017+
/// Makes `Array` differentiable as the product manifold of `Element`
2018+
/// multiplied with itself `count` times.
2019+
extension Array : Differentiable where Element : Differentiable {
2020+
// In an ideal world, `TangentVector`, `CotangentVector`, and
2021+
// `AllDifferentiableVariables` would all be `Array`s. Unfortunately, we
2022+
// can't conform `Array` to `AdditiveArithmetic` for `TangentVector` and
2023+
// `CotangentVector`, because `Array` already has a static `+` method with
2024+
// different semantics from `AdditiveArithmetic` `+`. So we use
2025+
// `Array.DifferentiableView` for all these associated types.
2026+
public typealias TangentVector =
2027+
Array<Element.TangentVector>.DifferentiableView
2028+
public typealias CotangentVector =
2029+
Array<Element.CotangentVector>.DifferentiableView
2030+
public typealias AllDifferentiableVariables =
2031+
Array<Element.AllDifferentiableVariables>.DifferentiableView
2032+
2033+
public var allDifferentiableVariables: AllDifferentiableVariables {
2034+
get {
2035+
return DifferentiableView(self).allDifferentiableVariables
2036+
}
2037+
set {
2038+
var view = DifferentiableView(self)
2039+
view.allDifferentiableVariables = newValue
2040+
self = view.base
2041+
}
2042+
}
2043+
2044+
public func moved(along direction: TangentVector) -> Array {
2045+
return DifferentiableView(self).moved(along: direction).base
2046+
}
2047+
2048+
public func tangentVector(from cotangentVector: CotangentVector) ->
2049+
TangentVector {
2050+
return DifferentiableView(self).tangentVector(from: cotangentVector)
2051+
}
2052+
}
2053+
2054+
extension Array where Element : Differentiable {
2055+
public func _vjpSubscript(index: Int) ->
2056+
(Element, (Element.CotangentVector) -> CotangentVector)
2057+
{
2058+
func pullback(_ gradientIn: Element.CotangentVector) -> CotangentVector {
2059+
var gradientOut = Array<Element.CotangentVector>(
2060+
repeating: .zero,
2061+
count: count)
2062+
gradientOut[index] = gradientIn
2063+
return CotangentVector(gradientOut)
2064+
}
2065+
return (self[index], pullback)
2066+
}
2067+
2068+
public static func _vjpPlus(_ lhs: [Element], _ rhs: [Element]) ->
2069+
([Element], (CotangentVector) -> (CotangentVector, CotangentVector)) {
2070+
func pullback(_ gradientIn: CotangentVector) ->
2071+
(CotangentVector, CotangentVector) {
2072+
precondition(
2073+
gradientIn.base.count == lhs.count + rhs.count,
2074+
"+ should receive gradient with count equal to sum of operand " +
2075+
"counts, but counts are: gradient \(gradientIn.base.count), " +
2076+
"lhs \(lhs.count), rhs \(rhs.count)")
2077+
return (
2078+
CotangentVector(Array<Element.CotangentVector>(
2079+
gradientIn.base[0..<lhs.count])),
2080+
CotangentVector(Array<Element.CotangentVector>(
2081+
gradientIn.base[lhs.count...])))
2082+
}
2083+
return (lhs + rhs, pullback)
2084+
}
2085+
}

test/AutoDiff/array.swift

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// RUN: %target-run-simple-swift
2+
3+
import StdlibUnittest
4+
5+
var ArrayAutodiffTests = TestSuite("ArrayAutodiff")
6+
7+
typealias FloatArrayGrad = Array<Float>.CotangentVector
8+
9+
ArrayAutodiffTests.test("ArrayIdentity") {
10+
func arrayIdentity(_ x: [Float]) -> [Float] {
11+
return x
12+
}
13+
14+
let backprop = pullback(at: [5, 6, 7, 8], in: arrayIdentity)
15+
expectEqual(
16+
FloatArrayGrad([1, 2, 3, 4]),
17+
backprop(FloatArrayGrad([1, 2, 3, 4])))
18+
}
19+
20+
ArrayAutodiffTests.test("ArraySubscript") {
21+
func sumFirstThree(_ array: [Float]) -> Float {
22+
return array[0] + array[1] + array[2]
23+
}
24+
25+
expectEqual(
26+
FloatArrayGrad([1, 1, 1, 0, 0, 0]),
27+
gradient(at: [2, 3, 4, 5, 6, 7], in: sumFirstThree))
28+
}
29+
30+
ArrayAutodiffTests.test("ArrayConcat") {
31+
struct TwoArrays : Differentiable {
32+
let a: [Float]
33+
let b: [Float]
34+
}
35+
36+
func sumFirstThreeConcatted(_ arrs: TwoArrays) -> Float {
37+
let c = arrs.a + arrs.b
38+
return c[0] + c[1] + c[2]
39+
}
40+
41+
expectEqual(
42+
TwoArrays.CotangentVector(
43+
a: FloatArrayGrad([1, 1]),
44+
b: FloatArrayGrad([1, 0])),
45+
gradient(
46+
at: TwoArrays(a: [0, 0], b: [0, 0]),
47+
in: sumFirstThreeConcatted))
48+
expectEqual(
49+
TwoArrays.CotangentVector(
50+
a: FloatArrayGrad([1, 1, 1, 0]),
51+
b: FloatArrayGrad([0, 0])),
52+
gradient(
53+
at: TwoArrays(a: [0, 0, 0, 0], b: [0, 0]),
54+
in: sumFirstThreeConcatted))
55+
expectEqual(
56+
TwoArrays.CotangentVector(
57+
a: FloatArrayGrad([]),
58+
b: FloatArrayGrad([1, 1, 1, 0])),
59+
gradient(
60+
at: TwoArrays(a: [], b: [0, 0, 0, 0]),
61+
in: sumFirstThreeConcatted))
62+
}
63+
64+
ArrayAutodiffTests.test("Array.DifferentiableView.init") {
65+
@differentiable
66+
func constructView(_ x: [Float]) -> Array<Float>.DifferentiableView {
67+
return Array<Float>.DifferentiableView(x)
68+
}
69+
70+
let backprop = pullback(at: [5, 6, 7, 8], in: constructView)
71+
expectEqual(
72+
FloatArrayGrad([1, 2, 3, 4]),
73+
backprop(FloatArrayGrad([1, 2, 3, 4])))
74+
}
75+
76+
ArrayAutodiffTests.test("Array.DifferentiableView.base") {
77+
@differentiable
78+
func accessBase(_ x: Array<Float>.DifferentiableView) -> [Float] {
79+
return x.base
80+
}
81+
82+
let backprop = pullback(
83+
at: Array<Float>.DifferentiableView([5, 6, 7, 8]),
84+
in: accessBase)
85+
expectEqual(
86+
FloatArrayGrad([1, 2, 3, 4]),
87+
backprop(FloatArrayGrad([1, 2, 3, 4])))
88+
}
89+
90+
ArrayAutodiffTests.test("Array.DifferentiableView : KeyPathIterable") {
91+
struct Container : KeyPathIterable {
92+
let a: Array<Float>.DifferentiableView
93+
}
94+
let container = Container(a: Array<Float>.DifferentiableView([1, 2, 3]))
95+
expectEqual(
96+
[1, 2, 3],
97+
container.recursivelyAllKeyPaths(to: Float.self).map {
98+
container[keyPath: $0]
99+
})
100+
}
101+
102+
runAllTests()

test/Sema/struct_differentiable.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,6 @@ struct GenericConstrained<T> {
243243
extension GenericConstrained : Differentiable
244244
where T : Differentiable {}
245245

246-
// TF-161: Test conditional conformance of `Array`.
247-
// expected-warning @+1 {{stored property '_buffer' has no derivative because it does not conform to 'Differentiable'; add '@noDerivative' to make it explicit}}
248-
extension Array : Differentiable where Element : Differentiable {}
249-
250246
struct TF_260<T : Differentiable> : Differentiable & AdditiveArithmetic {
251247
var x: T.CotangentVector
252248
}

0 commit comments

Comments
 (0)