@@ -1820,6 +1820,12 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
1820
1820
if (isVaried (cai->getSrc (), i))
1821
1821
recursivelySetVaried (cai->getDest (), i);
1822
1822
}
1823
+ // Handle `unconditional_checked_cast_addr`.
1824
+ else if (auto *uccai =
1825
+ dyn_cast<UnconditionalCheckedCastAddrInst>(&inst)) {
1826
+ if (isVaried (uccai->getSrc (), i))
1827
+ recursivelySetVaried (uccai->getDest (), i);
1828
+ }
1823
1829
// Handle `tuple_element_addr`.
1824
1830
else if (auto *teai = dyn_cast<TupleElementAddrInst>(&inst)) {
1825
1831
if (isVaried (teai->getOperand (), i)) {
@@ -1941,6 +1947,12 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
1941
1947
if (isUseful (cai->getDest (), i))
1942
1948
propagateUsefulThroughBuffer (cai->getSrc (), i);
1943
1949
}
1950
+ // Handle `unconditional_checked_cast_addr`.
1951
+ else if (auto *uccai =
1952
+ dyn_cast<UnconditionalCheckedCastAddrInst>(&inst)) {
1953
+ if (isUseful (uccai->getDest (), i))
1954
+ propagateUsefulThroughBuffer (uccai->getSrc (), i);
1955
+ }
1944
1956
// Handle everything else.
1945
1957
else if (llvm::any_of (inst.getResults (),
1946
1958
[&](SILValue res) { return isUseful (res, i); })) {
@@ -6601,6 +6613,27 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
6601
6613
}
6602
6614
}
6603
6615
6616
+ // / Handle `unconditional_checked_cast_addr` instruction.
6617
+ // / Original: y = unconditional_checked_cast_addr x
6618
+ // / Adjoint: adj[x] += unconditional_checked_cast_addr adj[y]
6619
+ void visitUnconditionalCheckedCastAddrInst (
6620
+ UnconditionalCheckedCastAddrInst *uccai) {
6621
+ auto *bb = uccai->getParent ();
6622
+ auto &adjDest = getAdjointBuffer (bb, uccai->getDest ());
6623
+ auto &adjSrc = getAdjointBuffer (bb, uccai->getSrc ());
6624
+ if (errorOccurred)
6625
+ return ;
6626
+ auto destType = remapType (adjDest->getType ());
6627
+ auto castBuf = builder.createAllocStack (uccai->getLoc (), adjSrc->getType ());
6628
+ builder.createUnconditionalCheckedCastAddr (
6629
+ uccai->getLoc (), adjDest, adjDest->getType ().getASTType (), castBuf,
6630
+ adjSrc->getType ().getASTType ());
6631
+ addToAdjointBuffer (bb, uccai->getSrc (), castBuf, uccai->getLoc ());
6632
+ builder.emitDestroyAddrAndFold (uccai->getLoc (), castBuf);
6633
+ builder.createDeallocStack (uccai->getLoc (), castBuf);
6634
+ emitZeroIndirect (destType.getASTType (), adjDest, uccai->getLoc ());
6635
+ }
6636
+
6604
6637
#define NOT_DIFFERENTIABLE (INST, DIAG ) \
6605
6638
void visit##INST##Inst(INST##Inst *inst) { \
6606
6639
getContext ().emitNondifferentiabilityError ( \
0 commit comments