Skip to content

[AutoDiff upstream] Add stdlib Differentiable conformances and derivatives. #30875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1132,9 +1132,15 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
// Apply `Array.TangentVector.subscript.getter` to get array element adjoint
// buffer.
auto &ctx = builder.getASTContext();
// %index_literal = integer_literal $Builtin.Int64, <index>
auto *eltIndexLiteral = builder.createIntegerLiteral(
loc, SILType::getBuiltinIntegerType(64, ctx), eltIndex);
// %index_literal = integer_literal $Builtin.IntXX, <index>
auto builtinIntType =
SILType::getPrimitiveObjectType(ctx.getIntDecl()
->getStoredProperties()
.front()
->getInterfaceType()
->getCanonicalType());
auto *eltIndexLiteral =
builder.createIntegerLiteral(loc, builtinIntType, eltIndex);
auto intType = SILType::getPrimitiveObjectType(
ctx.getIntDecl()->getDeclaredType()->getCanonicalType());
// %index_int = struct $Int (%index_literal)
Expand Down
332 changes: 332 additions & 0 deletions stdlib/public/Differentiation/ArrayDifferentiation.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
//===--- ArrayDifferentiation.swift ---------------------------*- swift -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

import Swift

//===----------------------------------------------------------------------===//
// Protocol conformances
//===----------------------------------------------------------------------===//

// TODO(TF-938): Add `Element: Differentiable` requirement.
extension Array {
/// The view of an array as the differentiable product manifold of `Element`
/// multiplied with itself `count` times.
@frozen
public struct DifferentiableView {
var _base: [Element]
}
}

extension Array.DifferentiableView: Differentiable
where Element: Differentiable {
/// The viewed array.
public var base: [Element] {
get { return _base }
_modify { yield &_base }
}

@usableFromInline
@derivative(of: base)
func _vjpBase() -> (
value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector
) {
return (base, { $0 })
}

/// Creates a differentiable view of the given array.
public init(_ base: [Element]) { self._base = base }

@usableFromInline
@derivative(of: init(_:))
static func _vjpInit(_ base: [Element]) -> (
value: Array.DifferentiableView, pullback: (TangentVector) -> TangentVector
) {
return (Array.DifferentiableView(base), { $0 })
}

public typealias TangentVector =
Array<Element.TangentVector>.DifferentiableView

public mutating func move(along direction: TangentVector) {
precondition(
base.count == direction.base.count,
"cannot move Array.DifferentiableView with count \(base.count) along "
+ "direction with different count \(direction.base.count)")
for i in base.indices {
base[i].move(along: direction.base[i])
}
}
}

extension Array.DifferentiableView: Equatable
where Element: Differentiable & Equatable {
public static func == (
lhs: Array.DifferentiableView,
rhs: Array.DifferentiableView
) -> Bool {
return lhs.base == rhs.base
}
}

extension Array.DifferentiableView: ExpressibleByArrayLiteral
where Element: Differentiable {
public init(arrayLiteral elements: Element...) {
self.init(elements)
}
}

extension Array.DifferentiableView: CustomStringConvertible
where Element: Differentiable {
public var description: String {
return base.description
}
}

/// Makes `Array.DifferentiableView` additive as the product space.
///
/// Note that `Array.DifferentiableView([])` is the zero in the product spaces
/// of all counts.
extension Array.DifferentiableView: AdditiveArithmetic
where Element: AdditiveArithmetic & Differentiable {

public static var zero: Array.DifferentiableView {
return Array.DifferentiableView([])
}

public static func + (
lhs: Array.DifferentiableView,
rhs: Array.DifferentiableView
) -> Array.DifferentiableView {
precondition(
lhs.base.count == 0 || rhs.base.count == 0
|| lhs.base.count == rhs.base.count,
"cannot add Array.DifferentiableViews with different counts: "
+ "\(lhs.base.count) and \(rhs.base.count)")
if lhs.base.count == 0 {
return rhs
}
if rhs.base.count == 0 {
return lhs
}
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(+))
}

public static func - (
lhs: Array.DifferentiableView,
rhs: Array.DifferentiableView
) -> Array.DifferentiableView {
precondition(
lhs.base.count == 0 || rhs.base.count == 0
|| lhs.base.count == rhs.base.count,
"cannot subtract Array.DifferentiableViews with different counts: "
+ "\(lhs.base.count) and \(rhs.base.count)")
if lhs.base.count == 0 {
return rhs
}
if rhs.base.count == 0 {
return lhs
}
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(-))
}

@inlinable
public subscript(_ index: Int) -> Element {
if index < base.count {
return base[index]
} else {
return Element.zero
}
}
}

/// Makes `Array` differentiable as the product manifold of `Element`
/// multiplied with itself `count` times.
extension Array: Differentiable where Element: Differentiable {
// In an ideal world, `TangentVector` would be `[Element.TangentVector]`.
// Unfortunately, we cannot conform `Array` to `AdditiveArithmetic` for
// `TangentVector` because `Array` already has a static `+` method with
// different semantics from `AdditiveArithmetic.+`. So we use
// `Array.DifferentiableView` for all these associated types.
public typealias TangentVector =
Array<Element.TangentVector>.DifferentiableView

public mutating func move(along direction: TangentVector) {
var view = DifferentiableView(self)
view.move(along: direction)
self = view.base
}
}

//===----------------------------------------------------------------------===//
// Derivatives
//===----------------------------------------------------------------------===//

extension Array where Element: Differentiable {
@usableFromInline
@derivative(of: subscript)
func _vjpSubscript(index: Int) -> (
value: Element, pullback: (Element.TangentVector) -> TangentVector
) {
func pullback(_ gradientIn: Element.TangentVector) -> TangentVector {
var gradientOut = [Element.TangentVector](
repeating: .zero,
count: count)
gradientOut[index] = gradientIn
return TangentVector(gradientOut)
}
return (self[index], pullback)
}

@usableFromInline
@derivative(of: +)
static func _vjpConcatenate(_ lhs: [Element], _ rhs: [Element]) -> (
value: [Element],
pullback: (TangentVector) -> (TangentVector, TangentVector)
) {
func pullback(_ gradientIn: TangentVector) -> (TangentVector, TangentVector)
{
precondition(
gradientIn.base.count == lhs.count + rhs.count,
"+ should receive gradient with count equal to sum of operand "
+ "counts, but counts are: gradient \(gradientIn.base.count), "
+ "lhs \(lhs.count), rhs \(rhs.count)")
return (
TangentVector(
[Element.TangentVector](
gradientIn.base[0..<lhs.count])),
TangentVector(
[Element.TangentVector](
gradientIn.base[lhs.count...]))
)
}
return (lhs + rhs, pullback)
}
}

extension Array where Element: Differentiable {
@usableFromInline
@derivative(of: append)
mutating func _vjpAppend(_ element: Element) -> (
value: Void, pullback: (inout TangentVector) -> Element.TangentVector
) {
let appendedElementIndex = count
defer { append(element) }
return ((), { dself in dself.base[appendedElementIndex] })
}

@usableFromInline
@derivative(of: append)
mutating func _jvpAppend(_ element: Element) -> (
value: Void,
differential: (inout TangentVector, Element.TangentVector) -> Void
) {
append(element)
return ((), { $0.base.append($1) })
}
}

extension Array where Element: Differentiable {
@usableFromInline
@derivative(of: init(repeating:count:))
static func _vjpInit(repeating repeatedValue: Element, count: Int) -> (
value: Self, pullback: (TangentVector) -> Element.TangentVector
) {
(
value: Self(repeating: repeatedValue, count: count),
pullback: { v in
v.base.reduce(.zero, +)
}
)
}
}

//===----------------------------------------------------------------------===//
// Differentiable higher order functions for collections
//===----------------------------------------------------------------------===//

extension Array where Element: Differentiable {
@differentiable(wrt: (self, initialResult))
public func differentiableReduce<Result: Differentiable>(
_ initialResult: Result,
_ nextPartialResult: @differentiable (Result, Element) -> Result
) -> Result {
reduce(initialResult, nextPartialResult)
}

@usableFromInline
@derivative(of: differentiableReduce)
internal func _vjpDifferentiableReduce<Result: Differentiable>(
_ initialResult: Result,
_ nextPartialResult: @differentiable (Result, Element) -> Result
) -> (
value: Result,
pullback: (Result.TangentVector)
-> (Array.TangentVector, Result.TangentVector)
) {
var pullbacks:
[(Result.TangentVector) -> (Result.TangentVector, Element.TangentVector)] =
[]
let count = self.count
pullbacks.reserveCapacity(count)
var result = initialResult
for element in self {
let (y, pb) =
valueWithPullback(at: result, element, in: nextPartialResult)
result = y
pullbacks.append(pb)
}
return (
value: result,
pullback: { tangent in
var resultTangent = tangent
var elementTangents = TangentVector([])
elementTangents.base.reserveCapacity(count)
for pullback in pullbacks.reversed() {
let (newResultTangent, elementTangent) = pullback(resultTangent)
resultTangent = newResultTangent
elementTangents.base.append(elementTangent)
}
return (TangentVector(elementTangents.base.reversed()), resultTangent)
}
)
}
}

extension Array where Element: Differentiable {
@differentiable(wrt: self)
public func differentiableMap<Result: Differentiable>(
_ body: @differentiable (Element) -> Result
) -> [Result] {
map(body)
}

@usableFromInline
@derivative(of: differentiableMap)
internal func _vjpDifferentiableMap<Result: Differentiable>(
_ body: @differentiable (Element) -> Result
) -> (
value: [Result],
pullback: (Array<Result>.TangentVector) -> Array.TangentVector
) {
var values: [Result] = []
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
for x in self {
let (y, pb) = valueWithPullback(at: x, in: body)
values.append(y)
pullbacks.append(pb)
}
func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector {
.init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) })
}
return (value: values, pullback: pullback)
}
}
18 changes: 17 additions & 1 deletion stdlib/public/Differentiation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source file is part of the Swift.org open source project
#
# Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
# Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
# Licensed under Apache License v2.0 with Runtime Library Exception
#
# See https://swift.org/LICENSE.txt for license information
Expand All @@ -14,6 +14,22 @@ add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPE
Differentiable.swift
DifferentialOperators.swift
DifferentiationUtilities.swift
ArrayDifferentiation.swift

GYB_SOURCES
FloatingPointDifferentiation.swift.gyb
TgmathDerivatives.swift.gyb
SIMDDifferentiation.swift.gyb

SWIFT_MODULE_DEPENDS_OSX Darwin
SWIFT_MODULE_DEPENDS_IOS Darwin
SWIFT_MODULE_DEPENDS_TVOS Darwin
SWIFT_MODULE_DEPENDS_WATCHOS Darwin
SWIFT_MODULE_DEPENDS_LINUX Glibc
SWIFT_MODULE_DEPENDS_FREEBSD Glibc
SWIFT_MODULE_DEPENDS_CYGWIN Glibc
SWIFT_MODULE_DEPENDS_HAIKU Glibc
SWIFT_MODULE_DEPENDS_WINDOWS MSVCRT

SWIFT_COMPILE_FLAGS
${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}
Expand Down
Loading