Skip to content

Commit 8d39256

Browse files
authored
[AutoDiff] Support differentiation of switch_enum. (#25509)
Handle `switch_enum` terminator during VJP and adjoint generation. Necessary step for differentiating `for-in` loops, which contain optional iterator `next()` values. Diagnose differentiation of active enum values, which requires further adjoint generation support.
1 parent 95b85a9 commit 8d39256

File tree

7 files changed

+304
-13
lines changed

7 files changed

+304
-13
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,8 @@ WARNING(autodiff_nonvaried_result_fixit,none,
424424
"result does not depend on differentiation arguments and will always "
425425
"have a zero derivative; do you want to add '.withoutDerivative()'?",
426426
())
427+
NOTE(autodiff_enums_unsupported,none,
428+
"differentiating enum values is not yet supported", ())
427429
NOTE(autodiff_global_let_closure_not_differentiable,none,
428430
"global constant closure is not differentiable", ())
429431
NOTE(autodiff_cannot_differentiate_global_var_closures,none,

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,15 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
15331533
setVaried(cbi->getFalseBB()->getArgument(opIdx), i);
15341534
}
15351535
}
1536+
// Handle `switch_enum`.
1537+
else if (auto *sei = dyn_cast<SwitchEnumInst>(&inst)) {
1538+
if (isVaried(sei->getOperand(), i)) {
1539+
for (auto *succBB : sei->getSuccessorBlocks())
1540+
for (auto *arg : succBB->getArguments())
1541+
setVaried(arg, i);
1542+
// Default block cannot have arguments.
1543+
}
1544+
}
15361545
// Handle everything else.
15371546
else {
15381547
for (auto &op : inst.getAllOperands())
@@ -1767,8 +1776,9 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context,
17671776
// Diagnose unsupported branching terminators.
17681777
for (auto &bb : *original) {
17691778
auto *term = bb.getTerminator();
1770-
// Supported terminators are: `br`, `cond_br`.
1771-
if (isa<BranchInst>(term) || isa<CondBranchInst>(term))
1779+
// Supported terminators are: `br`, `cond_br`, `switch_enum`.
1780+
if (isa<BranchInst>(term) || isa<CondBranchInst>(term) ||
1781+
isa<SwitchEnumInst>(term))
17721782
continue;
17731783
// If terminator is an unsupported branching terminator, emit an error.
17741784
if (term->isBranch()) {
@@ -3134,6 +3144,56 @@ class VJPEmitter final
31343144
getOpBasicBlock(cbi->getFalseBB()), falseArgs);
31353145
}
31363146

3147+
void visitSwitchEnumInst(SwitchEnumInst *sei) {
3148+
// Build pullback struct value for original block.
3149+
auto *origBB = sei->getParent();
3150+
auto *pbStructVal = buildPullbackValueStructValue(sei);
3151+
3152+
// Creates a trampoline block for given original successor block. The
3153+
// trampoline block has the same arguments as the VJP successor block but
3154+
// drops the last predecessor enum argument. The generated `switch_enum`
3155+
// instruction branches to the trampoline block, and the trampoline block
3156+
// constructs a predecessor enum value and branches to the VJP successor
3157+
// block.
3158+
auto createTrampolineBasicBlock =
3159+
[&](SILBasicBlock *origSuccBB) -> SILBasicBlock * {
3160+
auto *vjpSuccBB = getOpBasicBlock(origSuccBB);
3161+
// Create the trampoline block.
3162+
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
3163+
for (auto *arg : vjpSuccBB->getArguments().drop_back())
3164+
trampolineBB->createPhiArgument(arg->getType(),
3165+
arg->getOwnershipKind());
3166+
// Build predecessor enum value for successor block and branch to it.
3167+
SILBuilder trampolineBuilder(trampolineBB);
3168+
auto *succEnumVal = buildPredecessorEnumValue(
3169+
trampolineBuilder, origBB, origSuccBB, pbStructVal);
3170+
SmallVector<SILValue, 4> forwardedArguments(
3171+
trampolineBB->getArguments().begin(),
3172+
trampolineBB->getArguments().end());
3173+
forwardedArguments.push_back(succEnumVal);
3174+
trampolineBuilder.createBranch(
3175+
sei->getLoc(), vjpSuccBB, forwardedArguments);
3176+
return trampolineBB;
3177+
};
3178+
3179+
// Create trampoline successor basic blocks.
3180+
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs;
3181+
for (unsigned i : range(sei->getNumCases())) {
3182+
auto caseBB = sei->getCase(i);
3183+
auto *trampolineBB = createTrampolineBasicBlock(caseBB.second);
3184+
caseBBs.push_back({caseBB.first, trampolineBB});
3185+
}
3186+
// Create trampoline default basic block.
3187+
SILBasicBlock *newDefaultBB = nullptr;
3188+
if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull())
3189+
newDefaultBB = createTrampolineBasicBlock(defaultBB);
3190+
3191+
// Create a new `switch_enum` instruction.
3192+
getBuilder().createSwitchEnum(
3193+
sei->getLoc(), getOpValue(sei->getOperand()),
3194+
newDefaultBB, caseBBs);
3195+
}
3196+
31373197
// If an `apply` has active results or active inout parameters, replace it
31383198
// with an `apply` of its VJP.
31393199
void visitApplyInst(ApplyInst *ai) {
@@ -4155,6 +4215,13 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41554215
auto addActiveValue = [&](SILValue v) {
41564216
if (visited.count(v))
41574217
return;
4218+
// Diagnose active enum values. Differentiation of enum values is not
4219+
// yet supported; requires special adjoint handling.
4220+
if (v->getType().getEnumOrBoundGenericEnum()) {
4221+
getContext().emitNondifferentiabilityError(
4222+
v, getInvoker(), diag::autodiff_enums_unsupported);
4223+
errorOccurred = true;
4224+
}
41584225
// Skip address projections.
41594226
// Address projections do not need their own adjoint buffers; they
41604227
// become projections into their adjoint base buffer.
@@ -4175,8 +4242,12 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41754242
if (getActivityInfo().isActive(result, getIndices()))
41764243
addActiveValue(result);
41774244
}
4245+
if (errorOccurred)
4246+
break;
41784247
domOrder.pushChildren(bb);
41794248
}
4249+
if (errorOccurred)
4250+
return true;
41804251

41814252
// Create adjoint blocks and arguments, visiting original blocks in
41824253
// post-order.
@@ -4196,7 +4267,6 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41964267
adjointPullbackStructArguments[origBB] = lastArg;
41974268
continue;
41984269
}
4199-
42004270
// Get all active values in the original block.
42014271
// If the original block has no active values, continue.
42024272
auto &bbActiveValues = activeValues[origBB];
@@ -4421,7 +4491,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
44214491
getPullbackInfo().lookUpPredecessorEnumElement(predBB, bb);
44224492
adjointSuccessorCases.push_back({enumEltDecl, adjointSuccBB});
44234493
}
4424-
// Emit clenaups for all block-local adjoint values.
4494+
// Emit cleanups for all block-local adjoint values.
44254495
for (auto adjVal : blockLocalAdjointValues)
44264496
emitCleanupForAdjointValue(adjVal);
44274497
blockLocalAdjointValues.clear();

stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,16 @@ extension Tracked where T : Differentiable, T == T.AllDifferentiableVariables,
171171
}
172172
}
173173

174+
extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
175+
T == T.AllDifferentiableVariables, T == T.TangentVector {
176+
@usableFromInline
177+
@differentiating(*)
178+
internal static func _vjpMultiply(lhs: Self, rhs: Self)
179+
-> (value: Self, pullback: (Self) -> (Self, Self)) {
180+
return (lhs * rhs, { v in (v * rhs, v * lhs) })
181+
}
182+
}
183+
174184
// Differential operators for `Tracked<Float>`.
175185
public extension Differentiable {
176186
@inlinable

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ _ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) {
7575
func uses_optionals(_ x: Float) -> Float {
7676
var maybe: Float? = 10
7777
maybe = x
78-
// expected-note @+1 {{differentiating control flow is not yet supported}}
78+
// expected-note @+1 {{differentiating enum values is not yet supported}}
7979
return maybe!
8080
}
8181

test/AutoDiff/control_flow.swift

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ ControlFlowTests.test("Conditionals") {
258258
}
259259
expectEqual((0, 10), gradient(at: 4, 5, in: guard3))
260260
expectCrash {
261-
gradient(at: -3, -2, in: guard3)
261+
_ = gradient(at: -3, -2, in: guard3)
262262
}
263263

264264
func cond_empty(_ x: Float) -> Float {
@@ -424,4 +424,99 @@ ControlFlowTests.test("Recursion") {
424424
expectEqual(1, gradient(at: 100, in: { x in product(x, count: 1) }))
425425
}
426426

427+
ControlFlowTests.test("Enums") {
428+
enum Enum {
429+
case a(Float)
430+
case b(Float, Float)
431+
432+
func enum_notactive1(_ x: Float) -> Float {
433+
switch self {
434+
case let .a(a): return x * a
435+
case let .b(b1, b2): return x * b1 * b2
436+
}
437+
}
438+
}
439+
440+
func enum_notactive1(_ e: Enum, _ x: Float) -> Float {
441+
switch e {
442+
case let .a(a): return x * a
443+
case let .b(b1, b2): return x * b1 * b2
444+
}
445+
}
446+
expectEqual(10, gradient(at: 2, in: { x in enum_notactive1(.a(10), x) }))
447+
expectEqual(10, gradient(at: 2, in: { x in Enum.a(10).enum_notactive1(x) }))
448+
expectEqual(20, gradient(at: 2, in: { x in enum_notactive1(.b(4, 5), x) }))
449+
expectEqual(20, gradient(at: 2, in: { x in Enum.b(4, 5).enum_notactive1(x) }))
450+
451+
func enum_notactive2(_ e: Enum, _ x: Float) -> Float {
452+
var y = x
453+
if x > 0 {
454+
var z = y + y
455+
switch e {
456+
case .a: z = z - y
457+
case .b: y = y + x
458+
}
459+
var w = y
460+
if case .a = e {
461+
w = w + z
462+
}
463+
return w
464+
} else if case .b = e {
465+
return y + y
466+
}
467+
return x + y
468+
}
469+
expectEqual((8, 2), valueWithGradient(at: 4, in: { x in enum_notactive2(.a(10), x) }))
470+
expectEqual((20, 2), valueWithGradient(at: 10, in: { x in enum_notactive2(.b(4, 5), x) }))
471+
expectEqual((-20, 2), valueWithGradient(at: -10, in: { x in enum_notactive2(.a(10), x) }))
472+
expectEqual((-2674, 2), valueWithGradient(at: -1337, in: { x in enum_notactive2(.b(4, 5), x) }))
473+
474+
func optional_notactive1(_ optional: Float?, _ x: Float) -> Float {
475+
if let y = optional {
476+
return x * y
477+
}
478+
return x + x
479+
}
480+
expectEqual(2, gradient(at: 2, in: { x in optional_notactive1(nil, x) }))
481+
expectEqual(10, gradient(at: 2, in: { x in optional_notactive1(10, x) }))
482+
483+
struct Dense : Differentiable {
484+
var w1: Float
485+
@noDerivative var w2: Float?
486+
487+
@differentiable
488+
func callAsFunction(_ input: Float) -> Float {
489+
if let w2 = w2 {
490+
return input * w1 * w2
491+
}
492+
return input * w1
493+
}
494+
}
495+
expectEqual((Dense.AllDifferentiableVariables(w1: 10), 20),
496+
Dense(w1: 4, w2: 5).gradient(at: 2, in: { dense, x in dense(x) }))
497+
expectEqual((Dense.AllDifferentiableVariables(w1: 2), 4),
498+
Dense(w1: 4, w2: nil).gradient(at: 2, in: { dense, x in dense(x) }))
499+
500+
indirect enum Indirect {
501+
case e(Float, Enum)
502+
case indirect(Indirect)
503+
}
504+
505+
func enum_indirect_notactive1(_ indirect: Indirect, _ x: Float) -> Float {
506+
switch indirect {
507+
case let .e(f, e):
508+
switch e {
509+
case .a: return x * f * enum_notactive1(e, x)
510+
case .b: return x * f * enum_notactive1(e, x)
511+
}
512+
case let .indirect(ind): return enum_indirect_notactive1(ind, x)
513+
}
514+
}
515+
do {
516+
let ind: Indirect = .e(10, .a(3))
517+
expectEqual(120, gradient(at: 2, in: { x in enum_indirect_notactive1(ind, x) }))
518+
expectEqual(120, gradient(at: 2, in: { x in enum_indirect_notactive1(.indirect(ind), x) }))
519+
}
520+
}
521+
427522
runAllTests()

test/AutoDiff/control_flow_diagnostics.swift

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: %target-swift-frontend -emit-sil -verify %s
22

3-
// Test supported `br` and `cond_br` terminators.
3+
// Test supported `br`, `cond_br`, and `switch_enum` terminators.
44

55
@differentiable
66
func branch(_ x: Float) -> Float {
@@ -12,21 +12,82 @@ func branch(_ x: Float) -> Float {
1212
return x
1313
}
1414

15-
// Test currently unsupported `switch_enum` terminator.
16-
1715
enum Enum {
1816
case a(Float)
1917
case b(Float)
2018
}
2119

20+
@differentiable
21+
func enum_nonactive1(_ e: Enum, _ x: Float) -> Float {
22+
switch e {
23+
case .a: return x
24+
case .b: return x
25+
}
26+
}
27+
28+
@differentiable
29+
func enum_nonactive2(_ e: Enum, _ x: Float) -> Float {
30+
switch e {
31+
case let .a(a): return x + a
32+
case let .b(b): return x + b
33+
}
34+
}
35+
36+
// Test unsupported differentiation of active enum values.
37+
2238
// expected-error @+1 {{function is not differentiable}}
2339
@differentiable
2440
// expected-note @+1 {{when differentiating this function definition}}
25-
func switch_enum(_ e: Enum, _ x: Float) -> Float {
26-
// expected-note @+1 {{differentiating control flow is not yet supported}}
41+
func enum_active(_ x: Float) -> Float {
42+
// expected-note @+1 {{differentiating enum values is not yet supported}}
43+
let e: Enum
44+
if x > 0 {
45+
e = .a(x)
46+
} else {
47+
e = .b(x)
48+
}
2749
switch e {
28-
case let .a(a): return a
29-
case let .b(b): return b
50+
case let .a(a): return x + a
51+
case let .b(b): return x + b
52+
}
53+
}
54+
55+
enum Tree : Differentiable & AdditiveArithmetic {
56+
case leaf(Float)
57+
case branch(Float, Float)
58+
59+
typealias TangentVector = Self
60+
typealias AllDifferentiableVariables = Self
61+
static var zero: Self { .leaf(0) }
62+
63+
// expected-error @+1 {{function is not differentiable}}
64+
@differentiable
65+
// expected-note @+2 {{when differentiating this function definition}}
66+
// expected-note @+1 {{differentiating enum values is not yet supported}}
67+
static func +(_ lhs: Self, _ rhs: Self) -> Self {
68+
switch (lhs, rhs) {
69+
case let (.leaf(x), .leaf(y)):
70+
return .leaf(x + y)
71+
case let (.branch(x1, x2), .branch(y1, y2)):
72+
return .branch(x1 + x2, y1 + y2)
73+
default:
74+
fatalError()
75+
}
76+
}
77+
78+
// expected-error @+1 {{function is not differentiable}}
79+
@differentiable
80+
// expected-note @+2 {{when differentiating this function definition}}
81+
// expected-note @+1 {{differentiating enum values is not yet supported}}
82+
static func -(_ lhs: Self, _ rhs: Self) -> Self {
83+
switch (lhs, rhs) {
84+
case let (.leaf(x), .leaf(y)):
85+
return .leaf(x - y)
86+
case let (.branch(x1, x2), .branch(y1, y2)):
87+
return .branch(x1 - x2, y1 - y2)
88+
default:
89+
fatalError()
90+
}
3091
}
3192
}
3293

0 commit comments

Comments
 (0)