Skip to content

Commit f28cc66

Browse files
authored
[AutoDiff] Move dumpActivityInfo to DifferentiableActivityInfo class (#28597)
1 parent fdbdf7e commit f28cc66

File tree

3 files changed

+41
-36
lines changed

3 files changed

+41
-36
lines changed

include/swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class DifferentiableActivityInfo {
131131
SmallVector<SmallDenseSet<SILValue>, 4> usefulValueSets;
132132

133133
/// The original function.
134-
SILFunction &getFunction();
134+
SILFunction &getFunction() const;
135135

136136
/// Returns true if the given SILValue has a tangent space.
137137
bool hasTangentSpace(SILValue value) {
@@ -206,6 +206,14 @@ class DifferentiableActivityInfo {
206206
/// Returns the activity of the given value for the given `SILAutoDiffIndices`
207207
/// (parameter indices and result index).
208208
Activity getActivity(SILValue value, const SILAutoDiffIndices &indices) const;
209+
210+
/// Prints activity information for the `indices` of the given `value`.
211+
void dump(SILValue value, const SILAutoDiffIndices &indices,
212+
llvm::raw_ostream &s = llvm::dbgs()) const;
213+
214+
/// Prints activity information for the given `indices`.
215+
void dump(SILAutoDiffIndices indices,
216+
llvm::raw_ostream &s = llvm::dbgs()) const;
209217
};
210218

211219
class DifferentiableActivityCollection {

lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ DifferentiableActivityInfo::DifferentiableActivityInfo(
5959
analyze(parent.domInfo, parent.postDomInfo);
6060
}
6161

62-
SILFunction &DifferentiableActivityInfo::getFunction() {
62+
SILFunction &DifferentiableActivityInfo::getFunction() const {
6363
return parent.function;
6464
}
6565

@@ -421,3 +421,32 @@ Activity DifferentiableActivityInfo::getActivity(
421421
activity |= ActivityFlags::Useful;
422422
return activity;
423423
}
424+
425+
void DifferentiableActivityInfo::dump(SILValue value,
426+
const SILAutoDiffIndices &indices,
427+
llvm::raw_ostream &s) const {
428+
s << '[';
429+
auto activity = getActivity(value, indices);
430+
switch (activity.toRaw()) {
431+
case 0: s << "NONE"; break;
432+
case (unsigned)ActivityFlags::Varied: s << "VARIED"; break;
433+
case (unsigned)ActivityFlags::Useful: s << "USEFUL"; break;
434+
case (unsigned)ActivityFlags::Active: s << "ACTIVE"; break;
435+
}
436+
s << "] " << value;
437+
}
438+
439+
void DifferentiableActivityInfo::dump(SILAutoDiffIndices indices,
440+
llvm::raw_ostream &s) const {
441+
SILFunction &fn = getFunction();
442+
s << "Activity info for " << fn.getName() << " at " << indices << '\n';
443+
for (auto &bb : fn) {
444+
s << "bb" << bb.getDebugID() << ":\n";
445+
for (auto *arg : bb.getArguments())
446+
dump(arg, indices, s);
447+
for (auto &inst : bb)
448+
for (auto res : inst.getResults())
449+
dump(res, indices, s);
450+
s << '\n';
451+
}
452+
}

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -263,36 +263,6 @@ class DifferentiationTransformer {
263263

264264
} // end anonymous namespace
265265

266-
static void dumpActivityInfo(SILValue value,
267-
const SILAutoDiffIndices &indices,
268-
const DifferentiableActivityInfo &activityInfo,
269-
llvm::raw_ostream &s = llvm::dbgs()) {
270-
s << '[';
271-
auto activity = activityInfo.getActivity(value, indices);
272-
switch (activity.toRaw()) {
273-
case 0: s << "NONE"; break;
274-
case (unsigned)ActivityFlags::Varied: s << "VARIED"; break;
275-
case (unsigned)ActivityFlags::Useful: s << "USEFUL"; break;
276-
case (unsigned)ActivityFlags::Active: s << "ACTIVE"; break;
277-
}
278-
s << "] " << value;
279-
}
280-
281-
static void dumpActivityInfo(SILFunction &fn, SILAutoDiffIndices indices,
282-
const DifferentiableActivityInfo &activityInfo,
283-
llvm::raw_ostream &s = llvm::dbgs()) {
284-
s << "Activity info for " << fn.getName() << " at " << indices << '\n';
285-
for (auto &bb : fn) {
286-
s << "bb" << bb.getDebugID() << ":\n";
287-
for (auto *arg : bb.getArguments())
288-
dumpActivityInfo(arg, indices, activityInfo, s);
289-
for (auto &inst : bb)
290-
for (auto res : inst.getResults())
291-
dumpActivityInfo(res, indices, activityInfo, s);
292-
s << '\n';
293-
}
294-
}
295-
296266
/// If the original function doesn't have a return, it cannot be differentiated.
297267
/// Returns true if error is emitted.
298268
static bool diagnoseNoReturn(ADContext &context, SILFunction *original,
@@ -1388,8 +1358,7 @@ class VJPEmitter final
13881358
auto &activityInfo = activityCollection.getActivityInfo(
13891359
vjp->getLoweredFunctionType()->getSubstGenericSignature(),
13901360
AutoDiffDerivativeFunctionKind::VJP);
1391-
LLVM_DEBUG(
1392-
dumpActivityInfo(*original, indices, activityInfo, getADDebugStream()));
1361+
LLVM_DEBUG(activityInfo.dump(indices, getADDebugStream()));
13931362
return activityInfo;
13941363
}
13951364

@@ -2141,8 +2110,7 @@ class JVPEmitter final
21412110
auto &activityInfo = activityCollection.getActivityInfo(
21422111
jvp->getLoweredFunctionType()->getSubstGenericSignature(),
21432112
AutoDiffDerivativeFunctionKind::JVP);
2144-
LLVM_DEBUG(
2145-
dumpActivityInfo(*original, indices, activityInfo, getADDebugStream()));
2113+
LLVM_DEBUG(activityInfo.dump(indices, getADDebugStream()));
21462114
return activityInfo;
21472115
}
21482116

0 commit comments

Comments
 (0)