Skip to content

[AutoDiff] switch array and floating point derivatives to @derivative(of:) #29030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions stdlib/public/core/Array.swift
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,6 @@ extension Array: RandomAccessCollection, MutableCollection {
/// bridged `NSArray` instance as its storage, in which case writing is
/// O(*n*), where *n* is the length of the array.
@inlinable
// SWIFT_ENABLE_TENSORFLOW
@differentiable(wrt: self, vjp: _vjpSubscript where Element : Differentiable)
public subscript(index: Int) -> Element {
get {
// This call may be hoisted or eliminated by the optimizer. If
Expand Down Expand Up @@ -873,8 +871,6 @@ extension Array: RangeReplaceableCollection {
/// `repeating` parameter. `count` must be zero or greater.
@inlinable
@_semantics("array.init")
@differentiable(wrt: repeatedValue, vjp: _vjpInit(repeating:count:)
where Element: Differentiable)
public init(repeating repeatedValue: Element, count: Int) {
var p: UnsafeMutablePointer<Element>
(self, p) = Array._allocateUninitialized(count)
Expand Down Expand Up @@ -1330,8 +1326,6 @@ extension Array: RangeReplaceableCollection {
// operator in the same expression.
extension Array {
@inlinable
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpPlus where Element : Differentiable)
public static func + (lhs: Array, rhs: Array) -> Array {
var lhs = lhs
lhs.append(contentsOf: rhs)
Expand Down Expand Up @@ -1930,24 +1924,24 @@ extension Array {
extension Array.DifferentiableView : Differentiable where Element : Differentiable {
/// The viewed array.
public var base: [Element] {
@differentiable(wrt: self, vjp: _vjpBase)
get { return _base }
_modify { yield &_base }
}

@usableFromInline
@derivative(of: base)
func _vjpBase() ->
([Element], (Array<Element>.TangentVector) -> TangentVector) {
(value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector) {
return (base, { $0 })
}

/// Creates a differentiable view of the given array.
@differentiable(wrt: base, vjp: _vjpInit)
public init(_ base: [Element]) { self._base = base }

@usableFromInline
@derivative(of: init(_:))
static func _vjpInit(_ base: [Element]) ->
(Array.DifferentiableView, (TangentVector) -> TangentVector) {
(value: Array.DifferentiableView, pullback: (TangentVector) -> TangentVector) {
return (Array.DifferentiableView(base), { $0 })
}

Expand Down Expand Up @@ -2088,8 +2082,9 @@ extension Array : EuclideanDifferentiable
}

extension Array where Element : Differentiable {
@derivative(of: subscript)
public func _vjpSubscript(index: Int) ->
(Element, (Element.TangentVector) -> TangentVector)
(value: Element, pullback: (Element.TangentVector) -> TangentVector)
{
func pullback(_ gradientIn: Element.TangentVector) -> TangentVector {
var gradientOut = Array<Element.TangentVector>(
Expand All @@ -2101,8 +2096,9 @@ extension Array where Element : Differentiable {
return (self[index], pullback)
}

@derivative(of: +)
public static func _vjpPlus(_ lhs: [Element], _ rhs: [Element]) ->
([Element], (TangentVector) -> (TangentVector, TangentVector)) {
(value: [Element], pullback: (TangentVector) -> (TangentVector, TangentVector)) {
func pullback(_ gradientIn: TangentVector) ->
(TangentVector, TangentVector) {
precondition(
Expand All @@ -2122,6 +2118,7 @@ extension Array where Element : Differentiable {

extension Array where Element: Differentiable {
@usableFromInline
@derivative(of: init(repeating:count:))
static func _vjpInit(repeating repeatedValue: Element, count: Int) -> (
value: Self, pullback: (TangentVector) -> Element.TangentVector
) {
Expand Down
12 changes: 4 additions & 8 deletions stdlib/public/core/FloatingPoint.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1680,19 +1680,13 @@ extension FloatingPoint {
}

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(wrt: self, vjp: _vjpSquareRoot
where Self : Differentiable, Self == Self.TangentVector)
public func squareRoot( ) -> Self {
var lhs = self
lhs.formSquareRoot( )
return lhs
}

@_transparent
/// SWIFT_ENABLE_TENSORFLOW
@differentiable(wrt: (self, lhs, rhs), vjp: _vjpAddingProduct
where Self : Differentiable, Self == Self.TangentVector)
public func addingProduct(_ lhs: Self, _ rhs: Self) -> Self {
var addend = self
addend.addProduct(lhs, rhs)
Expand Down Expand Up @@ -1741,16 +1735,18 @@ extension FloatingPoint where Self : Differentiable,
/// original result and pullback of `addingProduct` with respect to `self`,
/// `lhs` and `rhs`.
@inlinable
@derivative(of: addingProduct)
func _vjpAddingProduct(
_ lhs: Self, _ rhs: Self
) -> (Self, (Self) -> (Self, Self, Self)) {
) -> (value: Self, pullback: (Self) -> (Self, Self, Self)) {
return (addingProduct(lhs, rhs), { _ in (1, rhs, lhs) })
}

/// The vector-Jacobian product function of `squareRoot`. Returns the original
/// result and pullback of `squareRoot` with respect to `self`.
@inlinable // FIXME(sil-serialize-all)
func _vjpSquareRoot() -> (Self, (Self) -> Self) {
@derivative(of: squareRoot)
func _vjpSquareRoot() -> (value: Self, pullback: (Self) -> Self) {
let y = squareRoot()
return (y, { v in v / (2 * y) })
}
Expand Down
33 changes: 33 additions & 0 deletions test/AutoDiff/floating_point.swift.gyb
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: %target-run-simple-swiftgyb
// REQUIRES: executable_test

import StdlibUnittest

var FloatingPointTests = TestSuite("FloatingPoint")

%for Self in ['Float', 'Double', 'Float80']:
% if Self == 'Float80':
#if !os(Windows) && (arch(i386) || arch(x86_64))
% end

FloatingPointTests.test("${Self}.squareRoot") {
expectEqual(${Self}(0.5), gradient(at: ${Self}(1), in: { $0.squareRoot() }))
expectEqual(${Self}(0.25), gradient(at: ${Self}(4), in: { $0.squareRoot() }))
}

FloatingPointTests.test("${Self}.addingProduct") {
expectEqual(
(${Self}(1), ${Self}(2), ${Self}(3)),
gradient(
at: ${Self}(10), ${Self}(3), ${Self}(2),
in: { $0.addingProduct($1, $2) }
)
)
}

% if Self == 'Float80':
#endif
% end
% end

runAllTests()