-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Support differentiation of switch_enum
.
#25509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1533,6 +1533,15 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, | |
setVaried(cbi->getFalseBB()->getArgument(opIdx), i); | ||
} | ||
} | ||
// Handle `switch_enum`. | ||
else if (auto *sei = dyn_cast<SwitchEnumInst>(&inst)) { | ||
if (isVaried(sei->getOperand(), i)) { | ||
for (auto *succBB : sei->getSuccessorBlocks()) | ||
for (auto *arg : succBB->getArguments()) | ||
setVaried(arg, i); | ||
// Default block cannot have arguments. | ||
} | ||
} | ||
// Handle everything else. | ||
else { | ||
for (auto &op : inst.getAllOperands()) | ||
|
@@ -1767,8 +1776,9 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context, | |
// Diagnose unsupported branching terminators. | ||
for (auto &bb : *original) { | ||
auto *term = bb.getTerminator(); | ||
// Supported terminators are: `br`, `cond_br`. | ||
if (isa<BranchInst>(term) || isa<CondBranchInst>(term)) | ||
// Supported terminators are: `br`, `cond_br`, `switch_enum`. | ||
if (isa<BranchInst>(term) || isa<CondBranchInst>(term) || | ||
isa<SwitchEnumInst>(term)) | ||
continue; | ||
// If terminator is an unsupported branching terminator, emit an error. | ||
if (term->isBranch()) { | ||
|
@@ -3134,6 +3144,56 @@ class VJPEmitter final | |
getOpBasicBlock(cbi->getFalseBB()), falseArgs); | ||
} | ||
|
||
void visitSwitchEnumInst(SwitchEnumInst *sei) { | ||
// Build pullback struct value for original block. | ||
auto *origBB = sei->getParent(); | ||
auto *pbStructVal = buildPullbackValueStructValue(sei); | ||
|
||
// Creates a trampoline block for given original successor block. The | ||
// trampoline block has the same arguments as the VJP successor block but | ||
// drops the last predecessor enum argument. The generated `switch_enum` | ||
// instruction branches to the trampoline block, and the trampoline block | ||
// constructs a predecessor enum value and branches to the VJP successor | ||
// block. | ||
auto createTrampolineBasicBlock = | ||
[&](SILBasicBlock *origSuccBB) -> SILBasicBlock * { | ||
auto *vjpSuccBB = getOpBasicBlock(origSuccBB); | ||
// Create the trampoline block. | ||
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); | ||
for (auto *arg : vjpSuccBB->getArguments().drop_back()) | ||
trampolineBB->createPhiArgument(arg->getType(), | ||
arg->getOwnershipKind()); | ||
// Build predecessor enum value for successor block and branch to it. | ||
SILBuilder trampolineBuilder(trampolineBB); | ||
auto *succEnumVal = buildPredecessorEnumValue( | ||
trampolineBuilder, origBB, origSuccBB, pbStructVal); | ||
SmallVector<SILValue, 4> forwardedArguments( | ||
trampolineBB->getArguments().begin(), | ||
trampolineBB->getArguments().end()); | ||
forwardedArguments.push_back(succEnumVal); | ||
trampolineBuilder.createBranch( | ||
sei->getLoc(), vjpSuccBB, forwardedArguments); | ||
return trampolineBB; | ||
}; | ||
|
||
// Create trampoline successor basic blocks. | ||
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs; | ||
for (unsigned i : range(sei->getNumCases())) { | ||
auto caseBB = sei->getCase(i); | ||
auto *trampolineBB = createTrampolineBasicBlock(caseBB.second); | ||
caseBBs.push_back({caseBB.first, trampolineBB}); | ||
} | ||
// Create trampoline default basic block. | ||
SILBasicBlock *newDefaultBB = nullptr; | ||
if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull()) | ||
newDefaultBB = createTrampolineBasicBlock(defaultBB); | ||
|
||
// Create a new `switch_enum` instruction. | ||
getBuilder().createSwitchEnum( | ||
sei->getLoc(), getOpValue(sei->getOperand()), | ||
newDefaultBB, caseBBs); | ||
} | ||
|
||
// If an `apply` has active results or active inout parameters, replace it | ||
// with an `apply` of its VJP. | ||
void visitApplyInst(ApplyInst *ai) { | ||
|
@@ -4155,6 +4215,13 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> { | |
auto addActiveValue = [&](SILValue v) { | ||
if (visited.count(v)) | ||
return; | ||
// Diagnose active enum values. Differentiation of enum values is not | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Diagnosing active enum values is necessary because adjoint generation doesn't propagate adjoint values of enum associated values correctly. Support is non-trivial because |
||
// yet supported; requires special adjoint handling. | ||
if (v->getType().getEnumOrBoundGenericEnum()) { | ||
getContext().emitNondifferentiabilityError( | ||
v, getInvoker(), diag::autodiff_enums_unsupported); | ||
errorOccurred = true; | ||
} | ||
// Skip address projections. | ||
// Address projections do not need their own adjoint buffers; they | ||
// become projections into their adjoint base buffer. | ||
|
@@ -4175,8 +4242,12 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> { | |
if (getActivityInfo().isActive(result, getIndices())) | ||
addActiveValue(result); | ||
} | ||
if (errorOccurred) | ||
break; | ||
domOrder.pushChildren(bb); | ||
} | ||
if (errorOccurred) | ||
return true; | ||
|
||
// Create adjoint blocks and arguments, visiting original blocks in | ||
// post-order. | ||
|
@@ -4196,7 +4267,10 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> { | |
adjointPullbackStructArguments[origBB] = lastArg; | ||
continue; | ||
} | ||
|
||
// Add a pullback struct argument. | ||
auto *pbStructArg = adjointBB->createPhiArgument( | ||
dan-zheng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pbStructLoweredType, ValueOwnershipKind::Guaranteed); | ||
adjointPullbackStructArguments[origBB] = pbStructArg; | ||
// Get all active values in the original block. | ||
// If the original block has no active values, continue. | ||
auto &bbActiveValues = activeValues[origBB]; | ||
|
@@ -4222,10 +4296,6 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> { | |
activeValueAdjointBBArgumentMap[{origBB, activeValue}] = adjointArg; | ||
} | ||
} | ||
// Add a pullback struct argument. | ||
auto *pbStructArg = adjointBB->createPhiArgument( | ||
pbStructLoweredType, ValueOwnershipKind::Guaranteed); | ||
adjointPullbackStructArguments[origBB] = pbStructArg; | ||
// - Create adjoint trampoline blocks for each successor block of the | ||
// original block. Adjoint trampoline blocks only have a pullback | ||
// struct argument, and branch from the adjoint successor block to the | ||
|
@@ -4373,6 +4443,8 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> { | |
assert(adjointSuccBB && adjointSuccBB->getNumArguments() == 1); | ||
SILBuilder adjointTrampolineBBBuilder(adjointSuccBB); | ||
SmallVector<SILValue, 8> trampolineArguments; | ||
// Propagate pullback struct argument. | ||
trampolineArguments.push_back(adjointSuccBB->getArguments().front()); | ||
dan-zheng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Propagate adjoint values/buffers of active values/buffers to | ||
// predecessor blocks. | ||
auto &predBBActiveValues = activeValues[predBB]; | ||
|
@@ -4411,8 +4483,6 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> { | |
adjLoc, adjBuf, predAdjBuf, IsNotTake, IsNotInitialization); | ||
} | ||
} | ||
// Propagate pullback struct argument. | ||
trampolineArguments.push_back(adjointSuccBB->getArguments().front()); | ||
// Branch from adjoint trampoline block to adjoint block. | ||
adjointTrampolineBBBuilder.createBranch( | ||
adjLoc, adjointBB, trampolineArguments); | ||
|
@@ -4421,7 +4491,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> { | |
getPullbackInfo().lookUpPredecessorEnumElement(predBB, bb); | ||
adjointSuccessorCases.push_back({enumEltDecl, adjointSuccBB}); | ||
} | ||
// Emit clenaups for all block-local adjoint values. | ||
// Emit cleanups for all block-local adjoint values. | ||
for (auto adjVal : blockLocalAdjointValues) | ||
emitCleanupForAdjointValue(adjVal); | ||
blockLocalAdjointValues.clear(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems necessary to also handle
switch_enum_addr
, but I didn't find Swift functions containingswitch_enum_addr
that also don't contain active enum values. A quick searches shows thatswitch_enum_addr
is generated during later SIL passes and also for optional force-unwrapping.