@@ -5350,6 +5350,43 @@ void SILGlobalVariable::verify() const {
5350
5350
}
5351
5351
}
5352
5352
5353
+ // SWIFT_ENABLE_TENSORFLOW
5354
+ // / Verify that a differentiability witness follows invariants.
5355
+ void SILDifferentiabilityWitness::verify (const SILModule &M) const {
5356
+ #ifdef NDEBUG
5357
+ if (!M.getOptions ().VerifyAll )
5358
+ return ;
5359
+ #endif
5360
+ auto origFnType = originalFunction->getLoweredFunctionType ();
5361
+ if (jvp) {
5362
+ // TODO: Change `SILFunctionType::getAutoDiffDerivativeFunctionType` to
5363
+ // accept result indices.
5364
+ auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType (
5365
+ getParameterIndices (), /* resultIndex*/ *resultIndices->begin (),
5366
+ AutoDiffDerivativeFunctionKind::JVP, M.Types ,
5367
+ LookUpConformanceInModule (M.getSwiftModule ()),
5368
+ getDerivativeGenericSignature ()->getCanonicalSignature ());
5369
+ SILVerifier (*jvp).requireSameType (
5370
+ SILType::getPrimitiveObjectType (jvp->getLoweredFunctionType ()),
5371
+ SILType::getPrimitiveObjectType (expectedJVPType),
5372
+ " JVP type does not match expected JVP type" );
5373
+ }
5374
+ if (vjp) {
5375
+ // TODO: Change `SILFunctionType::getAutoDiffDerivativeFunctionType` to
5376
+ // accept result indices.
5377
+ auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType (
5378
+ getParameterIndices (), /* resultIndex*/ *resultIndices->begin (),
5379
+ AutoDiffDerivativeFunctionKind::VJP, M.Types ,
5380
+ LookUpConformanceInModule (M.getSwiftModule ()),
5381
+ getDerivativeGenericSignature ()->getCanonicalSignature ());
5382
+ SILVerifier (*jvp).requireSameType (
5383
+ SILType::getPrimitiveObjectType (vjp->getLoweredFunctionType ()),
5384
+ SILType::getPrimitiveObjectType (expectedVJPType),
5385
+ " VJP type does not match expected VJP type" );
5386
+ }
5387
+ }
5388
+ // SWIFT_ENABLE_TENSORFLOW END
5389
+
5353
5390
// / Verify the module.
5354
5391
void SILModule::verify () const {
5355
5392
#ifdef NDEBUG
@@ -5433,6 +5470,22 @@ void SILModule::verify() const {
5433
5470
}
5434
5471
wt.verify (*this );
5435
5472
}
5473
+
5474
+ // SWIFT_ENABLE_TENSORFLOW
5475
+ // Check all differentiability witnesses.
5476
+ LLVM_DEBUG (llvm::dbgs () <<
5477
+ " *** Checking differentiability witnesses for duplicates ***\n " );
5478
+ llvm::DenseSet<SILDifferentiabilityWitnessKey> diffWitnesses;
5479
+ for (auto &dw : getDifferentiabilityWitnesses ()) {
5480
+ LLVM_DEBUG (llvm::dbgs () << " Differentiability Witness:\n " ; dw.dump ());
5481
+ if (!diffWitnesses.insert (dw.getKey ()).second ) {
5482
+ llvm::errs () << " Differentiability witness redefined: " ;
5483
+ dw.dump ();
5484
+ assert (false && " triggering standard assertion failure routine" );
5485
+ }
5486
+ dw.verify (*this );
5487
+ }
5488
+ // SWIFT_ENABLE_TENSORFLOW END
5436
5489
5437
5490
// Check property descriptors.
5438
5491
LLVM_DEBUG (llvm::dbgs () << " *** Checking property descriptors ***\n " );
0 commit comments