@@ -1551,7 +1551,6 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
1551
1551
setVaried (cbi->getFalseBB ()->getArgument (opIdx), i);
1552
1552
}
1553
1553
}
1554
-
1555
1554
// Handle everything else.
1556
1555
else {
1557
1556
for (auto &op : inst.getAllOperands ())
@@ -3863,6 +3862,12 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3863
3862
void initializeAdjointValue (SILBasicBlock *origBB, SILValue originalValue,
3864
3863
AdjointValue adjointValue) {
3865
3864
assert (origBB->getParent () == &getOriginal ());
3865
+ assert (originalValue->getType ().isObject ());
3866
+ assert (adjointValue.getType ().isObject ());
3867
+ assert (originalValue->getFunction () == &getOriginal ());
3868
+ // The adjoint value must be in the tangent space.
3869
+ assert (adjointValue.getType () ==
3870
+ getRemappedTangentType (originalValue->getType ()));
3866
3871
auto insertion =
3867
3872
valueMap.try_emplace ({origBB, originalValue}, adjointValue);
3868
3873
assert (insertion.second && " Adjoint value inserted before" );
@@ -3892,13 +3897,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3892
3897
assert (newAdjointValue.getType ().isObject ());
3893
3898
assert (originalValue->getFunction () == &getOriginal ());
3894
3899
LLVM_DEBUG (getADDebugStream () << " Adding adjoint for " << originalValue);
3895
- #ifndef NDEBUG
3896
- auto origTy = remapType (originalValue->getType ()).getASTType ();
3897
- auto tangentSpace = getTangentSpace (origTy);
3898
3900
// The adjoint value must be in the tangent space.
3899
- assert (tangentSpace && newAdjointValue.getType ().getASTType ()->isEqual (
3900
- tangentSpace->getCanonicalType ()));
3901
- #endif
3901
+ assert (newAdjointValue.getType () ==
3902
+ getRemappedTangentType (originalValue->getType ()));
3902
3903
auto insertion =
3903
3904
valueMap.try_emplace ({origBB, originalValue}, newAdjointValue);
3904
3905
auto inserted = insertion.second ;
@@ -4493,7 +4494,8 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4493
4494
}
4494
4495
builder.createReturn (adjLoc, joinElements (retElts, builder, adjLoc));
4495
4496
4496
- LLVM_DEBUG (getADDebugStream () << " Generated adjoint:\n " << adjoint);
4497
+ LLVM_DEBUG (getADDebugStream () << " Generated adjoint for "
4498
+ << original.getName () << " :\n " << adjoint);
4497
4499
return errorOccurred;
4498
4500
}
4499
4501
@@ -5572,7 +5574,7 @@ bool VJPEmitter::run() {
5572
5574
errorOccurred = true ;
5573
5575
return true ;
5574
5576
}
5575
- LLVM_DEBUG (getADDebugStream () << " Finished VJPGen for function "
5577
+ LLVM_DEBUG (getADDebugStream () << " Generated VJP for "
5576
5578
<< original->getName () << " :\n " << *vjp);
5577
5579
return errorOccurred;
5578
5580
}
0 commit comments