Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 01c94b5

Browse files
authored
[AutoDiff] Use @derivative for derivative registration. (#591)
Rewrite all derivative registration using `@differentiable(jvp:vjp:)` with `@derivative`. Keep original `@differentiable` attributes so that derivative functions are publicly exposed. When retroactive derivative registration is complete: - `@differentiable(jvp:vjp:)` will be deprecated. - `@derivative` attribute will be the canonical way to register derivatives. Resolves TF-1076.
1 parent 4bab884 commit 01c94b5

File tree

12 files changed

+433
-308
lines changed

12 files changed

+433
-308
lines changed

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ public extension Tensor {
125125
/// Reshape to scalar.
126126
/// - Precondition: The tensor has exactly one scalar.
127127
@inlinable
128-
@differentiable(wrt: self, vjp: _vjpScalarized where Scalar: TensorFlowFloatingPoint)
128+
@differentiable(where Scalar: TensorFlowFloatingPoint)
129129
func scalarized() -> Scalar {
130130
precondition(shape.contiguousSize == 1,
131131
"This tensor must have exactly one scalar but contains \(shape.contiguousSize).")
@@ -135,7 +135,8 @@ public extension Tensor {
135135

136136
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
137137
@inlinable
138-
func _vjpScalarized() -> (Scalar, (Scalar) -> Tensor) {
138+
@derivative(of: scalarized)
139+
func _vjpScalarized() -> (value: Scalar, pullback: (Scalar) -> Tensor) {
139140
return (scalarized(), { v in Tensor(v) })
140141
}
141142
}
@@ -162,14 +163,15 @@ public extension Tensor {
162163
}
163164

164165
@inlinable
165-
@differentiable(vjp: _vjpScalars where Scalar: TensorFlowFloatingPoint)
166+
@differentiable(where Scalar: TensorFlowFloatingPoint)
166167
var scalars: [Scalar] {
167168
return array.scalars
168169
}
169170
}
170171

171172
extension Tensor where Scalar: TensorFlowFloatingPoint {
172173
@inlinable
174+
@derivative(of: scalars)
173175
func _vjpScalars() -> (value: [Scalar], pullback: (Array<Scalar>.TangentVector) -> Tensor) {
174176
(value: scalars, pullback: { [shape = self.shape, device = self.device] v in
175177
Tensor(shape: shape, scalars: v.base, on: device)
@@ -184,24 +186,26 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
184186
public extension Tensor {
185187
/// Creates a 0-D tensor from a scalar value.
186188
@inlinable
187-
@differentiable(vjp: _vjpScalarInit where Scalar: TensorFlowFloatingPoint)
189+
@differentiable(where Scalar: TensorFlowFloatingPoint)
188190
init(_ value: Scalar, on device: Device = Device.getDefault) {
189191
self.init(shape: [], scalars: [value], on: device)
190192
}
191193
}
192194

193195
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
194196
@inlinable
195-
static func _vjpScalarInit(_ value: __owned Scalar, on device: Device = Device.getDefault
196-
) -> (Tensor, (Tensor) -> Scalar) {
197+
@derivative(of: init(_:on:))
198+
static func _vjpScalarInit(_ value: __owned Scalar, on device: Device = Device.getDefault) -> (
199+
value: Tensor, pullback: (Tensor) -> Scalar
200+
) {
197201
return (Tensor(value, on: device), { $0.scalarized() })
198202
}
199203
}
200204

201205
public extension Tensor {
202206
/// Creates a 1D tensor from scalars.
203207
@inlinable
204-
@differentiable(vjp: _vjpInit(_:on:) where Scalar: TensorFlowFloatingPoint)
208+
@differentiable(where Scalar: TensorFlowFloatingPoint)
205209
init(_ scalars: [Scalar], on device: Device = Device.getDefault) {
206210
self.init(shape: [scalars.count], scalars: scalars, on: device)
207211
}
@@ -230,7 +234,7 @@ public extension Tensor {
230234
/// - scalars: The scalar contents of the tensor.
231235
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
232236
@inlinable
233-
@differentiable(vjp: _vjpInit(shape:scalars:on:) where Scalar: TensorFlowFloatingPoint)
237+
@differentiable(where Scalar: TensorFlowFloatingPoint)
234238
init(shape: TensorShape, scalars: [Scalar], on device: Device = Device.getDefault) {
235239
precondition(shape.contiguousSize == scalars.count,
236240
"""
@@ -297,6 +301,7 @@ public extension Tensor {
297301

298302
extension Tensor where Scalar: TensorFlowFloatingPoint {
299303
@inlinable
304+
@derivative(of: init(_:on:))
300305
static func _vjpInit(_ scalars: [Scalar], on device: Device = Device.getDefault) -> (
301306
value: Tensor, pullback: (Tensor) -> Array<Scalar>.TangentVector
302307
) {
@@ -306,6 +311,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
306311
}
307312

308313
@inlinable
314+
@derivative(of: init(shape:scalars:on:))
309315
static func _vjpInit(
310316
shape: TensorShape, scalars: [Scalar], on device: Device = Device.getDefault
311317
) -> (value: Tensor, pullback: (Tensor) -> Array<Scalar>.TangentVector) {
@@ -542,23 +548,26 @@ extension Tensor: AdditiveArithmetic where Scalar: Numeric {
542548
/// Adds two tensors and produces their sum.
543549
/// - Note: `+` supports broadcasting.
544550
@inlinable
545-
@differentiable(vjp: _vjpAdd(lhs:rhs:) where Scalar: TensorFlowFloatingPoint)
551+
@differentiable(where Scalar: TensorFlowFloatingPoint)
546552
public static func + (lhs: Tensor, rhs: Tensor) -> Tensor {
547553
_Raw.addV2(lhs, rhs)
548554
}
549555

550556
/// Subtracts one tensor from another and produces their difference.
551557
/// - Note: `-` supports broadcasting.
552558
@inlinable
553-
@differentiable(vjp: _vjpSubtract(lhs:rhs:) where Scalar: TensorFlowFloatingPoint)
559+
@differentiable(where Scalar: TensorFlowFloatingPoint)
554560
public static func - (lhs: Tensor, rhs: Tensor) -> Tensor {
555561
_Raw.sub(lhs, rhs)
556562
}
557563
}
558564

559565
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
560566
@inlinable
561-
static func _vjpAdd(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
567+
@derivative(of: +)
568+
static func _vjpAdd(lhs: Tensor, rhs: Tensor) -> (
569+
value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)
570+
) {
562571
(lhs + rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
563572
let lhsGrad = v
564573
let rhsGrad = lhsGrad
@@ -569,7 +578,10 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
569578
}
570579

571580
@inlinable
572-
static func _vjpSubtract(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
581+
@derivative(of: -)
582+
static func _vjpSubtract(lhs: Tensor, rhs: Tensor) -> (
583+
value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)
584+
) {
573585
(lhs - rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
574586
let lhsGrad = v
575587
let rhsGrad = -lhsGrad

Sources/TensorFlow/Freezable.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ public struct _Freezable<Value: Differentiable> {
3030
}
3131

3232
/// The wrapped differentiable value.
33-
@differentiable(vjp: _vjpValue)
33+
@differentiable
3434
public var wrappedValue: Value {
3535
get { _value }
3636
set { _value = newValue }
3737
}
3838

3939
@usableFromInline
40+
@derivative(of: wrappedValue)
4041
func _vjpValue() -> (value: Value, pullback: (Value.TangentVector) -> TangentVector) {
4142
return (_value, { [isFrozen = self.isFrozen] v in
4243
isFrozen ? .zero : v

Sources/TensorFlow/Initializers.swift

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public extension Tensor {
3434
/// - repeatedValue: The scalar value to repeat.
3535
/// - shape: The dimensions of the tensor.
3636
@inlinable
37-
@differentiable(vjp: _vjpInit(repeating:shape:) where Scalar: TensorFlowFloatingPoint)
37+
@differentiable(where Scalar: TensorFlowFloatingPoint)
3838
init(repeating repeatedValue: Scalar, shape: TensorShape) {
3939
self = _Raw.fill(
4040
dims: Tensor<Int32>(shape.dimensions.map(Int32.init)),
@@ -60,10 +60,11 @@ public extension Tensor {
6060
6161
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
6262
@inlinable
63+
@derivative(of: init(repeating:shape:))
6364
static func _vjpInit(
6465
repeating repeatedValue: __owned Scalar,
6566
shape: __owned TensorShape
66-
) -> (Tensor, (Tensor) -> Scalar) {
67+
) -> (value: Tensor, pullback: (Tensor) -> Scalar) {
6768
return (Tensor(repeating: repeatedValue, shape: shape), {
6869
$0.sum().scalarized()
6970
})
@@ -83,18 +84,18 @@ public extension Tensor where Scalar: Numeric {
8384
8485
/// Perform an element-wise conversion from another `Tensor`.
8586
@inlinable
86-
@differentiable(
87-
vjp: _vjpCast where Scalar: TensorFlowFloatingPoint, OtherScalar: TensorFlowFloatingPoint)
87+
@differentiable(where Scalar: TensorFlowFloatingPoint, OtherScalar: TensorFlowFloatingPoint)
8888
init<OtherScalar: Numeric>(_ other: Tensor<OtherScalar>) {
8989
self = _Raw.cast(other)
9090
}
9191
}
9292
9393
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
9494
@inlinable
95+
@derivative(of: init(_:))
9596
static func _vjpCast<OtherScalar: TensorFlowFloatingPoint>(
9697
_ other: __owned Tensor<OtherScalar>
97-
) -> (Tensor, (Tensor) -> Tensor<OtherScalar>) {
98+
) -> (value: Tensor, pullback: (Tensor) -> Tensor<OtherScalar>) {
9899
(Tensor(other), { v in Tensor<OtherScalar>(v) })
99100
}
100101
}
@@ -106,7 +107,7 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
106107
public extension Tensor {
107108
/// Creates a tensor from an array of tensors (which may themselves be scalars).
108109
@inlinable
109-
@differentiable(vjp: _vjpInitElements where Scalar: TensorFlowFloatingPoint)
110+
@differentiable(where Scalar: TensorFlowFloatingPoint)
110111
init(_ elements: [Tensor]) {
111112
self = _Raw.pack(elements)
112113
}
@@ -140,7 +141,7 @@ public extension Tensor {
140141
///
141142
/// - Returns: The stacked tensor.
142143
@inlinable
143-
@differentiable(vjp: _vjpStacking where Scalar: TensorFlowFloatingPoint)
144+
@differentiable(where Scalar: TensorFlowFloatingPoint)
144145
init(stacking tensors: [Tensor], alongAxis axis: Int = 0) {
145146
self = _Raw.pack(tensors, axis: Int64(axis))
146147
}
@@ -178,7 +179,7 @@ public extension Tensor {
178179
///
179180
/// - Returns: The concatenated tensor.
180181
@inlinable
181-
@differentiable(vjp: _vjpConcatenating where Scalar: TensorFlowFloatingPoint)
182+
@differentiable(where Scalar: TensorFlowFloatingPoint)
182183
init(concatenating tensors: [Tensor], alongAxis axis: Int = 0) {
183184
precondition(tensors.count > 0)
184185
self = _Raw.concatV2(tensors, axis: Tensor<Int32>(Int32(axis)))
@@ -187,27 +188,30 @@ public extension Tensor {
187188

188189
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
189190
@inlinable
191+
@derivative(of: init(_:))
190192
static func _vjpInitElements(
191193
_ elements: __owned [Tensor]
192-
) -> (Tensor, (Tensor) -> Array<Tensor>.DifferentiableView) {
194+
) -> (value: Tensor, pullback: (Tensor) -> Array<Tensor>.DifferentiableView) {
193195
_vjpStacking(stacking: elements)
194196
}
195197

196198
@inlinable
199+
@derivative(of: init(stacking:alongAxis:))
197200
static func _vjpStacking(
198201
stacking tensors: __owned [Tensor],
199202
alongAxis axis: __owned Int = 0
200-
) -> (Tensor, (Tensor) -> Array<Tensor>.DifferentiableView) {
203+
) -> (value: Tensor, pullback: (Tensor) -> Array<Tensor>.DifferentiableView) {
201204
(Tensor(stacking: tensors, alongAxis: axis), { v in
202205
Array<Tensor>.DifferentiableView(v.unstacked(alongAxis: axis))
203206
})
204207
}
205208

206209
@inlinable
210+
@derivative(of: init(concatenating:alongAxis:))
207211
static func _vjpConcatenating(
208212
concatenating tensors: __owned [Tensor],
209213
alongAxis axis: __owned Int = 0
210-
) -> (Tensor, (Tensor) -> Array<Tensor>.DifferentiableView) {
214+
) -> (value: Tensor, pullback: (Tensor) -> Array<Tensor>.DifferentiableView) {
211215
let result = Tensor<Scalar>(concatenating: tensors, alongAxis: axis)
212216
let posAxis = axis < 0 ? axis + tensors[0].rank : axis
213217
let sizes = Tensor<Int32>(stacking: tensors.map { $0.shapeTensor[posAxis] })

Sources/TensorFlow/Layers/Recurrent.swift

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ public struct RNN<Cell: RNNCell>: Layer {
344344
self.cell = cell()
345345
}
346346

347-
@differentiable(wrt: (self, inputs), vjp: _vjpCallAsFunction(_:initialState:))
347+
@differentiable(wrt: (self, inputs))
348348
public func callAsFunction(
349349
_ inputs: [Cell.TimeStepInput],
350350
initialState: Cell.State
@@ -369,12 +369,15 @@ public struct RNN<Cell: RNNCell>: Layer {
369369
}
370370

371371
@usableFromInline
372+
@derivative(of: callAsFunction, wrt: (self, inputs))
372373
internal func _vjpCallAsFunction(
373374
_ inputs: [Cell.TimeStepInput],
374375
initialState: Cell.State
375-
) -> ([Cell.TimeStepOutput],
376-
(Array<Cell.TimeStepOutput>.TangentVector)
377-
-> (TangentVector, Array<Cell.TimeStepInput>.TangentVector)) {
376+
) -> (
377+
value: [Cell.TimeStepOutput],
378+
pullback: (Array<Cell.TimeStepOutput>.TangentVector)
379+
-> (TangentVector, Array<Cell.TimeStepInput>.TangentVector)
380+
) {
378381
let timeStepCount = inputs.count
379382
var currentHiddenState = cell.zeroState(for: inputs[0])
380383
var timeStepOutputs: [Cell.TimeStepOutput] = []

Sources/TensorFlow/Layers/Upsampling.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ public struct UpSampling3D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer
7979
/// Repeats the elements of a tensor along an axis, like `np.repeat`.
8080
/// Function adapted from `def repeat_elements`:
8181
/// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/backend.py
82-
@differentiable(vjp: _vjpRepeatingElements)
82+
@differentiable
8383
private func repeatingElements(
8484
_ input: Tensor<Scalar>, alongAxis axis: Int, count: Int
8585
) -> Tensor<Scalar> {
@@ -91,9 +91,10 @@ public struct UpSampling3D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer
9191
return Tensor<Scalar>(concatenating: repeated, alongAxis: axis)
9292
}
9393

94+
@derivative(of: repeatingElements)
9495
private func _vjpRepeatingElements(
9596
_ input: Tensor<Scalar>, alongAxis axis: Int, count: Int
96-
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (TangentVector, Tensor<Scalar>)) {
97+
) -> (value: Tensor<Scalar>, pullback: (Tensor<Scalar>) -> (TangentVector, Tensor<Scalar>)) {
9798
let value = repeatingElements(input, alongAxis: axis, count: count)
9899
return (value, { v in
99100
let splits = _Raw.split(

Sources/TensorFlow/Loss.swift

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ public func softmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
210210
}
211211

212212
@inlinable
213-
@differentiable(wrt: logits, vjp: _vjpSoftmaxCrossEntropyHelper(logits:labels:))
213+
@differentiable(wrt: logits)
214214
func softmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
215215
logits: Tensor<Scalar>,
216216
labels: Tensor<Int32>
@@ -219,10 +219,11 @@ func softmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
219219
}
220220

221221
@inlinable
222+
@derivative(of: softmaxCrossEntropyHelper(logits:labels:))
222223
func _vjpSoftmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
223224
logits: Tensor<Scalar>,
224225
labels: Tensor<Int32>
225-
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> Tensor<Scalar>) {
226+
) -> (value: Tensor<Scalar>, pullback: (Tensor<Scalar>) -> Tensor<Scalar>) {
226227
let (loss, grad) = _Raw.sparseSoftmaxCrossEntropyWithLogits(features: logits, labels: labels)
227228
return (loss, { $0.expandingShape(at: -1) * grad })
228229
}
@@ -244,7 +245,7 @@ public func softmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
244245
}
245246

246247
@inlinable
247-
@differentiable(wrt: logits, vjp: _vjpSoftmaxCrossEntropyHelper(logits:probabilities:))
248+
@differentiable(wrt: logits)
248249
func softmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
249250
logits: Tensor<Scalar>,
250251
probabilities: Tensor<Scalar>
@@ -253,10 +254,11 @@ func softmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
253254
}
254255

255256
@inlinable
257+
@derivative(of: softmaxCrossEntropyHelper(logits:probabilities:), wrt: logits)
256258
func _vjpSoftmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
257259
logits: Tensor<Scalar>,
258260
probabilities: Tensor<Scalar>
259-
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> Tensor<Scalar>) {
261+
) -> (value: Tensor<Scalar>, pullback: (Tensor<Scalar>) -> Tensor<Scalar>) {
260262
let (loss, grad) = _Raw.softmaxCrossEntropyWithLogits(features: logits, labels: probabilities)
261263
return (loss, { $0.expandingShape(at: -1) * grad })
262264
}

0 commit comments

Comments
 (0)