Skip to content

[AutoDiff] Support differentiation of switch_enum. #25509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 16, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@ WARNING(autodiff_nonvaried_result_fixit,none,
"result does not depend on differentiation arguments and will always "
"have a zero derivative; do you want to add '.withoutDerivative()'?",
())
NOTE(autodiff_enums_unsupported,none,
"differentiating enum values is not yet supported", ())
NOTE(autodiff_global_let_closure_not_differentiable,none,
"global constant closure is not differentiable", ())
NOTE(autodiff_cannot_differentiate_global_var_closures,none,
Expand Down
90 changes: 80 additions & 10 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,15 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
setVaried(cbi->getFalseBB()->getArgument(opIdx), i);
}
}
// Handle `switch_enum`.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems necessary to also handle switch_enum_addr, but I didn't find Swift functions containing switch_enum_addr that also don't contain active enum values. A quick searches shows that switch_enum_addr is generated during later SIL passes and also for optional force-unwrapping.

else if (auto *sei = dyn_cast<SwitchEnumInst>(&inst)) {
if (isVaried(sei->getOperand(), i)) {
for (auto *succBB : sei->getSuccessorBlocks())
for (auto *arg : succBB->getArguments())
setVaried(arg, i);
// Default block cannot have arguments.
}
}
// Handle everything else.
else {
for (auto &op : inst.getAllOperands())
Expand Down Expand Up @@ -1767,8 +1776,9 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context,
// Diagnose unsupported branching terminators.
for (auto &bb : *original) {
auto *term = bb.getTerminator();
// Supported terminators are: `br`, `cond_br`.
if (isa<BranchInst>(term) || isa<CondBranchInst>(term))
// Supported terminators are: `br`, `cond_br`, `switch_enum`.
if (isa<BranchInst>(term) || isa<CondBranchInst>(term) ||
isa<SwitchEnumInst>(term))
continue;
// If terminator is an unsupported branching terminator, emit an error.
if (term->isBranch()) {
Expand Down Expand Up @@ -3134,6 +3144,56 @@ class VJPEmitter final
getOpBasicBlock(cbi->getFalseBB()), falseArgs);
}

void visitSwitchEnumInst(SwitchEnumInst *sei) {
// Build pullback struct value for original block.
auto *origBB = sei->getParent();
auto *pbStructVal = buildPullbackValueStructValue(sei);

// Creates a trampoline block for given original successor block. The
// trampoline block has the same arguments as the VJP successor block but
// drops the last predecessor enum argument. The generated `switch_enum`
// instruction branches to the trampoline block, and the trampoline block
// constructs a predecessor enum value and branches to the VJP successor
// block.
auto createTrampolineBasicBlock =
[&](SILBasicBlock *origSuccBB) -> SILBasicBlock * {
auto *vjpSuccBB = getOpBasicBlock(origSuccBB);
// Create the trampoline block.
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
for (auto *arg : vjpSuccBB->getArguments().drop_back())
trampolineBB->createPhiArgument(arg->getType(),
arg->getOwnershipKind());
// Build predecessor enum value for successor block and branch to it.
SILBuilder trampolineBuilder(trampolineBB);
auto *succEnumVal = buildPredecessorEnumValue(
trampolineBuilder, origBB, origSuccBB, pbStructVal);
SmallVector<SILValue, 4> forwardedArguments(
trampolineBB->getArguments().begin(),
trampolineBB->getArguments().end());
forwardedArguments.push_back(succEnumVal);
trampolineBuilder.createBranch(
sei->getLoc(), vjpSuccBB, forwardedArguments);
return trampolineBB;
};

// Create trampoline successor basic blocks.
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs;
for (unsigned i : range(sei->getNumCases())) {
auto caseBB = sei->getCase(i);
auto *trampolineBB = createTrampolineBasicBlock(caseBB.second);
caseBBs.push_back({caseBB.first, trampolineBB});
}
// Create trampoline default basic block.
SILBasicBlock *newDefaultBB = nullptr;
if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull())
newDefaultBB = createTrampolineBasicBlock(defaultBB);

// Create a new `switch_enum` instruction.
getBuilder().createSwitchEnum(
sei->getLoc(), getOpValue(sei->getOperand()),
newDefaultBB, caseBBs);
}

// If an `apply` has active results or active inout parameters, replace it
// with an `apply` of its VJP.
void visitApplyInst(ApplyInst *ai) {
Expand Down Expand Up @@ -4155,6 +4215,13 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
auto addActiveValue = [&](SILValue v) {
if (visited.count(v))
return;
// Diagnose active enum values. Differentiation of enum values is not
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Diagnosing active enum values is necessary because adjoint generation doesn't propagate adjoint values of enum associated values correctly.

Support is non-trivial because switch_enum operand and successor block arguments have different types: the operand has an enum type but successor block arguments have associated values' type. Adjoint value propagation needs to construct enum adjoint value from associated values' adjoint values.

// yet supported; requires special adjoint handling.
if (v->getType().getEnumOrBoundGenericEnum()) {
getContext().emitNondifferentiabilityError(
v, getInvoker(), diag::autodiff_enums_unsupported);
errorOccurred = true;
}
// Skip address projections.
// Address projections do not need their own adjoint buffers; they
// become projections into their adjoint base buffer.
Expand All @@ -4175,8 +4242,12 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
if (getActivityInfo().isActive(result, getIndices()))
addActiveValue(result);
}
if (errorOccurred)
break;
domOrder.pushChildren(bb);
}
if (errorOccurred)
return true;

// Create adjoint blocks and arguments, visiting original blocks in
// post-order.
Expand All @@ -4196,7 +4267,10 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
adjointPullbackStructArguments[origBB] = lastArg;
continue;
}

// Add a pullback struct argument.
auto *pbStructArg = adjointBB->createPhiArgument(
pbStructLoweredType, ValueOwnershipKind::Guaranteed);
adjointPullbackStructArguments[origBB] = pbStructArg;
// Get all active values in the original block.
// If the original block has no active values, continue.
auto &bbActiveValues = activeValues[origBB];
Expand All @@ -4222,10 +4296,6 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
activeValueAdjointBBArgumentMap[{origBB, activeValue}] = adjointArg;
}
}
// Add a pullback struct argument.
auto *pbStructArg = adjointBB->createPhiArgument(
pbStructLoweredType, ValueOwnershipKind::Guaranteed);
adjointPullbackStructArguments[origBB] = pbStructArg;
// - Create adjoint trampoline blocks for each successor block of the
// original block. Adjoint trampoline blocks only have a pullback
// struct argument, and branch from the adjoint successor block to the
Expand Down Expand Up @@ -4373,6 +4443,8 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
assert(adjointSuccBB && adjointSuccBB->getNumArguments() == 1);
SILBuilder adjointTrampolineBBBuilder(adjointSuccBB);
SmallVector<SILValue, 8> trampolineArguments;
// Propagate pullback struct argument.
trampolineArguments.push_back(adjointSuccBB->getArguments().front());
// Propagate adjoint values/buffers of active values/buffers to
// predecessor blocks.
auto &predBBActiveValues = activeValues[predBB];
Expand Down Expand Up @@ -4411,8 +4483,6 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
adjLoc, adjBuf, predAdjBuf, IsNotTake, IsNotInitialization);
}
}
// Propagate pullback struct argument.
trampolineArguments.push_back(adjointSuccBB->getArguments().front());
// Branch from adjoint trampoline block to adjoint block.
adjointTrampolineBBBuilder.createBranch(
adjLoc, adjointBB, trampolineArguments);
Expand All @@ -4421,7 +4491,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
getPullbackInfo().lookUpPredecessorEnumElement(predBB, bb);
adjointSuccessorCases.push_back({enumEltDecl, adjointSuccBB});
}
// Emit clenaups for all block-local adjoint values.
// Emit cleanups for all block-local adjoint values.
for (auto adjVal : blockLocalAdjointValues)
emitCleanupForAdjointValue(adjVal);
blockLocalAdjointValues.clear();
Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/autodiff_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ _ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) {
func uses_optionals(_ x: Float) -> Float {
var maybe: Float? = 10
maybe = x
// expected-note @+1 {{differentiating control flow is not yet supported}}
// expected-note @+1 {{differentiating enum values is not yet supported}}
return maybe!
}

Expand Down
88 changes: 87 additions & 1 deletion test/AutoDiff/control_flow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ ControlFlowTests.test("Conditionals") {
}
expectEqual((0, 10), gradient(at: 4, 5, in: guard3))
expectCrash {
gradient(at: -3, -2, in: guard3)
_ = gradient(at: -3, -2, in: guard3)
}

func cond_empty(_ x: Float) -> Float {
Expand Down Expand Up @@ -424,4 +424,90 @@ ControlFlowTests.test("Recursion") {
expectEqual(1, gradient(at: 100, in: { x in product(x, count: 1) }))
}

ControlFlowTests.test("Enums") {
enum Enum {
case a(Float)
case b(Float, Float)

func enum_notactive1(_ x: Float) -> Float {
switch self {
case let .a(a): return x * a
case let .b(b1, b2): return x * b1 * b2
}
}
}

func enum_notactive1(_ e: Enum, _ x: Float) -> Float {
switch e {
case let .a(a): return x * a
case let .b(b1, b2): return x * b1 * b2
}
}
expectEqual(10, gradient(at: 2, in: { x in enum_notactive1(.a(10), x) }))
expectEqual(10, gradient(at: 2, in: { x in Enum.a(10).enum_notactive1(x) }))
expectEqual(20, gradient(at: 2, in: { x in enum_notactive1(.b(4, 5), x) }))
expectEqual(20, gradient(at: 2, in: { x in Enum.b(4, 5).enum_notactive1(x) }))

func enum_notactive2(_ e: Enum, _ x: Float) -> Float {
if x > 0 {
switch e {
case .a: return x * x * x
case .b: return -x
}
} else if case .b = e {
return -x
}
return x * x
}
expectEqual(12, gradient(at: 2, in: { x in enum_notactive2(.a(10), x) }))
expectEqual(-1, gradient(at: 2, in: { x in enum_notactive2(.b(4, 5), x) }))

func optional_notactive1(_ optional: Float?, _ x: Float) -> Float {
if let y = optional {
return x * y
}
return x + x
}
expectEqual(2, gradient(at: 2, in: { x in optional_notactive1(nil, x) }))
expectEqual(10, gradient(at: 2, in: { x in optional_notactive1(10, x) }))

struct Dense : Differentiable {
var w1: Float
@noDerivative var w2: Float?

@differentiable
func callAsFunction(_ input: Float) -> Float {
if let w2 = w2 {
return input * w1 * w2
}
return input * w1
}
}
expectEqual((Dense.AllDifferentiableVariables(w1: 10), 20),
Dense(w1: 4, w2: 5).gradient(at: 2, in: { dense, x in dense(x) }))
expectEqual((Dense.AllDifferentiableVariables(w1: 2), 4),
Dense(w1: 4, w2: nil).gradient(at: 2, in: { dense, x in dense(x) }))

indirect enum Indirect {
case e(Float, Enum)
case indirect(Indirect)
}

func enum_indirect_notactive1(_ indirect: Indirect, _ x: Float) -> Float {
switch indirect {
case let .e(f, e):
switch e {
case .a: return x * f * enum_notactive1(e, x)
case .b: return x * f * enum_notactive1(e, x)
}
case let .indirect(ind): return enum_indirect_notactive1(ind, x)
}
}
do {
let ind: Indirect = .e(10, .a(3))
expectEqual(120, gradient(at: 2, in: { x in enum_indirect_notactive1(ind, x) }))
expectEqual(120, gradient(at: 2, in: { x in enum_indirect_notactive1(.indirect(ind), x) }))
}
}

runAllTests()
75 changes: 68 additions & 7 deletions test/AutoDiff/control_flow_diagnostics.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: %target-swift-frontend -emit-sil -verify %s

// Test supported `br` and `cond_br` terminators.
// Test supported `br`, `cond_br`, and `switch_enum` terminators.

@differentiable
func branch(_ x: Float) -> Float {
Expand All @@ -12,21 +12,82 @@ func branch(_ x: Float) -> Float {
return x
}

// Test currently unsupported `switch_enum` terminator.

enum Enum {
case a(Float)
case b(Float)
}

@differentiable
func enum_nonactive1(_ e: Enum, _ x: Float) -> Float {
switch e {
case .a: return x
case .b: return x
}
}

@differentiable
func enum_nonactive2(_ e: Enum, _ x: Float) -> Float {
switch e {
case let .a(a): return x + a
case let .b(b): return x + b
}
}

// Test unsupported differentiation of active enum values.

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func switch_enum(_ e: Enum, _ x: Float) -> Float {
// expected-note @+1 {{differentiating control flow is not yet supported}}
func enum_active(_ x: Float) -> Float {
// expected-note @+1 {{differentiating enum values is not yet supported}}
let e: Enum
if x > 0 {
e = .a(x)
} else {
e = .b(x)
}
switch e {
case let .a(a): return a
case let .b(b): return b
case let .a(a): return x + a
case let .b(b): return x + b
}
}

enum Tree : Differentiable & AdditiveArithmetic {
case leaf(Float)
case branch(Float, Float)

typealias TangentVector = Self
typealias AllDifferentiableVariables = Self
static var zero: Self { .leaf(0) }

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+2 {{when differentiating this function definition}}
// expected-note @+1 {{differentiating enum values is not yet supported}}
static func +(_ lhs: Self, _ rhs: Self) -> Self {
switch (lhs, rhs) {
case let (.leaf(x), .leaf(y)):
return .leaf(x + y)
case let (.branch(x1, x2), .branch(y1, y2)):
return .branch(x1 + x2, y1 + y2)
default:
fatalError()
}
}

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+2 {{when differentiating this function definition}}
// expected-note @+1 {{differentiating enum values is not yet supported}}
static func -(_ lhs: Self, _ rhs: Self) -> Self {
switch (lhs, rhs) {
case let (.leaf(x), .leaf(y)):
return .leaf(x - y)
case let (.branch(x1, x2), .branch(y1, y2)):
return .branch(x1 - x2, y1 - y2)
default:
fatalError()
}
}
}

Expand Down
Loading