Skip to content

Commit 1c1305c

Browse files
authored
[NFC] [AutoDiff] Gardening. (#25264)
1 parent cec88da commit 1c1305c

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,7 +1551,6 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
15511551
setVaried(cbi->getFalseBB()->getArgument(opIdx), i);
15521552
}
15531553
}
1554-
15551554
// Handle everything else.
15561555
else {
15571556
for (auto &op : inst.getAllOperands())
@@ -3863,6 +3862,12 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
38633862
void initializeAdjointValue(SILBasicBlock *origBB, SILValue originalValue,
38643863
AdjointValue adjointValue) {
38653864
assert(origBB->getParent() == &getOriginal());
3865+
assert(originalValue->getType().isObject());
3866+
assert(adjointValue.getType().isObject());
3867+
assert(originalValue->getFunction() == &getOriginal());
3868+
// The adjoint value must be in the tangent space.
3869+
assert(adjointValue.getType() ==
3870+
getRemappedTangentType(originalValue->getType()));
38663871
auto insertion =
38673872
valueMap.try_emplace({origBB, originalValue}, adjointValue);
38683873
assert(insertion.second && "Adjoint value inserted before");
@@ -3892,13 +3897,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
38923897
assert(newAdjointValue.getType().isObject());
38933898
assert(originalValue->getFunction() == &getOriginal());
38943899
LLVM_DEBUG(getADDebugStream() << "Adding adjoint for " << originalValue);
3895-
#ifndef NDEBUG
3896-
auto origTy = remapType(originalValue->getType()).getASTType();
3897-
auto tangentSpace = getTangentSpace(origTy);
38983900
// The adjoint value must be in the tangent space.
3899-
assert(tangentSpace && newAdjointValue.getType().getASTType()->isEqual(
3900-
tangentSpace->getCanonicalType()));
3901-
#endif
3901+
assert(newAdjointValue.getType() ==
3902+
getRemappedTangentType(originalValue->getType()));
39023903
auto insertion =
39033904
valueMap.try_emplace({origBB, originalValue}, newAdjointValue);
39043905
auto inserted = insertion.second;
@@ -4493,7 +4494,8 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
44934494
}
44944495
builder.createReturn(adjLoc, joinElements(retElts, builder, adjLoc));
44954496

4496-
LLVM_DEBUG(getADDebugStream() << "Generated adjoint:\n" << adjoint);
4497+
LLVM_DEBUG(getADDebugStream() << "Generated adjoint for "
4498+
<< original.getName() << ":\n" << adjoint);
44974499
return errorOccurred;
44984500
}
44994501

@@ -5572,7 +5574,7 @@ bool VJPEmitter::run() {
55725574
errorOccurred = true;
55735575
return true;
55745576
}
5575-
LLVM_DEBUG(getADDebugStream() << "Finished VJPGen for function "
5577+
LLVM_DEBUG(getADDebugStream() << "Generated VJP for "
55765578
<< original->getName() << ":\n" << *vjp);
55775579
return errorOccurred;
55785580
}

0 commit comments

Comments
 (0)