@@ -1491,7 +1491,9 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
1491
1491
adfi->getParameterIndices (), /* resultIndex*/ 0 , order,
1492
1492
AutoDiffAssociatedFunctionKind::JVP, F.getModule (),
1493
1493
LookUpConformanceInModule (F.getModule ().getSwiftModule ()));
1494
- require (expectedJVPType == jvpType, " Unexpected JVP function type" );
1494
+ requireSameType (SILType::getPrimitiveObjectType (jvpType),
1495
+ SILType::getPrimitiveObjectType (expectedJVPType),
1496
+ " JVP type does not match expected JVP type" );
1495
1497
auto vjpType = pair.second ->getType ().getAs <SILFunctionType>();
1496
1498
require (vjpType, " The VJP function must have a function type" );
1497
1499
require (!vjpType->isDifferentiable (),
@@ -1500,7 +1502,9 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
1500
1502
adfi->getParameterIndices (), /* resultIndex*/ 0 , order,
1501
1503
AutoDiffAssociatedFunctionKind::VJP, F.getModule (),
1502
1504
LookUpConformanceInModule (F.getModule ().getSwiftModule ()));
1503
- require (expectedVJPType == vjpType, " Unexpected VJP function type" );
1505
+ requireSameType (SILType::getPrimitiveObjectType (vjpType),
1506
+ SILType::getPrimitiveObjectType (expectedVJPType),
1507
+ " VJP type does not match expected VJP type" );
1504
1508
}
1505
1509
}
1506
1510
}
0 commit comments