Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

[AutoDiff] Use @derivative for derivative registration. #591

Merged
merged 1 commit into from
Dec 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions Sources/TensorFlow/Core/Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public extension Tensor {
/// Reshape to scalar.
/// - Precondition: The tensor has exactly one scalar.
@inlinable
@differentiable(wrt: self, vjp: _vjpScalarized where Scalar: TensorFlowFloatingPoint)
@differentiable(where Scalar: TensorFlowFloatingPoint)
func scalarized() -> Scalar {
precondition(shape.contiguousSize == 1,
"This tensor must have exactly one scalar but contains \(shape.contiguousSize).")
Expand All @@ -135,7 +135,8 @@ public extension Tensor {

internal extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
func _vjpScalarized() -> (Scalar, (Scalar) -> Tensor) {
@derivative(of: scalarized)
func _vjpScalarized() -> (value: Scalar, pullback: (Scalar) -> Tensor) {
return (scalarized(), { v in Tensor(v) })
}
}
Expand All @@ -162,14 +163,15 @@ public extension Tensor {
}

@inlinable
@differentiable(vjp: _vjpScalars where Scalar: TensorFlowFloatingPoint)
@differentiable(where Scalar: TensorFlowFloatingPoint)
var scalars: [Scalar] {
return array.scalars
}
}

extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
@derivative(of: scalars)
func _vjpScalars() -> (value: [Scalar], pullback: (Array<Scalar>.TangentVector) -> Tensor) {
(value: scalars, pullback: { [shape = self.shape, device = self.device] v in
Tensor(shape: shape, scalars: v.base, on: device)
Expand All @@ -184,24 +186,26 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
public extension Tensor {
/// Creates a 0-D tensor from a scalar value.
@inlinable
@differentiable(vjp: _vjpScalarInit where Scalar: TensorFlowFloatingPoint)
@differentiable(where Scalar: TensorFlowFloatingPoint)
init(_ value: Scalar, on device: Device = Device.getDefault) {
self.init(shape: [], scalars: [value], on: device)
}
}

internal extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
static func _vjpScalarInit(_ value: __owned Scalar, on device: Device = Device.getDefault
) -> (Tensor, (Tensor) -> Scalar) {
@derivative(of: init(_:on:))
static func _vjpScalarInit(_ value: __owned Scalar, on device: Device = Device.getDefault) -> (
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dan-zheng Perhaps it would make sense to rename _vjpXXX functions to _derivativeXXX?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly! The best name is func _ (a private anonymous function), which is used in the differentiable programming manifesto but not yet implemented.

When linear functions and transposition are done, users will register only primitive JVP functions returning a @differentiable(linear) differential function, and primitive VJP functions will be removed. At that point, @derivative functions will unambiguously be JVP functions, so derivativeXXX names become appropriate and vjpXXX names become obsolete.

value: Tensor, pullback: (Tensor) -> Scalar
) {
return (Tensor(value, on: device), { $0.scalarized() })
}
}

public extension Tensor {
/// Creates a 1D tensor from scalars.
@inlinable
@differentiable(vjp: _vjpInit(_:on:) where Scalar: TensorFlowFloatingPoint)
@differentiable(where Scalar: TensorFlowFloatingPoint)
init(_ scalars: [Scalar], on device: Device = Device.getDefault) {
self.init(shape: [scalars.count], scalars: scalars, on: device)
}
Expand Down Expand Up @@ -230,7 +234,7 @@ public extension Tensor {
/// - scalars: The scalar contents of the tensor.
/// - Precondition: The product of the dimensions of the shape must equal the number of scalars.
@inlinable
@differentiable(vjp: _vjpInit(shape:scalars:on:) where Scalar: TensorFlowFloatingPoint)
@differentiable(where Scalar: TensorFlowFloatingPoint)
init(shape: TensorShape, scalars: [Scalar], on device: Device = Device.getDefault) {
precondition(shape.contiguousSize == scalars.count,
"""
Expand Down Expand Up @@ -297,6 +301,7 @@ public extension Tensor {

extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
@derivative(of: init(_:on:))
static func _vjpInit(_ scalars: [Scalar], on device: Device = Device.getDefault) -> (
value: Tensor, pullback: (Tensor) -> Array<Scalar>.TangentVector
) {
Expand All @@ -306,6 +311,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
}

@inlinable
@derivative(of: init(shape:scalars:on:))
static func _vjpInit(
shape: TensorShape, scalars: [Scalar], on device: Device = Device.getDefault
) -> (value: Tensor, pullback: (Tensor) -> Array<Scalar>.TangentVector) {
Expand Down Expand Up @@ -542,23 +548,26 @@ extension Tensor: AdditiveArithmetic where Scalar: Numeric {
/// Adds two tensors and produces their sum.
/// - Note: `+` supports broadcasting.
@inlinable
@differentiable(vjp: _vjpAdd(lhs:rhs:) where Scalar: TensorFlowFloatingPoint)
@differentiable(where Scalar: TensorFlowFloatingPoint)
public static func + (lhs: Tensor, rhs: Tensor) -> Tensor {
_Raw.addV2(lhs, rhs)
}

/// Subtracts one tensor from another and produces their difference.
/// - Note: `-` supports broadcasting.
@inlinable
@differentiable(vjp: _vjpSubtract(lhs:rhs:) where Scalar: TensorFlowFloatingPoint)
@differentiable(where Scalar: TensorFlowFloatingPoint)
public static func - (lhs: Tensor, rhs: Tensor) -> Tensor {
_Raw.sub(lhs, rhs)
}
}

internal extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
static func _vjpAdd(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
@derivative(of: +)
static func _vjpAdd(lhs: Tensor, rhs: Tensor) -> (
value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)
) {
(lhs + rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
let lhsGrad = v
let rhsGrad = lhsGrad
Expand All @@ -569,7 +578,10 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
}

@inlinable
static func _vjpSubtract(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
@derivative(of: -)
static func _vjpSubtract(lhs: Tensor, rhs: Tensor) -> (
value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)
) {
(lhs - rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
let lhsGrad = v
let rhsGrad = -lhsGrad
Expand Down
3 changes: 2 additions & 1 deletion Sources/TensorFlow/Freezable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ public struct _Freezable<Value: Differentiable> {
}

/// The wrapped differentiable value.
@differentiable(vjp: _vjpValue)
@differentiable
public var wrappedValue: Value {
get { _value }
set { _value = newValue }
}

@usableFromInline
@derivative(of: wrappedValue)
func _vjpValue() -> (value: Value, pullback: (Value.TangentVector) -> TangentVector) {
return (_value, { [isFrozen = self.isFrozen] v in
isFrozen ? .zero : v
Expand Down
26 changes: 15 additions & 11 deletions Sources/TensorFlow/Initializers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public extension Tensor {
/// - repeatedValue: The scalar value to repeat.
/// - shape: The dimensions of the tensor.
@inlinable
@differentiable(vjp: _vjpInit(repeating:shape:) where Scalar: TensorFlowFloatingPoint)
@differentiable(where Scalar: TensorFlowFloatingPoint)
init(repeating repeatedValue: Scalar, shape: TensorShape) {
self = _Raw.fill(
dims: Tensor<Int32>(shape.dimensions.map(Int32.init)),
Expand All @@ -60,10 +60,11 @@ public extension Tensor {

internal extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
@derivative(of: init(repeating:shape:))
static func _vjpInit(
repeating repeatedValue: __owned Scalar,
shape: __owned TensorShape
) -> (Tensor, (Tensor) -> Scalar) {
) -> (value: Tensor, pullback: (Tensor) -> Scalar) {
return (Tensor(repeating: repeatedValue, shape: shape), {
$0.sum().scalarized()
})
Expand All @@ -83,18 +84,18 @@ public extension Tensor where Scalar: Numeric {

/// Perform an element-wise conversion from another `Tensor`.
@inlinable
@differentiable(
vjp: _vjpCast where Scalar: TensorFlowFloatingPoint, OtherScalar: TensorFlowFloatingPoint)
@differentiable(where Scalar: TensorFlowFloatingPoint, OtherScalar: TensorFlowFloatingPoint)
init<OtherScalar: Numeric>(_ other: Tensor<OtherScalar>) {
self = _Raw.cast(other)
}
}

internal extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
@derivative(of: init(_:))
static func _vjpCast<OtherScalar: TensorFlowFloatingPoint>(
_ other: __owned Tensor<OtherScalar>
) -> (Tensor, (Tensor) -> Tensor<OtherScalar>) {
) -> (value: Tensor, pullback: (Tensor) -> Tensor<OtherScalar>) {
(Tensor(other), { v in Tensor<OtherScalar>(v) })
}
}
Expand All @@ -106,7 +107,7 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
public extension Tensor {
/// Creates a tensor from an array of tensors (which may themselves be scalars).
@inlinable
@differentiable(vjp: _vjpInitElements where Scalar: TensorFlowFloatingPoint)
@differentiable(where Scalar: TensorFlowFloatingPoint)
init(_ elements: [Tensor]) {
self = _Raw.pack(elements)
}
Expand Down Expand Up @@ -140,7 +141,7 @@ public extension Tensor {
///
/// - Returns: The stacked tensor.
@inlinable
@differentiable(vjp: _vjpStacking where Scalar: TensorFlowFloatingPoint)
@differentiable(where Scalar: TensorFlowFloatingPoint)
init(stacking tensors: [Tensor], alongAxis axis: Int = 0) {
self = _Raw.pack(tensors, axis: Int64(axis))
}
Expand Down Expand Up @@ -178,7 +179,7 @@ public extension Tensor {
///
/// - Returns: The concatenated tensor.
@inlinable
@differentiable(vjp: _vjpConcatenating where Scalar: TensorFlowFloatingPoint)
@differentiable(where Scalar: TensorFlowFloatingPoint)
init(concatenating tensors: [Tensor], alongAxis axis: Int = 0) {
precondition(tensors.count > 0)
self = _Raw.concatV2(tensors, axis: Tensor<Int32>(Int32(axis)))
Expand All @@ -187,27 +188,30 @@ public extension Tensor {

internal extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
@derivative(of: init(_:))
static func _vjpInitElements(
_ elements: __owned [Tensor]
) -> (Tensor, (Tensor) -> Array<Tensor>.DifferentiableView) {
) -> (value: Tensor, pullback: (Tensor) -> Array<Tensor>.DifferentiableView) {
_vjpStacking(stacking: elements)
}

@inlinable
@derivative(of: init(stacking:alongAxis:))
static func _vjpStacking(
stacking tensors: __owned [Tensor],
alongAxis axis: __owned Int = 0
) -> (Tensor, (Tensor) -> Array<Tensor>.DifferentiableView) {
) -> (value: Tensor, pullback: (Tensor) -> Array<Tensor>.DifferentiableView) {
(Tensor(stacking: tensors, alongAxis: axis), { v in
Array<Tensor>.DifferentiableView(v.unstacked(alongAxis: axis))
})
}

@inlinable
@derivative(of: init(concatenating:alongAxis:))
static func _vjpConcatenating(
concatenating tensors: __owned [Tensor],
alongAxis axis: __owned Int = 0
) -> (Tensor, (Tensor) -> Array<Tensor>.DifferentiableView) {
) -> (value: Tensor, pullback: (Tensor) -> Array<Tensor>.DifferentiableView) {
let result = Tensor<Scalar>(concatenating: tensors, alongAxis: axis)
let posAxis = axis < 0 ? axis + tensors[0].rank : axis
let sizes = Tensor<Int32>(stacking: tensors.map { $0.shapeTensor[posAxis] })
Expand Down
11 changes: 7 additions & 4 deletions Sources/TensorFlow/Layers/Recurrent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ public struct RNN<Cell: RNNCell>: Layer {
self.cell = cell()
}

@differentiable(wrt: (self, inputs), vjp: _vjpCallAsFunction(_:initialState:))
@differentiable(wrt: (self, inputs))
public func callAsFunction(
_ inputs: [Cell.TimeStepInput],
initialState: Cell.State
Expand All @@ -369,12 +369,15 @@ public struct RNN<Cell: RNNCell>: Layer {
}

@usableFromInline
@derivative(of: callAsFunction, wrt: (self, inputs))
internal func _vjpCallAsFunction(
_ inputs: [Cell.TimeStepInput],
initialState: Cell.State
) -> ([Cell.TimeStepOutput],
(Array<Cell.TimeStepOutput>.TangentVector)
-> (TangentVector, Array<Cell.TimeStepInput>.TangentVector)) {
) -> (
value: [Cell.TimeStepOutput],
pullback: (Array<Cell.TimeStepOutput>.TangentVector)
-> (TangentVector, Array<Cell.TimeStepInput>.TangentVector)
) {
let timeStepCount = inputs.count
var currentHiddenState = cell.zeroState(for: inputs[0])
var timeStepOutputs: [Cell.TimeStepOutput] = []
Expand Down
5 changes: 3 additions & 2 deletions Sources/TensorFlow/Layers/Upsampling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public struct UpSampling3D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer
/// Repeats the elements of a tensor along an axis, like `np.repeat`.
/// Function adapted from `def repeat_elements`:
/// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/backend.py
@differentiable(vjp: _vjpRepeatingElements)
@differentiable
private func repeatingElements(
_ input: Tensor<Scalar>, alongAxis axis: Int, count: Int
) -> Tensor<Scalar> {
Expand All @@ -91,9 +91,10 @@ public struct UpSampling3D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer
return Tensor<Scalar>(concatenating: repeated, alongAxis: axis)
}

@derivative(of: repeatingElements)
private func _vjpRepeatingElements(
_ input: Tensor<Scalar>, alongAxis axis: Int, count: Int
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> (TangentVector, Tensor<Scalar>)) {
) -> (value: Tensor<Scalar>, pullback: (Tensor<Scalar>) -> (TangentVector, Tensor<Scalar>)) {
let value = repeatingElements(input, alongAxis: axis, count: count)
return (value, { v in
let splits = _Raw.split(
Expand Down
10 changes: 6 additions & 4 deletions Sources/TensorFlow/Loss.swift
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ public func softmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
}

@inlinable
@differentiable(wrt: logits, vjp: _vjpSoftmaxCrossEntropyHelper(logits:labels:))
@differentiable(wrt: logits)
func softmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
logits: Tensor<Scalar>,
labels: Tensor<Int32>
Expand All @@ -219,10 +219,11 @@ func softmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
}

@inlinable
@derivative(of: softmaxCrossEntropyHelper(logits:labels:))
func _vjpSoftmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
logits: Tensor<Scalar>,
labels: Tensor<Int32>
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> Tensor<Scalar>) {
) -> (value: Tensor<Scalar>, pullback: (Tensor<Scalar>) -> Tensor<Scalar>) {
let (loss, grad) = _Raw.sparseSoftmaxCrossEntropyWithLogits(features: logits, labels: labels)
return (loss, { $0.expandingShape(at: -1) * grad })
}
Expand All @@ -244,7 +245,7 @@ public func softmaxCrossEntropy<Scalar: TensorFlowFloatingPoint>(
}

@inlinable
@differentiable(wrt: logits, vjp: _vjpSoftmaxCrossEntropyHelper(logits:probabilities:))
@differentiable(wrt: logits)
func softmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
logits: Tensor<Scalar>,
probabilities: Tensor<Scalar>
Expand All @@ -253,10 +254,11 @@ func softmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
}

@inlinable
@derivative(of: softmaxCrossEntropyHelper(logits:probabilities:), wrt: logits)
func _vjpSoftmaxCrossEntropyHelper<Scalar: TensorFlowFloatingPoint>(
logits: Tensor<Scalar>,
probabilities: Tensor<Scalar>
) -> (Tensor<Scalar>, (Tensor<Scalar>) -> Tensor<Scalar>) {
) -> (value: Tensor<Scalar>, pullback: (Tensor<Scalar>) -> Tensor<Scalar>) {
let (loss, grad) = _Raw.softmaxCrossEntropyWithLogits(features: logits, labels: probabilities)
return (loss, { $0.expandingShape(at: -1) * grad })
}
Expand Down
Loading