Skip to content

Commit 2f7b42c

Browse files
authored
[AutoDiff] Handle init_enum_data_addr and inject_enum_addr for Optional (#68300)
Optional's `init_enum_data_addr` and `inject_enum_addr` instructions are generated in presence of non-loadable Optional values. The compiler used to treat these instructions as inactive, and this resulted in silent run-time issues described in #64223. The patch marks `init_enum_data_addr` as "active" if its Optional operand is also active, and in PullbackCloner we differentiate through it and the related `inject_enum_addr`. However, we only determine this relation in simple cases when both instructions are in the same block. There is no def-use relation between them (both take the same Optional operand), so if there is more than one set of instructions operating on the same Optional, or there is some control flow, we currently bail out. In PullbackCloner, we walk over instructions in reverse order and start from `inject_enum_addr` and its `Optional<Wrapped>.TangentVector` operand. Assuming that is is already initialized, we emit an `unchecked_take_enum_data_addr` and set it as the adjoint buffer of `init_enum_data_addr`. The Optional value is invalidated, and we have to destroy the enum data address later when we reach `init_enum_data_addr`.
1 parent 2f3e090 commit 2f7b42c

File tree

6 files changed

+307
-15
lines changed

6 files changed

+307
-15
lines changed

lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ void DifferentiableActivityInfo::propagateUsefulThroughAddress(
408408
SKIP_NODERIVATIVE(RefElementAddr)
409409
#undef SKIP_NODERIVATIVE
410410
if (Projection::isAddressProjection(res) || isa<BeginAccessInst>(res) ||
411-
isa<BeginBorrowInst>(res))
411+
isa<BeginBorrowInst>(res) || isa<InitEnumDataAddrInst>(res))
412412
propagateUsefulThroughAddress(res, dependentVariableIndex);
413413
}
414414
}

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,19 @@ bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite applySite) {
453453
return hasActiveResults && hasActiveArguments;
454454
}
455455

456+
static bool shouldDifferentiateInjectEnumAddr(
457+
const InjectEnumAddrInst &inject,
458+
const DifferentiableActivityInfo &activityInfo,
459+
const AutoDiffConfig &config) {
460+
SILValue en = inject.getOperand();
461+
for (auto use : en->getUses()) {
462+
auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser());
463+
if (init && activityInfo.isActive(init, config))
464+
return true;
465+
}
466+
return false;
467+
}
468+
456469
/// Returns a flag indicating whether the instruction should be differentiated,
457470
/// given the differentiation indices of the instruction's parent function.
458471
/// Whether the instruction should be differentiated is determined sequentially
@@ -506,6 +519,13 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
506519
isa<DestroyValueInst>(inst) || isa<DestroyAddrInst>(inst))
507520
return true;
508521
}
522+
523+
// Should differentiate `inject_enum_addr` if the corresponding
524+
// `init_enum_addr` has an active operand.
525+
if (auto inject = dyn_cast<InjectEnumAddrInst>(inst))
526+
if (shouldDifferentiateInjectEnumAddr(*inject, activityInfo, config))
527+
return true;
528+
509529
return false;
510530
}
511531

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 152 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -744,14 +744,23 @@ class PullbackCloner::Implementation final
744744
// Optional differentiation
745745
//--------------------------------------------------------------------------//
746746

747-
/// Given a `wrappedAdjoint` value of type `T.TangentVector`, creates an
748-
/// `Optional<T>.TangentVector` value from it and adds it to the adjoint value
749-
/// of `optionalValue`.
747+
/// Given a `wrappedAdjoint` value of type `T.TangentVector` and `Optional<T>`
748+
/// type, creates an `Optional<T>.TangentVector` buffer from it.
750749
///
751750
/// `wrappedAdjoint` may be an object or address value, both cases are
752751
/// handled.
753-
void accumulateAdjointForOptional(SILBasicBlock *bb, SILValue optionalValue,
754-
SILValue wrappedAdjoint);
752+
AllocStackInst *createOptionalAdjoint(SILBasicBlock *bb,
753+
SILValue wrappedAdjoint,
754+
SILType optionalTy);
755+
756+
/// Accumulate optional buffer from `wrappedAdjoint`.
757+
void accumulateAdjointForOptionalBuffer(SILBasicBlock *bb,
758+
SILValue optionalBuffer,
759+
SILValue wrappedAdjoint);
760+
761+
/// Set optional value from `wrappedAdjoint`.
762+
void setAdjointValueForOptional(SILBasicBlock *bb, SILValue optionalValue,
763+
SILValue wrappedAdjoint);
755764

756765
//--------------------------------------------------------------------------//
757766
// Array literal initialization differentiation
@@ -1687,6 +1696,104 @@ class PullbackCloner::Implementation final
16871696
builder.emitZeroIntoBuffer(uccai->getLoc(), adjDest, IsInitialization);
16881697
}
16891698

1699+
/// Handle a sequence of `init_enum_data_addr` and `inject_enum_addr`
1700+
/// instructions.
1701+
///
1702+
/// Original: y = init_enum_data_addr x
1703+
/// inject_enum_addr y
1704+
///
1705+
/// Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y]
1706+
void visitInjectEnumAddrInst(InjectEnumAddrInst *inject) {
1707+
SILBasicBlock *bb = inject->getParent();
1708+
SILValue origEnum = inject->getOperand();
1709+
1710+
// Only `Optional`-typed operands are supported for now. Diagnose all other
1711+
// enum operand types.
1712+
auto *optionalEnumDecl = getASTContext().getOptionalDecl();
1713+
if (origEnum->getType().getEnumOrBoundGenericEnum() != optionalEnumDecl) {
1714+
LLVM_DEBUG(getADDebugStream()
1715+
<< "Unsupported enum type in PullbackCloner: " << *inject);
1716+
getContext().emitNondifferentiabilityError(
1717+
inject, getInvoker(),
1718+
diag::autodiff_expression_not_differentiable_note);
1719+
errorOccurred = true;
1720+
return;
1721+
}
1722+
1723+
InitEnumDataAddrInst *origData = nullptr;
1724+
for (auto use : origEnum->getUses()) {
1725+
if (auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser())) {
1726+
// We need a more complicated analysis when init_enum_data_addr and
1727+
// inject_enum_addr are in different blocks, or there is more than one
1728+
// such instruction. Bail out for now.
1729+
if (origData || init->getParent() != bb) {
1730+
LLVM_DEBUG(getADDebugStream()
1731+
<< "Could not find a matching init_enum_data_addr for: "
1732+
<< *inject);
1733+
getContext().emitNondifferentiabilityError(
1734+
inject, getInvoker(),
1735+
diag::autodiff_expression_not_differentiable_note);
1736+
errorOccurred = true;
1737+
return;
1738+
}
1739+
1740+
origData = init;
1741+
}
1742+
}
1743+
1744+
SILValue adjStruct = getAdjointBuffer(bb, origEnum);
1745+
StructDecl *adjStructDecl =
1746+
adjStruct->getType().getStructOrBoundGenericStruct();
1747+
1748+
VarDecl *adjOptVar = nullptr;
1749+
if (adjStructDecl) {
1750+
ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties();
1751+
adjOptVar = properties.size() == 1 ? properties[0] : nullptr;
1752+
}
1753+
1754+
EnumDecl *adjOptDecl =
1755+
adjOptVar ? adjOptVar->getTypeInContext()->getEnumOrBoundGenericEnum()
1756+
: nullptr;
1757+
1758+
// Optional<T>.TangentVector should be a struct with a single
1759+
// Optional<T.TangentVector> property. This is an implementation detail of
1760+
// OptionalDifferentiation.swift
1761+
if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
1762+
llvm_unreachable("Unexpected type of Optional.TangentVector");
1763+
1764+
SILLocation loc = origData->getLoc();
1765+
StructElementAddrInst *adjOpt =
1766+
builder.createStructElementAddr(loc, adjStruct, adjOptVar);
1767+
1768+
// unchecked_take_enum_data_addr is destructive, so copy
1769+
// Optional<T.TangentVector> to a new alloca.
1770+
AllocStackInst *adjOptCopy =
1771+
createFunctionLocalAllocation(adjOpt->getType(), loc);
1772+
builder.createCopyAddr(loc, adjOpt, adjOptCopy, IsNotTake,
1773+
IsInitialization);
1774+
1775+
EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl();
1776+
UncheckedTakeEnumDataAddrInst *adjData =
1777+
builder.createUncheckedTakeEnumDataAddr(loc, adjOptCopy, someElemDecl);
1778+
1779+
setAdjointBuffer(bb, origData, adjData);
1780+
1781+
// The Optional copy is invalidated, do not attempt to destroy it at the end
1782+
// of the pullback. The value returned from unchecked_take_enum_data_addr is
1783+
// destroyed in visitInitEnumDataAddrInst.
1784+
destroyedLocalAllocations.insert(adjOptCopy);
1785+
}
1786+
1787+
/// Handle `init_enum_data_addr` instruction.
1788+
/// Destroy the value returned from `unchecked_take_enum_data_addr`.
1789+
void visitInitEnumDataAddrInst(InitEnumDataAddrInst *init) {
1790+
auto bufIt = bufferMap.find({init->getParent(), SILValue(init)});
1791+
if (bufIt == bufferMap.end())
1792+
return;
1793+
SILValue adjData = bufIt->second;
1794+
builder.emitDestroyAddr(init->getLoc(), adjData);
1795+
}
1796+
16901797
/// Handle `unchecked_ref_cast` instruction.
16911798
/// Original: y = unchecked_ref_cast x
16921799
/// Adjoint: adj[x] += adj[y]
@@ -1758,7 +1865,7 @@ class PullbackCloner::Implementation final
17581865
errorOccurred = true;
17591866
return;
17601867
}
1761-
accumulateAdjointForOptional(bb, utedai->getOperand(), adjDest);
1868+
accumulateAdjointForOptionalBuffer(bb, utedai->getOperand(), adjDest);
17621869
builder.emitZeroIntoBuffer(utedai->getLoc(), adjDest, IsNotInitialization);
17631870
}
17641871

@@ -2342,12 +2449,11 @@ void PullbackCloner::Implementation::emitZeroDerivativesForNonvariedResult(
23422449
<< pullback);
23432450
}
23442451

2345-
void PullbackCloner::Implementation::accumulateAdjointForOptional(
2346-
SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
2452+
AllocStackInst *PullbackCloner::Implementation::createOptionalAdjoint(
2453+
SILBasicBlock *bb, SILValue wrappedAdjoint, SILType optionalTy) {
23472454
auto pbLoc = getPullback().getLocation();
2348-
// Handle `switch_enum` on `Optional`.
23492455
// `Optional<T>`
2350-
auto optionalTy = remapType(optionalValue->getType());
2456+
optionalTy = remapType(optionalTy);
23512457
assert(optionalTy.getASTType()->isOptional());
23522458
// `T`
23532459
auto wrappedType = optionalTy.getOptionalObjectType();
@@ -2429,13 +2535,45 @@ void PullbackCloner::Implementation::accumulateAdjointForOptional(
24292535
builder.createApply(pbLoc, initFnRef, subMap,
24302536
{optTanAdjBuf, optArgBuf, metatype});
24312537
builder.createDeallocStack(pbLoc, optArgBuf);
2538+
return optTanAdjBuf;
2539+
}
2540+
2541+
// Accumulate adjoint for the incoming `Optional` buffer.
2542+
void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer(
2543+
SILBasicBlock *bb, SILValue optionalBuffer, SILValue wrappedAdjoint) {
2544+
assert(getTangentValueCategory(optionalBuffer) == SILValueCategory::Address);
2545+
auto pbLoc = getPullback().getLocation();
24322546

2433-
// Accumulate adjoint for the incoming `Optional` value.
2434-
addToAdjointBuffer(bb, optionalValue, optTanAdjBuf, pbLoc);
2547+
// Allocate and initialize Optional<Wrapped>.TangentVector from
2548+
// Wrapped.TangentVector
2549+
AllocStackInst *optTanAdjBuf =
2550+
createOptionalAdjoint(bb, wrappedAdjoint, optionalBuffer->getType());
2551+
2552+
// Accumulate into optionalBuffer
2553+
addToAdjointBuffer(bb, optionalBuffer, optTanAdjBuf, pbLoc);
24352554
builder.emitDestroyAddr(pbLoc, optTanAdjBuf);
24362555
builder.createDeallocStack(pbLoc, optTanAdjBuf);
24372556
}
24382557

2558+
// Set the adjoint value for the incoming `Optional` value.
2559+
void PullbackCloner::Implementation::setAdjointValueForOptional(
2560+
SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
2561+
assert(getTangentValueCategory(optionalValue) == SILValueCategory::Object);
2562+
auto pbLoc = getPullback().getLocation();
2563+
2564+
// Allocate and initialize Optional<Wrapped>.TangentVector from
2565+
// Wrapped.TangentVector
2566+
AllocStackInst *optTanAdjBuf =
2567+
createOptionalAdjoint(bb, wrappedAdjoint, optionalValue->getType());
2568+
2569+
auto optTanAdjVal = builder.emitLoadValueOperation(
2570+
pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take);
2571+
recordTemporary(optTanAdjVal);
2572+
builder.createDeallocStack(pbLoc, optTanAdjBuf);
2573+
2574+
setAdjointValue(bb, optionalValue, makeConcreteAdjointValue(optTanAdjVal));
2575+
}
2576+
24392577
SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
24402578
SILBasicBlock *origBB, SILBasicBlock *origPredBB,
24412579
SmallDenseMap<SILValue, TrampolineBlockSet> &pullbackTrampolineBlockMap) {
@@ -2623,7 +2761,7 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
26232761
// Handle `switch_enum` on `Optional`.
26242762
auto termInst = bbArg->getSingleTerminator();
26252763
if (isSwitchEnumInstOnOptional(termInst)) {
2626-
accumulateAdjointForOptional(bb, incomingValue, concreteBBArgAdjCopy);
2764+
setAdjointValueForOptional(bb, incomingValue, concreteBBArgAdjCopy);
26272765
} else {
26282766
blockTemporaries[getPullbackBlock(predBB)].insert(
26292767
concreteBBArgAdjCopy);
@@ -2643,7 +2781,7 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
26432781
// Handle `switch_enum` on `Optional`.
26442782
auto termInst = bbArg->getSingleTerminator();
26452783
if (isSwitchEnumInstOnOptional(termInst))
2646-
accumulateAdjointForOptional(bb, incomingValue, bbArgAdjBuf);
2784+
accumulateAdjointForOptionalBuffer(bb, incomingValue, bbArgAdjBuf);
26472785
else
26482786
addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc);
26492787
}

stdlib/public/Differentiation/OptionalDifferentiation.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import Swift
1414

1515
extension Optional: Differentiable where Wrapped: Differentiable {
16+
@frozen
1617
public struct TangentVector: Differentiable, AdditiveArithmetic {
1718
public typealias TangentVector = Self
1819

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: %target-swift-emit-sil %s | %FileCheck %s
2+
3+
import _Differentiation
4+
5+
// CHECK: sil private{{.*}}@$s17optional_pullback23givesWrongTangentVector1xxSgx_t16_Differentiation14DifferentiableRzlFAeFRzlTJpSpSr
6+
// CHECK-SAME: $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable>
7+
// CHECK-SAME: (@in_guaranteed Optional<τ_0_0>.TangentVector) -> @out τ_0_0.TangentVector
8+
//
9+
// CHECK: bb0(%[[RET_TAN:.+]] : $*τ_0_0.TangentVector, %[[OPT_TAN:.+]] : $*Optional<τ_0_0>.TangentVector):
10+
// CHECK: %[[RET_TAN_BUF:.+]] = alloc_stack $τ_0_0.TangentVector
11+
12+
// CHECK: %[[ZERO1:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter
13+
// CHECK: apply %[[ZERO1]]<τ_0_0.TangentVector>(%[[RET_TAN_BUF]], %{{.*}})
14+
//
15+
// CHECK: %[[TAN_VAL_COPY:.+]] = alloc_stack $Optional<τ_0_0.TangentVector>
16+
// CHECK: %[[TAN_BUF:.+]] = alloc_stack $Optional<τ_0_0>.TangentVector
17+
18+
// CHECK: copy_addr %[[OPT_TAN]] to [init] %[[TAN_BUF]] : $*Optional<τ_0_0>.TangentVector
19+
// CHECK: %[[TAN_VAL:.+]] = struct_element_addr %[[TAN_BUF]] : $*Optional<τ_0_0>.TangentVector, #Optional.TangentVector.value
20+
// CHECK: copy_addr %[[TAN_VAL]] to [init] %[[TAN_VAL_COPY]] : $*Optional<τ_0_0.TangentVector>
21+
//
22+
// CHECK: %[[TAN_DATA:.+]] = unchecked_take_enum_data_addr %[[TAN_VAL_COPY]] : $*Optional<τ_0_0.TangentVector>, #Optional.some!enumelt
23+
// CHECK: %[[PLUS_EQUAL:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic."+="
24+
// CHECK: apply %[[PLUS_EQUAL]]<τ_0_0.TangentVector>(%[[RET_TAN_BUF]], %[[TAN_DATA]], %{{.*}})
25+
//
26+
// CHECK: destroy_addr %[[TAN_DATA]] : $*τ_0_0.TangentVector
27+
// CHECK: %[[ZERO2:.+]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter
28+
// CHECK: apply %[[ZERO2]]<τ_0_0.TangentVector>(%[[TAN_DATA]], %{{.*}})
29+
// CHECK: destroy_addr %[[TAN_DATA]] : $*τ_0_0.TangentVector
30+
//
31+
// CHECK: copy_addr [take] %[[RET_TAN_BUF:.+]] to [init] %[[RET_TAN:.+]]
32+
// CHECK: destroy_addr %[[TAN_BUF]] : $*Optional<τ_0_0>.TangentVector
33+
// CHECK: dealloc_stack %[[TAN_BUF]] : $*Optional<τ_0_0>.TangentVector
34+
// CHECK: dealloc_stack %[[TAN_VAL_COPY]] : $*Optional<τ_0_0.TangentVector>
35+
// CHECK: dealloc_stack %[[RET_TAN_BUF]] : $*τ_0_0.TangentVector
36+
37+
@differentiable(reverse)
38+
func givesWrongTangentVector<Element>(x: Element) -> Element? where Element: Differentiable {
39+
return x
40+
}
41+
42+
@differentiable(reverse)
43+
func f(x: Double) -> Double {
44+
let y = givesWrongTangentVector(x: x)
45+
return y!
46+
}
47+
48+
print(valueWithGradient(at: 0.0, of: f))

0 commit comments

Comments
 (0)