Skip to content

Commit 8bcee7b

Browse files
authored
[NFC] [AutoDiff] Gardening. (#25253)
1 parent 18d73fa commit 8bcee7b

File tree

1 file changed

+29
-33
lines changed

1 file changed

+29
-33
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,6 +1741,7 @@ static void dumpActivityInfo(SILFunction &fn,
17411741
llvm::raw_ostream &s = llvm::dbgs()) {
17421742
s << "Activity info for " << fn.getName() << " at " << indices << '\n';
17431743
for (auto &bb : fn) {
1744+
s << "bb" << bb.getDebugID() << ":\n";
17441745
for (auto *arg : bb.getArguments())
17451746
dumpActivityInfo(arg, indices, activityInfo, s);
17461747
for (auto &inst : bb)
@@ -3192,7 +3193,7 @@ class VJPEmitter final
31923193
// Do standard cloning.
31933194
if (!hasActiveResults || !hasActiveArguments) {
31943195
LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
3195-
SILClonerWithScopes::visitApplyInst(ai);
3196+
TypeSubstCloner::visitApplyInst(ai);
31963197
return;
31973198
}
31983199

@@ -3893,11 +3894,10 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
38933894
LLVM_DEBUG(getADDebugStream() << "Adding adjoint for " << originalValue);
38943895
#ifndef NDEBUG
38953896
auto origTy = remapType(originalValue->getType()).getASTType();
3896-
auto tanSpace = origTy->getAutoDiffAssociatedTangentSpace(
3897-
LookUpConformanceInModule(getModule().getSwiftModule()));
3897+
auto tangentSpace = getTangentSpace(origTy);
38983898
// The adjoint value must be in the tangent space.
3899-
assert(tanSpace && newAdjointValue.getType().getASTType()->isEqual(
3900-
tanSpace->getCanonicalType()));
3899+
assert(tangentSpace && newAdjointValue.getType().getASTType()->isEqual(
3900+
tangentSpace->getCanonicalType()));
39013901
#endif
39023902
auto insertion =
39033903
valueMap.try_emplace({origBB, originalValue}, newAdjointValue);
@@ -4662,11 +4662,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
46624662
}
46634663

46644664
/// Handle `struct` instruction.
4665-
/// y = struct (x0, x1, x2, ...)
4666-
/// adj[x0] += struct_extract adj[y], #x0
4667-
/// adj[x1] += struct_extract adj[y], #x1
4668-
/// adj[x2] += struct_extract adj[y], #x2
4669-
/// ...
4665+
/// Original: y = struct (x0, x1, x2, ...)
4666+
/// Adjoint: adj[x0] += struct_extract adj[y], #x0
4667+
/// adj[x1] += struct_extract adj[y], #x1
4668+
/// adj[x2] += struct_extract adj[y], #x2
4669+
/// ...
46704670
void visitStructInst(StructInst *si) {
46714671
auto *bb = si->getParent();
46724672
auto loc = si->getLoc();
@@ -4684,9 +4684,8 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
46844684
auto adjStruct = materializeAdjointDirect(std::move(av), loc);
46854685
// Find the struct `TangentVector` type.
46864686
auto structTy = remapType(si->getType()).getASTType();
4687-
auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace(
4688-
LookUpConformanceInModule(getModule().getSwiftModule()))
4689-
->getType()->getCanonicalType();
4687+
auto tangentVectorTy =
4688+
getTangentSpace(structTy)->getType()->getCanonicalType();
46904689
assert(!getModule().Types.getTypeLowering(
46914690
tangentVectorTy, ResilienceExpansion::Minimal)
46924691
.isAddressOnly());
@@ -4737,20 +4736,19 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
47374736
}
47384737
}
47394738

4739+
/// Handle `struct_extract` instruction.
4740+
/// Original: y = struct_extract x, #field
4741+
/// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
4742+
/// ^~~~~~~
4743+
/// field in tangent space corresponding to #field
47404744
void visitStructExtractInst(StructExtractInst *sei) {
47414745
assert(!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
47424746
"`struct_extract` with `@noDerivative` field should not be "
47434747
"differentiated; activity analysis should not marked as varied");
4744-
// Compute adjoint as follows:
4745-
// y = struct_extract x, #key
4746-
// adj[x] += struct (0, ..., #key': adj[y], ..., 0)
4747-
// where `#key'` is the field in the tangent space corresponding to
4748-
// `#key`.
47494748
auto *bb = sei->getParent();
47504749
auto structTy = remapType(sei->getOperand()->getType()).getASTType();
4751-
auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace(
4752-
LookUpConformanceInModule(getModule().getSwiftModule()))
4753-
->getType()->getCanonicalType();
4750+
auto tangentVectorTy =
4751+
getTangentSpace(structTy)->getType()->getCanonicalType();
47544752
assert(!getModule().Types.getTypeLowering(
47554753
tangentVectorTy, ResilienceExpansion::Minimal)
47564754
.isAddressOnly());
@@ -4810,9 +4808,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
48104808
}
48114809

48124810
/// Handle `tuple` instruction.
4813-
/// y = tuple (x0, x1, x2, ...)
4814-
/// adj[x0] += tuple_extract adj[y], 0
4815-
/// ...
4811+
/// Original: y = tuple (x0, x1, x2, ...)
4812+
/// Adjoint: adj[x0] += tuple_extract adj[y], 0
4813+
/// ...
48164814
void visitTupleInst(TupleInst *ti) {
48174815
auto *bb = ti->getParent();
48184816
auto av = getAdjointValue(bb, ti);
@@ -4851,9 +4849,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
48514849
}
48524850

48534851
/// Handle `tuple_extract` instruction.
4854-
/// y = tuple_extract x, <n>
4855-
/// |--- n-th element
4856-
/// adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0)
4852+
/// Original: y = tuple_extract x, <n>
4853+
/// Adjoint: adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0)
4854+
/// ^~~~~~
4855+
/// n'-th element, where n' is tuple tangent space
4856+
/// index corresponding to n
48574857
void visitTupleExtractInst(TupleExtractInst *tei) {
48584858
auto *bb = tei->getParent();
48594859
auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType());
@@ -5234,9 +5234,7 @@ void AdjointEmitter::materializeAdjointIndirectHelper(
52345234

52355235
void AdjointEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess,
52365236
SILLocation loc) {
5237-
auto *swiftMod = getModule().getSwiftModule();
5238-
auto tangentSpace = type->getAutoDiffAssociatedTangentSpace(
5239-
LookUpConformanceInModule(swiftMod));
5237+
auto tangentSpace = getTangentSpace(type);
52405238
assert(tangentSpace && "No tangent space for this type");
52415239
switch (tangentSpace->getKind()) {
52425240
case VectorSpace::Kind::Vector:
@@ -5371,9 +5369,7 @@ SILValue AdjointEmitter::accumulateDirect(SILValue lhs, SILValue rhs) {
53715369
auto adjointTy = lhs->getType();
53725370
auto adjointASTTy = adjointTy.getASTType();
53735371
auto loc = lhs.getLoc();
5374-
auto *swiftMod = getModule().getSwiftModule();
5375-
auto tangentSpace = adjointASTTy->getAutoDiffAssociatedTangentSpace(
5376-
LookUpConformanceInModule(swiftMod));
5372+
auto tangentSpace = getTangentSpace(adjointASTTy);
53775373
assert(tangentSpace && "No tangent space for this type");
53785374
switch (tangentSpace->getKind()) {
53795375
case VectorSpace::Kind::Vector: {

0 commit comments

Comments
 (0)