Skip to content

[AutoDiff] Support differentiation of loops. #25558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 161 additions & 126 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp

Large diffs are not rendered by default.

93 changes: 93 additions & 0 deletions lib/SILOptimizer/Mandatory/Differentiation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//===--- Differentiation.h - SIL Automatic Differentiation ----*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// SWIFT_ENABLE_TENSORFLOW
//
// Reverse-mode automatic differentiation utilities.
//
// NOTE: Although the AD feature is developed as part of the Swift for
// TensorFlow project, it is completely independent from TensorFlow support.
//
// TODO: Move definitions here from Differentiation.cpp.
//
//===----------------------------------------------------------------------===//

#ifndef SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H
#define SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H

#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
#include "swift/SILOptimizer/Utils/Local.h"

namespace swift {

using llvm::DenseMap;
using llvm::SmallDenseMap;
using llvm::SmallDenseSet;
using llvm::SmallMapVector;
using llvm::SmallSet;

/// Helper class for visiting basic blocks in post-order post-dominance order,
/// based on a worklist algorithm.
class PostOrderPostDominanceOrder {
SmallVector<DominanceInfoNode *, 16> buffer;
PostOrderFunctionInfo *postOrderInfo;
size_t srcIdx = 0;

public:
/// Constructor.
/// \p root The root of the post-dominator tree.
/// \p postOrderInfo The post-order info of the function.
/// \p capacity Should be the number of basic blocks in the dominator tree to
/// reduce memory allocation.
PostOrderPostDominanceOrder(DominanceInfoNode *root,
PostOrderFunctionInfo *postOrderInfo,
int capacity = 0)
: postOrderInfo(postOrderInfo) {
buffer.reserve(capacity);
buffer.push_back(root);
}

/// Get the next block from the worklist.
DominanceInfoNode *getNext() {
if (srcIdx == buffer.size())
return nullptr;
return buffer[srcIdx++];
}

/// Pushes the dominator children of a block onto the worklist in post-order.
void pushChildren(DominanceInfoNode *node) {
pushChildrenIf(node, [](SILBasicBlock *) { return true; });
}

/// Conditionally pushes the dominator children of a block onto the worklist
/// in post-order.
template <typename Pred>
void pushChildrenIf(DominanceInfoNode *node, Pred pred) {
SmallVector<DominanceInfoNode *, 4> children;
for (auto *child : *node)
children.push_back(child);
llvm::sort(children.begin(), children.end(),
[&](DominanceInfoNode *n1, DominanceInfoNode *n2) {
return postOrderInfo->getPONumber(n1->getBlock()) <
postOrderInfo->getPONumber(n2->getBlock());
});
for (auto *child : children) {
SILBasicBlock *childBB = child->getBlock();
if (pred(childBB))
buffer.push_back(child);
}
}
};

} // end namespace swift

#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_H
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,17 @@ extension Tracked : SignedNumeric & Numeric where T : SignedNumeric, T == T.Magn
}

public static func *= (lhs: inout Tracked, rhs: Tracked) {
lhs = Tracked(lhs.value * rhs.value)
lhs = lhs * rhs
}
}

extension Tracked where T : FloatingPoint {
public static func / (lhs: Tracked, rhs: Tracked) -> Tracked {
return Tracked(lhs.value / rhs.value)
}

public static func /= (lhs: inout Tracked, rhs: Tracked) {
lhs = lhs / rhs
}
}

Expand Down Expand Up @@ -181,6 +191,16 @@ extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
}
}

extension Tracked where T : Differentiable & FloatingPoint,
T == T.AllDifferentiableVariables, T == T.TangentVector {
@usableFromInline
@differentiating(/)
internal static func _vjpDivide(lhs: Self, rhs: Self)
-> (value: Self, pullback: (Self) -> (Self, Self)) {
return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) })
}
}

// Differential operators for `Tracked<Float>`.
public extension Differentiable {
@inlinable
Expand Down
108 changes: 108 additions & 0 deletions test/AutoDiff/control_flow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -519,4 +519,112 @@ ControlFlowTests.test("Enums") {
}
}

ControlFlowTests.test("Loops") {
func for_loop(_ x: Float) -> Float {
var result = x
for _ in 1..<3 {
result = result * x
}
return result
}
expectEqual((8, 12), valueWithGradient(at: 2, in: for_loop))
expectEqual((27, 27), valueWithGradient(at: 3, in: for_loop))

func while_loop(_ x: Float) -> Float {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also test a repeat-while loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests for repeat-while loops and continue/break statements in 32f6629.
Exposed TF-584: derivative computation bug for repeat-while loops. Thanks for pointing this out!

var result = x
var i = 1
while i < 3 {
result = result * x
i += 1
}
return result
}
expectEqual((8, 12), valueWithGradient(at: 2, in: while_loop))
expectEqual((27, 27), valueWithGradient(at: 3, in: while_loop))

func repeat_while_loop(_ x: Float) -> Float {
var result = x
var i = 1
repeat {
result = result * x
i += 1
} while i < 3
return result
}
// FIXME(TF-584): Investigate incorrect (too big) gradient values
// for repeat-while loops.
// expectEqual((8, 12), valueWithGradient(at: 2, in: repeat_while_loop))
// expectEqual((27, 27), valueWithGradient(at: 3, in: repeat_while_loop))
expectEqual((8, 18), valueWithGradient(at: 2, in: repeat_while_loop))
expectEqual((27, 36), valueWithGradient(at: 3, in: repeat_while_loop))

func loop_continue(_ x: Float) -> Float {
var result = x
for i in 1..<10 {
if i.isMultiple(of: 2) {
continue
}
result = result * x
}
return result
}
expectEqual((64, 192), valueWithGradient(at: 2, in: loop_continue))
expectEqual((729, 1458), valueWithGradient(at: 3, in: loop_continue))

func loop_break(_ x: Float) -> Float {
var result = x
for i in 1..<10 {
if i.isMultiple(of: 2) {
continue
}
result = result * x
}
return result
}
expectEqual((64, 192), valueWithGradient(at: 2, in: loop_break))
expectEqual((729, 1458), valueWithGradient(at: 3, in: loop_break))

func nested_loop1(_ x: Float) -> Float {
var outer = x
for _ in 1..<3 {
outer = outer * x

var inner = outer
var i = 1
while i < 3 {
inner = inner + x
i += 1
}
outer = inner
}
return outer
}
expectEqual((20, 22), valueWithGradient(at: 2, in: nested_loop1))
expectEqual((104, 66), valueWithGradient(at: 4, in: nested_loop1))

func nested_loop2(_ x: Float, count: Int) -> Float {
var outer = x
outerLoop: for _ in 1..<count {
outer = outer * x

var inner = outer
var i = 1
while i < count {
inner = inner + x
i += 1

switch Int(inner.truncatingRemainder(dividingBy: 7)) {
case 0: break outerLoop
case 1: break
default: continue
}
}
outer = inner
}
return outer
}
expectEqual((24, 12), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 5) }))
expectEqual((16, 8), valueWithGradient(at: 4, in: { x in nested_loop2(x, count: 5) }))
}

runAllTests()
63 changes: 56 additions & 7 deletions test/AutoDiff/control_flow_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,57 @@ func enum_nonactive2(_ e: Enum, _ x: Float) -> Float {
}
}

// Test loops.

@differentiable
func for_loop(_ x: Float) -> Float {
var result: Float = x
for _ in 0..<3 {
result = result * x
}
return result
}

@differentiable
func while_loop(_ x: Float) -> Float {
var result = x
var i = 1
while i < 3 {
result = result * x
i += 1
}
return result
}

@differentiable
func nested_loop(_ x: Float) -> Float {
var outer = x
for _ in 1..<3 {
outer = outer * x

var inner = outer
var i = 1
while i < 3 {
inner = inner / x
i += 1
}
outer = inner
}
return outer
}

// Test `try_apply`.

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func withoutDerivative<T : Differentiable, R: Differentiable>(
at x: T, in body: (T) throws -> R
) rethrows -> R {
// expected-note @+1 {{differentiating control flow is not yet supported}}
try body(x)
}

// Test unsupported differentiation of active enum values.

// expected-error @+1 {{function is not differentiable}}
Expand Down Expand Up @@ -91,16 +142,14 @@ enum Tree : Differentiable & AdditiveArithmetic {
}
}

// Test loops.

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func loop(_ x: Float) -> Float {
func loop_array(_ array: [Float]) -> Float {
var result: Float = 1
// expected-note @+1 {{differentiating loops is not yet supported}}
for _ in 0..<3 {
result += x
// expected-note @+1 {{differentiating enum values is not yet supported}}
for x in array {
result = result * x
}
return x
return result
}
36 changes: 25 additions & 11 deletions test/AutoDiff/control_flow_sil.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,31 @@ func cond(_ x: Float) -> Float {
return x - x
}

// CHECK-DATA-STRUCTURES: enum _AD__cond_bb0__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb0__PB__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb1__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb1__PB__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb1__Pred__src_0_wrt_0 { get set }
// CHECK-DATA-STRUCTURES: @_hasStorage var pullback_0: (Float) -> (Float, Float) { get set }
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb2__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb2__PB__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb2__Pred__src_0_wrt_0 { get set }
// CHECK-DATA-STRUCTURES: @_hasStorage var pullback_1: (Float) -> (Float, Float) { get set }
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb3__PB__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set }
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb0__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb1__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb2__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb0(_AD__cond_bb0__PB__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__cond_bb3__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb2(_AD__cond_bb2__PB__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: case bb1(_AD__cond_bb1__PB__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: struct _AD__cond_bb3__PB__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: @_hasStorage var predecessor: _AD__cond_bb3__Pred__src_0_wrt_0 { get set }
// CHECK-DATA-STRUCTURES: }

// CHECK-SIL-LABEL: sil hidden @AD__cond__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK-SIL: bb0([[INPUT_ARG:%.*]] : $Float):
Expand Down Expand Up @@ -137,6 +137,20 @@ func nested_cond_generic<T : Differentiable & FloatingPoint>(_ x: T, _ y: T) ->
return y
}

@differentiable
@_silgen_name("loop_generic")
func loop_generic<T : Differentiable & FloatingPoint>(_ x: T) -> T {
var result = x
for _ in 1..<3 {
var y = x
for _ in 1..<3 {
result = y
y = result
}
}
return result
}

// Test control flow + tuple buffer.
// Verify that adjoint buffers are not allocated for address projections.

Expand Down
Loading