Skip to content

Commit b6cd1d7

Browse files
committed
Add SIL verification.
1 parent 419eea2 commit b6cd1d7

File tree

2 files changed

+61
-5
lines changed

2 files changed

+61
-5
lines changed

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ class SILDifferentiabilityWitness
7575
serialized(isSerialized) {}
7676

7777
public:
78+
static SILDifferentiabilityWitness *create(
79+
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
80+
AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices,
81+
GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
82+
bool isSerialized);
83+
7884
SILDifferentiabilityWitnessKey getKey() const;
7985
SILModule &getModule() const { return module; }
8086
SILLinkage getLinkage() const { return linkage; }
@@ -92,11 +98,8 @@ class SILDifferentiabilityWitness
9298
SILFunction *getVJP() const { return vjp; }
9399
bool isSerialized() const { return serialized; }
94100

95-
static SILDifferentiabilityWitness *create(
96-
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
97-
AutoDiffIndexSubset *parameterIndices, AutoDiffIndexSubset *resultIndices,
98-
GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
99-
bool isSerialized);
101+
/// Verify that the differentiability witness is well-formed.
102+
void verify(const SILModule &M) const;
100103

101104
void print(llvm::raw_ostream &OS, bool verbose = false) const;
102105
void dump() const;

lib/SIL/SILVerifier.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5350,6 +5350,43 @@ void SILGlobalVariable::verify() const {
53505350
}
53515351
}
53525352

5353+
// SWIFT_ENABLE_TENSORFLOW
5354+
/// Verify that a differentiability witness follows invariants.
5355+
void SILDifferentiabilityWitness::verify(const SILModule &M) const {
5356+
#ifdef NDEBUG
5357+
if (!M.getOptions().VerifyAll)
5358+
return;
5359+
#endif
5360+
auto origFnType = originalFunction->getLoweredFunctionType();
5361+
if (jvp) {
5362+
// TODO: Change `SILFunctionType::getAutoDiffDerivativeFunctionType` to
5363+
// accept result indices.
5364+
auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType(
5365+
getParameterIndices(), /*resultIndex*/ *resultIndices->begin(),
5366+
AutoDiffDerivativeFunctionKind::JVP, M.Types,
5367+
LookUpConformanceInModule(M.getSwiftModule()),
5368+
getDerivativeGenericSignature()->getCanonicalSignature());
5369+
SILVerifier(*jvp).requireSameType(
5370+
SILType::getPrimitiveObjectType(jvp->getLoweredFunctionType()),
5371+
SILType::getPrimitiveObjectType(expectedJVPType),
5372+
"JVP type does not match expected JVP type");
5373+
}
5374+
if (vjp) {
5375+
// TODO: Change `SILFunctionType::getAutoDiffDerivativeFunctionType` to
5376+
// accept result indices.
5377+
auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType(
5378+
getParameterIndices(), /*resultIndex*/ *resultIndices->begin(),
5379+
AutoDiffDerivativeFunctionKind::VJP, M.Types,
5380+
LookUpConformanceInModule(M.getSwiftModule()),
5381+
getDerivativeGenericSignature()->getCanonicalSignature());
5382+
SILVerifier(*jvp).requireSameType(
5383+
SILType::getPrimitiveObjectType(vjp->getLoweredFunctionType()),
5384+
SILType::getPrimitiveObjectType(expectedVJPType),
5385+
"VJP type does not match expected VJP type");
5386+
}
5387+
}
5388+
// SWIFT_ENABLE_TENSORFLOW END
5389+
53535390
/// Verify the module.
53545391
void SILModule::verify() const {
53555392
#ifdef NDEBUG
@@ -5433,6 +5470,22 @@ void SILModule::verify() const {
54335470
}
54345471
wt.verify(*this);
54355472
}
5473+
5474+
// SWIFT_ENABLE_TENSORFLOW
5475+
// Check all differentiability witnesses.
5476+
LLVM_DEBUG(llvm::dbgs() <<
5477+
"*** Checking differentiability witnesses for duplicates ***\n");
5478+
llvm::DenseSet<SILDifferentiabilityWitnessKey> diffWitnesses;
5479+
for (auto &dw : getDifferentiabilityWitnesses()) {
5480+
LLVM_DEBUG(llvm::dbgs() << "Differentiability Witness:\n"; dw.dump());
5481+
if (!diffWitnesses.insert(dw.getKey()).second) {
5482+
llvm::errs() << "Differentiability witness redefined: ";
5483+
dw.dump();
5484+
assert(false && "triggering standard assertion failure routine");
5485+
}
5486+
dw.verify(*this);
5487+
}
5488+
// SWIFT_ENABLE_TENSORFLOW END
54365489

54375490
// Check property descriptors.
54385491
LLVM_DEBUG(llvm::dbgs() << "*** Checking property descriptors ***\n");

0 commit comments

Comments
 (0)