Skip to content

Commit ccfbc38

Browse files
JaapWijnenJaap Wijnen
andauthored
mark multiple autodiff related array methods as inlinable for increased specialization possibilities (#75778)
Co-authored-by: Jaap Wijnen <[email protected]>
1 parent cd79938 commit ccfbc38

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,29 @@ extension Array where Element: Differentiable {
2121
/// multiplied with itself `count` times.
2222
@frozen
2323
public struct DifferentiableView {
24+
@usableFromInline
2425
var _base: [Element]
2526
}
2627
}
2728

2829
extension Array.DifferentiableView: Differentiable
2930
where Element: Differentiable {
3031
/// The viewed array.
32+
@inlinable
3133
public var base: [Element] {
3234
get { _base }
3335
_modify { yield &_base }
3436
}
3537

36-
@usableFromInline
38+
@inlinable
3739
@derivative(of: base)
3840
func _vjpBase() -> (
3941
value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector
4042
) {
4143
return (base, { $0 })
4244
}
4345

44-
@usableFromInline
46+
@inlinable
4547
@derivative(of: base)
4648
func _jvpBase() -> (
4749
value: [Element], differential: (Array<Element>.TangentVector) -> TangentVector
@@ -50,17 +52,18 @@ where Element: Differentiable {
5052
}
5153

5254
/// Creates a differentiable view of the given array.
55+
@inlinable
5356
public init(_ base: [Element]) { self._base = base }
5457

55-
@usableFromInline
58+
@inlinable
5659
@derivative(of: init(_:))
5760
static func _vjpInit(_ base: [Element]) -> (
5861
value: Array.DifferentiableView, pullback: (TangentVector) -> TangentVector
5962
) {
6063
return (Array.DifferentiableView(base), { $0 })
6164
}
6265

63-
@usableFromInline
66+
@inlinable
6467
@derivative(of: init(_:))
6568
static func _jvpInit(_ base: [Element]) -> (
6669
value: Array.DifferentiableView, differential: (TangentVector) -> TangentVector
@@ -71,6 +74,7 @@ where Element: Differentiable {
7174
public typealias TangentVector =
7275
Array<Element.TangentVector>.DifferentiableView
7376

77+
@inlinable
7478
public mutating func move(by offset: TangentVector) {
7579
if offset.base.isEmpty {
7680
return
@@ -88,6 +92,7 @@ where Element: Differentiable {
8892

8993
extension Array.DifferentiableView: Equatable
9094
where Element: Differentiable & Equatable {
95+
@inlinable
9196
public static func == (
9297
lhs: Array.DifferentiableView,
9398
rhs: Array.DifferentiableView
@@ -98,6 +103,7 @@ where Element: Differentiable & Equatable {
98103

99104
extension Array.DifferentiableView: ExpressibleByArrayLiteral
100105
where Element: Differentiable {
106+
@inlinable
101107
public init(arrayLiteral elements: Element...) {
102108
self.init(elements)
103109
}
@@ -123,10 +129,12 @@ extension Array.DifferentiableView: CustomReflectable {
123129
extension Array.DifferentiableView: AdditiveArithmetic
124130
where Element: AdditiveArithmetic & Differentiable {
125131

132+
@inlinable
126133
public static var zero: Array.DifferentiableView {
127134
return Array.DifferentiableView([])
128135
}
129-
136+
137+
@inlinable
130138
public static func + (
131139
lhs: Array.DifferentiableView,
132140
rhs: Array.DifferentiableView
@@ -143,6 +151,7 @@ where Element: AdditiveArithmetic & Differentiable {
143151
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(+))
144152
}
145153

154+
@inlinable
146155
public static func - (
147156
lhs: Array.DifferentiableView,
148157
rhs: Array.DifferentiableView
@@ -180,6 +189,7 @@ extension Array: Differentiable where Element: Differentiable {
180189
public typealias TangentVector =
181190
Array<Element.TangentVector>.DifferentiableView
182191

192+
@inlinable
183193
public mutating func move(by offset: TangentVector) {
184194
var view = DifferentiableView(self)
185195
view.move(by: offset)
@@ -192,7 +202,7 @@ extension Array: Differentiable where Element: Differentiable {
192202
//===----------------------------------------------------------------------===//
193203

194204
extension Array where Element: Differentiable {
195-
@usableFromInline
205+
@inlinable
196206
@derivative(of: subscript)
197207
func _vjpSubscript(index: Int) -> (
198208
value: Element, pullback: (Element.TangentVector) -> TangentVector
@@ -207,7 +217,7 @@ extension Array where Element: Differentiable {
207217
return (self[index], pullback)
208218
}
209219

210-
@usableFromInline
220+
@inlinable
211221
@derivative(of: subscript)
212222
func _jvpSubscript(index: Int) -> (
213223
value: Element, differential: (TangentVector) -> Element.TangentVector
@@ -218,7 +228,7 @@ extension Array where Element: Differentiable {
218228
return (self[index], differential)
219229
}
220230

221-
@usableFromInline
231+
@inlinable
222232
@derivative(of: +)
223233
static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> (
224234
value: Self,
@@ -241,7 +251,7 @@ extension Array where Element: Differentiable {
241251
return (lhs + rhs, pullback)
242252
}
243253

244-
@usableFromInline
254+
@inlinable
245255
@derivative(of: +)
246256
static func _jvpConcatenate(_ lhs: Self, _ rhs: Self) -> (
247257
value: Self,
@@ -261,7 +271,7 @@ extension Array where Element: Differentiable {
261271

262272

263273
extension Array where Element: Differentiable {
264-
@usableFromInline
274+
@inlinable
265275
@derivative(of: append)
266276
mutating func _vjpAppend(_ element: Element) -> (
267277
value: Void, pullback: (inout TangentVector) -> Element.TangentVector
@@ -274,7 +284,7 @@ extension Array where Element: Differentiable {
274284
})
275285
}
276286

277-
@usableFromInline
287+
@inlinable
278288
@derivative(of: append)
279289
mutating func _jvpAppend(_ element: Element) -> (
280290
value: Void,
@@ -286,7 +296,7 @@ extension Array where Element: Differentiable {
286296
}
287297

288298
extension Array where Element: Differentiable {
289-
@usableFromInline
299+
@inlinable
290300
@derivative(of: +=)
291301
static func _vjpAppend(_ lhs: inout Self, _ rhs: Self) -> (
292302
value: Void, pullback: (inout TangentVector) -> TangentVector
@@ -302,7 +312,7 @@ extension Array where Element: Differentiable {
302312
})
303313
}
304314

305-
@usableFromInline
315+
@inlinable
306316
@derivative(of: +=)
307317
static func _jvpAppend(_ lhs: inout Self, _ rhs: Self) -> (
308318
value: Void, differential: (inout TangentVector, TangentVector) -> Void
@@ -313,7 +323,7 @@ extension Array where Element: Differentiable {
313323
}
314324

315325
extension Array where Element: Differentiable {
316-
@usableFromInline
326+
@inlinable
317327
@derivative(of: init(repeating:count:))
318328
static func _vjpInit(repeating repeatedValue: Element, count: Int) -> (
319329
value: Self, pullback: (TangentVector) -> Element.TangentVector
@@ -326,7 +336,7 @@ extension Array where Element: Differentiable {
326336
)
327337
}
328338

329-
@usableFromInline
339+
@inlinable
330340
@derivative(of: init(repeating:count:))
331341
static func _jvpInit(repeating repeatedValue: Element, count: Int) -> (
332342
value: Self, differential: (Element.TangentVector) -> TangentVector

0 commit comments

Comments
 (0)