Skip to content

Commit 8c548ae

Browse files
committed
[AutoDiff] Improve debugging utilities.
- Show SIL type when printing `AdjointValue`. - Add utilities to print `PullbackEmitter` adjoint value and buffer mappings. - Print generated VJP before printing generated pullback. - This is useful because pullback generation may crash after VJP generation succeeds.
1 parent cc84c7b commit 8c548ae

File tree

4 files changed

+75
-8
lines changed

4 files changed

+75
-8
lines changed

include/swift/SILOptimizer/Differentiation/AdjointValue.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,12 @@ class AdjointValue final {
136136
void print(llvm::raw_ostream &s) const {
137137
switch (getKind()) {
138138
case AdjointValueKind::Zero:
139-
s << "Zero";
139+
s << "Zero[" << getType() << ']';
140140
break;
141141
case AdjointValueKind::Aggregate:
142-
s << "Aggregate<";
142+
s << "Aggregate[" << getType() << "](";
143143
if (auto *decl =
144144
getType().getASTType()->getStructOrBoundGenericStruct()) {
145-
s << "Struct>(";
146145
interleave(
147146
llvm::zip(decl->getStoredProperties(), base->value.aggregate),
148147
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
@@ -151,7 +150,6 @@ class AdjointValue final {
151150
},
152151
[&s] { s << ", "; });
153152
} else if (getType().is<TupleType>()) {
154-
s << "Tuple>(";
155153
interleave(
156154
base->value.aggregate,
157155
[&s](const AdjointValue &elt) { elt.print(s); },
@@ -162,10 +160,11 @@ class AdjointValue final {
162160
s << ')';
163161
break;
164162
case AdjointValueKind::Concrete:
165-
s << "Concrete(" << base->value.concrete << ')';
163+
s << "Concrete[" << getType() << "](" << base->value.concrete << ')';
166164
break;
167165
}
168166
}
167+
169168
SWIFT_DEBUG_DUMP { print(llvm::dbgs()); };
170169
};
171170

include/swift/SILOptimizer/Differentiation/PullbackEmitter.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,13 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
307307
return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock});
308308
}
309309

310+
//--------------------------------------------------------------------------//
311+
// Debugging utilities
312+
//--------------------------------------------------------------------------//
313+
314+
void printAdjointValueMapping();
315+
void printAdjointBufferMapping();
316+
310317
public:
311318
//--------------------------------------------------------------------------//
312319
// Entry point

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,66 @@ void PullbackEmitter::addToAdjointBuffer(SILBasicBlock *origBB,
513513
accumulateIndirect(adjointBuffer, rhsBufferAccess, loc);
514514
}
515515

516+
//--------------------------------------------------------------------------//
517+
// Debugging utilities
518+
//--------------------------------------------------------------------------//
519+
520+
void PullbackEmitter::printAdjointValueMapping() {
521+
// Group original/adjoint values by basic block.
522+
llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, AdjointValue>> tmp;
523+
for (auto pair : valueMap) {
524+
auto origPair = pair.first;
525+
auto *origBB = origPair.first;
526+
auto origValue = origPair.second;
527+
auto adjValue = pair.second;
528+
tmp[origBB].insert({origValue, adjValue});
529+
}
530+
// Print original/adjoint values per basic block.
531+
auto &s = getADDebugStream() << "Adjoint value mapping:\n";
532+
for (auto &origBB : getOriginal()) {
533+
if (!pullbackBBMap.count(&origBB))
534+
continue;
535+
auto bbValueMap = tmp[&origBB];
536+
s << "bb" << origBB.getDebugID();
537+
s << " (size " << bbValueMap.size() << "):\n";
538+
for (auto valuePair : bbValueMap) {
539+
auto origValue = valuePair.first;
540+
auto adjValue = valuePair.second;
541+
s << "ORIG: " << origValue;
542+
s << "ADJ: " << adjValue << '\n';
543+
}
544+
s << '\n';
545+
}
546+
}
547+
548+
void PullbackEmitter::printAdjointBufferMapping() {
549+
// Group original/adjoint buffers by basic block.
550+
llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, SILValue>> tmp;
551+
for (auto pair : bufferMap) {
552+
auto origPair = pair.first;
553+
auto *origBB = origPair.first;
554+
auto origBuf = origPair.second;
555+
auto adjBuf = pair.second;
556+
tmp[origBB][origBuf] = adjBuf;
557+
}
558+
// Print original/adjoint buffers per basic block.
559+
auto &s = getADDebugStream() << "Adjoint buffer mapping:\n";
560+
for (auto &origBB : getOriginal()) {
561+
if (!pullbackBBMap.count(&origBB))
562+
continue;
563+
auto bbBufferMap = tmp[&origBB];
564+
s << "bb" << origBB.getDebugID();
565+
s << " (size " << bbBufferMap.size() << "):\n";
566+
for (auto valuePair : bbBufferMap) {
567+
auto origBuf = valuePair.first;
568+
auto adjBuf = valuePair.second;
569+
s << "ORIG: " << origBuf;
570+
s << "ADJ: " << adjBuf << '\n';
571+
}
572+
s << '\n';
573+
}
574+
}
575+
516576
//--------------------------------------------------------------------------//
517577
// Member accessor pullback generation
518578
//--------------------------------------------------------------------------//

lib/SILOptimizer/Differentiation/VJPEmitter.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -832,15 +832,16 @@ bool VJPEmitter::run() {
832832
// `-enable-strip-ownership-after-serialization` is true.
833833
mergeBasicBlocks(vjp);
834834

835+
LLVM_DEBUG(getADDebugStream()
836+
<< "Generated VJP for " << original->getName() << ":\n"
837+
<< *vjp);
838+
835839
// Generate pullback code.
836840
PullbackEmitter PullbackEmitter(*this);
837841
if (PullbackEmitter.run()) {
838842
errorOccurred = true;
839843
return true;
840844
}
841-
LLVM_DEBUG(getADDebugStream()
842-
<< "Generated VJP for " << original->getName() << ":\n"
843-
<< *vjp);
844845
return errorOccurred;
845846
}
846847

0 commit comments

Comments
 (0)