Skip to content

Commit d890f29

Browse files
authored
[AutoDiff] Improve debugging utilities. (#32269)
- Show SIL type when printing `AdjointValue`. - Add utilities to print `PullbackEmitter` adjoint value and buffer mappings. - Print generated VJP before printing generated pullback. - This is useful because pullback generation may crash after VJP generation succeeds.
1 parent b65300f commit d890f29

File tree

4 files changed

+75
-8
lines changed

4 files changed

+75
-8
lines changed

include/swift/SILOptimizer/Differentiation/AdjointValue.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,12 @@ class AdjointValue final {
136136
void print(llvm::raw_ostream &s) const {
137137
switch (getKind()) {
138138
case AdjointValueKind::Zero:
139-
s << "Zero";
139+
s << "Zero[" << getType() << ']';
140140
break;
141141
case AdjointValueKind::Aggregate:
142-
s << "Aggregate<";
142+
s << "Aggregate[" << getType() << "](";
143143
if (auto *decl =
144144
getType().getASTType()->getStructOrBoundGenericStruct()) {
145-
s << "Struct>(";
146145
interleave(
147146
llvm::zip(decl->getStoredProperties(), base->value.aggregate),
148147
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
@@ -151,7 +150,6 @@ class AdjointValue final {
151150
},
152151
[&s] { s << ", "; });
153152
} else if (getType().is<TupleType>()) {
154-
s << "Tuple>(";
155153
interleave(
156154
base->value.aggregate,
157155
[&s](const AdjointValue &elt) { elt.print(s); },
@@ -162,10 +160,11 @@ class AdjointValue final {
162160
s << ')';
163161
break;
164162
case AdjointValueKind::Concrete:
165-
s << "Concrete(" << base->value.concrete << ')';
163+
s << "Concrete[" << getType() << "](" << base->value.concrete << ')';
166164
break;
167165
}
168166
}
167+
169168
SWIFT_DEBUG_DUMP { print(llvm::dbgs()); };
170169
};
171170

include/swift/SILOptimizer/Differentiation/PullbackEmitter.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,13 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
307307
return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock});
308308
}
309309

310+
//--------------------------------------------------------------------------//
311+
// Debugging utilities
312+
//--------------------------------------------------------------------------//
313+
314+
void printAdjointValueMapping();
315+
void printAdjointBufferMapping();
316+
310317
public:
311318
//--------------------------------------------------------------------------//
312319
// Entry point

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,66 @@ void PullbackEmitter::addToAdjointBuffer(SILBasicBlock *origBB,
514514
accumulateIndirect(adjointBuffer, rhsBufferAccess, loc);
515515
}
516516

517+
//--------------------------------------------------------------------------//
518+
// Debugging utilities
519+
//--------------------------------------------------------------------------//
520+
521+
void PullbackEmitter::printAdjointValueMapping() {
522+
// Group original/adjoint values by basic block.
523+
llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, AdjointValue>> tmp;
524+
for (auto pair : valueMap) {
525+
auto origPair = pair.first;
526+
auto *origBB = origPair.first;
527+
auto origValue = origPair.second;
528+
auto adjValue = pair.second;
529+
tmp[origBB].insert({origValue, adjValue});
530+
}
531+
// Print original/adjoint values per basic block.
532+
auto &s = getADDebugStream() << "Adjoint value mapping:\n";
533+
for (auto &origBB : getOriginal()) {
534+
if (!pullbackBBMap.count(&origBB))
535+
continue;
536+
auto bbValueMap = tmp[&origBB];
537+
s << "bb" << origBB.getDebugID();
538+
s << " (size " << bbValueMap.size() << "):\n";
539+
for (auto valuePair : bbValueMap) {
540+
auto origValue = valuePair.first;
541+
auto adjValue = valuePair.second;
542+
s << "ORIG: " << origValue;
543+
s << "ADJ: " << adjValue << '\n';
544+
}
545+
s << '\n';
546+
}
547+
}
548+
549+
void PullbackEmitter::printAdjointBufferMapping() {
550+
// Group original/adjoint buffers by basic block.
551+
llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, SILValue>> tmp;
552+
for (auto pair : bufferMap) {
553+
auto origPair = pair.first;
554+
auto *origBB = origPair.first;
555+
auto origBuf = origPair.second;
556+
auto adjBuf = pair.second;
557+
tmp[origBB][origBuf] = adjBuf;
558+
}
559+
// Print original/adjoint buffers per basic block.
560+
auto &s = getADDebugStream() << "Adjoint buffer mapping:\n";
561+
for (auto &origBB : getOriginal()) {
562+
if (!pullbackBBMap.count(&origBB))
563+
continue;
564+
auto bbBufferMap = tmp[&origBB];
565+
s << "bb" << origBB.getDebugID();
566+
s << " (size " << bbBufferMap.size() << "):\n";
567+
for (auto valuePair : bbBufferMap) {
568+
auto origBuf = valuePair.first;
569+
auto adjBuf = valuePair.second;
570+
s << "ORIG: " << origBuf;
571+
s << "ADJ: " << adjBuf << '\n';
572+
}
573+
s << '\n';
574+
}
575+
}
576+
517577
//--------------------------------------------------------------------------//
518578
// Member accessor pullback generation
519579
//--------------------------------------------------------------------------//

lib/SILOptimizer/Differentiation/VJPEmitter.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -842,15 +842,16 @@ bool VJPEmitter::run() {
842842
// `-enable-strip-ownership-after-serialization` is true.
843843
mergeBasicBlocks(vjp);
844844

845+
LLVM_DEBUG(getADDebugStream()
846+
<< "Generated VJP for " << original->getName() << ":\n"
847+
<< *vjp);
848+
845849
// Generate pullback code.
846850
PullbackEmitter PullbackEmitter(*this);
847851
if (PullbackEmitter.run()) {
848852
errorOccurred = true;
849853
return true;
850854
}
851-
LLVM_DEBUG(getADDebugStream()
852-
<< "Generated VJP for " << original->getName() << ":\n"
853-
<< *vjp);
854855
return errorOccurred;
855856
}
856857

0 commit comments

Comments
 (0)