Skip to content

Commit 3cb6940

Browse files
authored
Merge pull request #34887 from gottesmm/pr-a2f43c89854b9be02db9525f4baea04ff3bca6b8
[autodiff] When asserts are enabled, verify all autodiff compiler generated functions.
2 parents 2328132 + 25ebb5d commit 3cb6940

File tree

4 files changed

+38
-5
lines changed

4 files changed

+38
-5
lines changed

include/swift/SILOptimizer/Differentiation/JVPCloner.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class JVPCloner final {
4747
/// Performs JVP generation on the empty JVP function. Returns true if any
4848
/// error occurs.
4949
bool run();
50+
51+
SILFunction &getJVP() const;
5052
};
5153

5254
} // end namespace autodiff

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
3232
#include "llvm/ADT/DenseMap.h"
3333

34+
using namespace swift;
35+
using namespace autodiff;
36+
3437
namespace swift {
3538
namespace autodiff {
3639

@@ -380,6 +383,8 @@ class JVPCloner::Implementation final
380383
/// Run JVP generation. Returns true on error.
381384
bool run();
382385

386+
SILFunction &getJVP() const { return *jvp; }
387+
383388
void postProcess(SILInstruction *orig, SILInstruction *cloned) {
384389
if (errorOccurred)
385390
return;
@@ -1727,7 +1732,16 @@ bool JVPCloner::Implementation::run() {
17271732
return errorOccurred;
17281733
}
17291734

1730-
bool JVPCloner::run() { return impl.run(); }
1731-
17321735
} // end namespace autodiff
17331736
} // end namespace swift
1737+
1738+
bool JVPCloner::run() {
1739+
bool foundError = impl.run();
1740+
#ifndef NDEBUG
1741+
if (!foundError)
1742+
getJVP().verify();
1743+
#endif
1744+
return foundError;
1745+
}
1746+
1747+
SILFunction &JVPCloner::getJVP() const { return impl.getJVP(); }

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ class PullbackCloner::Implementation final
138138
SILModule &getModule() const { return getContext().getModule(); }
139139
ASTContext &getASTContext() const { return getPullback().getASTContext(); }
140140
SILFunction &getOriginal() const { return vjpCloner.getOriginal(); }
141-
SILFunction &getPullback() const { return vjpCloner.getPullback(); }
142141
SILDifferentiabilityWitness *getWitness() const {
143142
return vjpCloner.getWitness();
144143
}
@@ -782,6 +781,10 @@ class PullbackCloner::Implementation final
782781
/// parameters.
783782
void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult);
784783

784+
/// Public helper so that our users can get the underlying newly created
785+
/// function.
786+
SILFunction &getPullback() const { return vjpCloner.getPullback(); }
787+
785788
using TrampolineBlockSet = SmallPtrSet<SILBasicBlock *, 4>;
786789

787790
/// Determines the pullback successor block for a given original block and one
@@ -1740,7 +1743,14 @@ PullbackCloner::~PullbackCloner() { delete &impl; }
17401743
// Entry point
17411744
//--------------------------------------------------------------------------//
17421745

1743-
bool PullbackCloner::run() { return impl.run(); }
1746+
bool PullbackCloner::run() {
1747+
bool foundError = impl.run();
1748+
#ifndef NDEBUG
1749+
if (!foundError)
1750+
impl.getPullback().verify();
1751+
#endif
1752+
return foundError;
1753+
}
17441754

17451755
bool PullbackCloner::Implementation::run() {
17461756
PrettyStackTraceSILFunction trace("generating pullback for", &getOriginal());

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,14 @@ bool VJPCloner::Implementation::run() {
10191019
return errorOccurred;
10201020
}
10211021

1022-
bool VJPCloner::run() { return impl.run(); }
1022+
bool VJPCloner::run() {
1023+
bool foundError = impl.run();
1024+
#ifndef NDEBUG
1025+
if (!foundError)
1026+
getVJP().verify();
1027+
#endif
1028+
return foundError;
1029+
}
10231030

10241031
} // end namespace autodiff
10251032
} // end namespace swift

0 commit comments

Comments
 (0)