@@ -744,14 +744,23 @@ class PullbackCloner::Implementation final
744
744
// Optional differentiation
745
745
// --------------------------------------------------------------------------//
746
746
747
- // / Given a `wrappedAdjoint` value of type `T.TangentVector`, creates an
748
- // / `Optional<T>.TangentVector` value from it and adds it to the adjoint value
749
- // / of `optionalValue`.
747
+ // / Given a `wrappedAdjoint` value of type `T.TangentVector` and `Optional<T>`
748
+ // / type, creates an `Optional<T>.TangentVector` buffer from it.
750
749
// /
751
750
// / `wrappedAdjoint` may be an object or address value, both cases are
752
751
// / handled.
753
- void accumulateAdjointForOptional (SILBasicBlock *bb, SILValue optionalValue,
754
- SILValue wrappedAdjoint);
752
+ AllocStackInst *createOptionalAdjoint (SILBasicBlock *bb,
753
+ SILValue wrappedAdjoint,
754
+ SILType optionalTy);
755
+
756
+ // / Accumulate optional buffer from `wrappedAdjoint`.
757
+ void accumulateAdjointForOptionalBuffer (SILBasicBlock *bb,
758
+ SILValue optionalBuffer,
759
+ SILValue wrappedAdjoint);
760
+
761
+ // / Set optional value from `wrappedAdjoint`.
762
+ void setAdjointValueForOptional (SILBasicBlock *bb, SILValue optionalValue,
763
+ SILValue wrappedAdjoint);
755
764
756
765
// --------------------------------------------------------------------------//
757
766
// Array literal initialization differentiation
@@ -1687,6 +1696,104 @@ class PullbackCloner::Implementation final
1687
1696
builder.emitZeroIntoBuffer (uccai->getLoc (), adjDest, IsInitialization);
1688
1697
}
1689
1698
1699
+ // / Handle a sequence of `init_enum_data_addr` and `inject_enum_addr`
1700
+ // / instructions.
1701
+ // /
1702
+ // / Original: y = init_enum_data_addr x
1703
+ // / inject_enum_addr y
1704
+ // /
1705
+ // / Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y]
1706
+ void visitInjectEnumAddrInst (InjectEnumAddrInst *inject) {
1707
+ SILBasicBlock *bb = inject->getParent ();
1708
+ SILValue origEnum = inject->getOperand ();
1709
+
1710
+ // Only `Optional`-typed operands are supported for now. Diagnose all other
1711
+ // enum operand types.
1712
+ auto *optionalEnumDecl = getASTContext ().getOptionalDecl ();
1713
+ if (origEnum->getType ().getEnumOrBoundGenericEnum () != optionalEnumDecl) {
1714
+ LLVM_DEBUG (getADDebugStream ()
1715
+ << " Unsupported enum type in PullbackCloner: " << *inject);
1716
+ getContext ().emitNondifferentiabilityError (
1717
+ inject, getInvoker (),
1718
+ diag::autodiff_expression_not_differentiable_note);
1719
+ errorOccurred = true ;
1720
+ return ;
1721
+ }
1722
+
1723
+ InitEnumDataAddrInst *origData = nullptr ;
1724
+ for (auto use : origEnum->getUses ()) {
1725
+ if (auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser ())) {
1726
+ // We need a more complicated analysis when init_enum_data_addr and
1727
+ // inject_enum_addr are in different blocks, or there is more than one
1728
+ // such instruction. Bail out for now.
1729
+ if (origData || init->getParent () != bb) {
1730
+ LLVM_DEBUG (getADDebugStream ()
1731
+ << " Could not find a matching init_enum_data_addr for: "
1732
+ << *inject);
1733
+ getContext ().emitNondifferentiabilityError (
1734
+ inject, getInvoker (),
1735
+ diag::autodiff_expression_not_differentiable_note);
1736
+ errorOccurred = true ;
1737
+ return ;
1738
+ }
1739
+
1740
+ origData = init;
1741
+ }
1742
+ }
1743
+
1744
+ SILValue adjStruct = getAdjointBuffer (bb, origEnum);
1745
+ StructDecl *adjStructDecl =
1746
+ adjStruct->getType ().getStructOrBoundGenericStruct ();
1747
+
1748
+ VarDecl *adjOptVar = nullptr ;
1749
+ if (adjStructDecl) {
1750
+ ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties ();
1751
+ adjOptVar = properties.size () == 1 ? properties[0 ] : nullptr ;
1752
+ }
1753
+
1754
+ EnumDecl *adjOptDecl =
1755
+ adjOptVar ? adjOptVar->getTypeInContext ()->getEnumOrBoundGenericEnum ()
1756
+ : nullptr ;
1757
+
1758
+ // Optional<T>.TangentVector should be a struct with a single
1759
+ // Optional<T.TangentVector> property. This is an implementation detail of
1760
+ // OptionalDifferentiation.swift
1761
+ if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
1762
+ llvm_unreachable (" Unexpected type of Optional.TangentVector" );
1763
+
1764
+ SILLocation loc = origData->getLoc ();
1765
+ StructElementAddrInst *adjOpt =
1766
+ builder.createStructElementAddr (loc, adjStruct, adjOptVar);
1767
+
1768
+ // unchecked_take_enum_data_addr is destructive, so copy
1769
+ // Optional<T.TangentVector> to a new alloca.
1770
+ AllocStackInst *adjOptCopy =
1771
+ createFunctionLocalAllocation (adjOpt->getType (), loc);
1772
+ builder.createCopyAddr (loc, adjOpt, adjOptCopy, IsNotTake,
1773
+ IsInitialization);
1774
+
1775
+ EnumElementDecl *someElemDecl = getASTContext ().getOptionalSomeDecl ();
1776
+ UncheckedTakeEnumDataAddrInst *adjData =
1777
+ builder.createUncheckedTakeEnumDataAddr (loc, adjOptCopy, someElemDecl);
1778
+
1779
+ setAdjointBuffer (bb, origData, adjData);
1780
+
1781
+ // The Optional copy is invalidated, do not attempt to destroy it at the end
1782
+ // of the pullback. The value returned from unchecked_take_enum_data_addr is
1783
+ // destroyed in visitInitEnumDataAddrInst.
1784
+ destroyedLocalAllocations.insert (adjOptCopy);
1785
+ }
1786
+
1787
+ // / Handle `init_enum_data_addr` instruction.
1788
+ // / Destroy the value returned from `unchecked_take_enum_data_addr`.
1789
+ void visitInitEnumDataAddrInst (InitEnumDataAddrInst *init) {
1790
+ auto bufIt = bufferMap.find ({init->getParent (), SILValue (init)});
1791
+ if (bufIt == bufferMap.end ())
1792
+ return ;
1793
+ SILValue adjData = bufIt->second ;
1794
+ builder.emitDestroyAddr (init->getLoc (), adjData);
1795
+ }
1796
+
1690
1797
// / Handle `unchecked_ref_cast` instruction.
1691
1798
// / Original: y = unchecked_ref_cast x
1692
1799
// / Adjoint: adj[x] += adj[y]
@@ -1758,7 +1865,7 @@ class PullbackCloner::Implementation final
1758
1865
errorOccurred = true ;
1759
1866
return ;
1760
1867
}
1761
- accumulateAdjointForOptional (bb, utedai->getOperand (), adjDest);
1868
+ accumulateAdjointForOptionalBuffer (bb, utedai->getOperand (), adjDest);
1762
1869
builder.emitZeroIntoBuffer (utedai->getLoc (), adjDest, IsNotInitialization);
1763
1870
}
1764
1871
@@ -2342,12 +2449,11 @@ void PullbackCloner::Implementation::emitZeroDerivativesForNonvariedResult(
2342
2449
<< pullback);
2343
2450
}
2344
2451
2345
- void PullbackCloner::Implementation::accumulateAdjointForOptional (
2346
- SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint ) {
2452
+ AllocStackInst * PullbackCloner::Implementation::createOptionalAdjoint (
2453
+ SILBasicBlock *bb, SILValue wrappedAdjoint, SILType optionalTy ) {
2347
2454
auto pbLoc = getPullback ().getLocation ();
2348
- // Handle `switch_enum` on `Optional`.
2349
2455
// `Optional<T>`
2350
- auto optionalTy = remapType (optionalValue-> getType () );
2456
+ optionalTy = remapType (optionalTy );
2351
2457
assert (optionalTy.getASTType ()->isOptional ());
2352
2458
// `T`
2353
2459
auto wrappedType = optionalTy.getOptionalObjectType ();
@@ -2429,13 +2535,45 @@ void PullbackCloner::Implementation::accumulateAdjointForOptional(
2429
2535
builder.createApply (pbLoc, initFnRef, subMap,
2430
2536
{optTanAdjBuf, optArgBuf, metatype});
2431
2537
builder.createDeallocStack (pbLoc, optArgBuf);
2538
+ return optTanAdjBuf;
2539
+ }
2540
+
2541
+ // Accumulate adjoint for the incoming `Optional` buffer.
2542
+ void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer (
2543
+ SILBasicBlock *bb, SILValue optionalBuffer, SILValue wrappedAdjoint) {
2544
+ assert (getTangentValueCategory (optionalBuffer) == SILValueCategory::Address);
2545
+ auto pbLoc = getPullback ().getLocation ();
2432
2546
2433
- // Accumulate adjoint for the incoming `Optional` value.
2434
- addToAdjointBuffer (bb, optionalValue, optTanAdjBuf, pbLoc);
2547
+ // Allocate and initialize Optional<Wrapped>.TangentVector from
2548
+ // Wrapped.TangentVector
2549
+ AllocStackInst *optTanAdjBuf =
2550
+ createOptionalAdjoint (bb, wrappedAdjoint, optionalBuffer->getType ());
2551
+
2552
+ // Accumulate into optionalBuffer
2553
+ addToAdjointBuffer (bb, optionalBuffer, optTanAdjBuf, pbLoc);
2435
2554
builder.emitDestroyAddr (pbLoc, optTanAdjBuf);
2436
2555
builder.createDeallocStack (pbLoc, optTanAdjBuf);
2437
2556
}
2438
2557
2558
+ // Set the adjoint value for the incoming `Optional` value.
2559
+ void PullbackCloner::Implementation::setAdjointValueForOptional (
2560
+ SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
2561
+ assert (getTangentValueCategory (optionalValue) == SILValueCategory::Object);
2562
+ auto pbLoc = getPullback ().getLocation ();
2563
+
2564
+ // Allocate and initialize Optional<Wrapped>.TangentVector from
2565
+ // Wrapped.TangentVector
2566
+ AllocStackInst *optTanAdjBuf =
2567
+ createOptionalAdjoint (bb, wrappedAdjoint, optionalValue->getType ());
2568
+
2569
+ auto optTanAdjVal = builder.emitLoadValueOperation (
2570
+ pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take);
2571
+ recordTemporary (optTanAdjVal);
2572
+ builder.createDeallocStack (pbLoc, optTanAdjBuf);
2573
+
2574
+ setAdjointValue (bb, optionalValue, makeConcreteAdjointValue (optTanAdjVal));
2575
+ }
2576
+
2439
2577
SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor (
2440
2578
SILBasicBlock *origBB, SILBasicBlock *origPredBB,
2441
2579
SmallDenseMap<SILValue, TrampolineBlockSet> &pullbackTrampolineBlockMap) {
@@ -2623,7 +2761,7 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
2623
2761
// Handle `switch_enum` on `Optional`.
2624
2762
auto termInst = bbArg->getSingleTerminator ();
2625
2763
if (isSwitchEnumInstOnOptional (termInst)) {
2626
- accumulateAdjointForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2764
+ setAdjointValueForOptional (bb, incomingValue, concreteBBArgAdjCopy);
2627
2765
} else {
2628
2766
blockTemporaries[getPullbackBlock (predBB)].insert (
2629
2767
concreteBBArgAdjCopy);
@@ -2643,7 +2781,7 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
2643
2781
// Handle `switch_enum` on `Optional`.
2644
2782
auto termInst = bbArg->getSingleTerminator ();
2645
2783
if (isSwitchEnumInstOnOptional (termInst))
2646
- accumulateAdjointForOptional (bb, incomingValue, bbArgAdjBuf);
2784
+ accumulateAdjointForOptionalBuffer (bb, incomingValue, bbArgAdjBuf);
2647
2785
else
2648
2786
addToAdjointBuffer (bb, incomingValue, bbArgAdjBuf, pbLoc);
2649
2787
}
0 commit comments