@@ -3292,25 +3292,34 @@ static bool checkFunctionSignature(
3292
3292
return false ;
3293
3293
}
3294
3294
3295
+ // Map type into the required function type's generic signature, if it exists.
3296
+ // This is significant when the required generic signature has same-type
3297
+ // requirements while the candidate generic signature does not.
3298
+ auto mapType = [&](Type type) {
3299
+ if (!requiredGenSig)
3300
+ return type->getCanonicalType ();
3301
+ return requiredGenSig->getCanonicalTypeInContext (type);
3302
+ };
3303
+
3295
3304
// Check that parameter types match, disregarding labels.
3296
3305
if (required->getNumParams () != candidateFnTy->getNumParams ())
3297
3306
return false ;
3298
3307
if (!std::equal (required->getParams ().begin (), required->getParams ().end (),
3299
3308
candidateFnTy->getParams ().begin (),
3300
- [](AnyFunctionType::Param x, AnyFunctionType::Param y) {
3301
- return x.getPlainType ()->isEqual (y.getPlainType ());
3309
+ [& ](AnyFunctionType::Param x, AnyFunctionType::Param y) {
3310
+ return x.getPlainType ()->isEqual (mapType ( y.getPlainType () ));
3302
3311
}))
3303
3312
return false ;
3304
3313
3305
3314
// If required result type is not a function type, check that result types
3306
3315
// match exactly.
3307
3316
auto requiredResultFnTy = dyn_cast<AnyFunctionType>(required.getResult ());
3317
+ auto candidateResultTy = mapType (candidateFnTy.getResult ());
3308
3318
if (!requiredResultFnTy) {
3309
3319
auto requiredResultTupleTy = dyn_cast<TupleType>(required.getResult ());
3310
- auto candidateResultTupleTy =
3311
- dyn_cast<TupleType>(candidateFnTy.getResult ());
3320
+ auto candidateResultTupleTy = dyn_cast<TupleType>(candidateResultTy);
3312
3321
if (!requiredResultTupleTy || !candidateResultTupleTy)
3313
- return required.getResult ()->isEqual (candidateFnTy. getResult () );
3322
+ return required.getResult ()->isEqual (candidateResultTy );
3314
3323
// If result types are tuple types, check that element types match,
3315
3324
// ignoring labels.
3316
3325
if (requiredResultTupleTy->getNumElements () !=
@@ -3323,7 +3332,7 @@ static bool checkFunctionSignature(
3323
3332
}
3324
3333
3325
3334
// Required result type is a function. Recurse.
3326
- return checkFunctionSignature (requiredResultFnTy, candidateFnTy. getResult () );
3335
+ return checkFunctionSignature (requiredResultFnTy, candidateResultTy );
3327
3336
};
3328
3337
3329
3338
// Returns an `AnyFunctionType` with the same `ExtInfo` as `fnType`, but with
@@ -3607,8 +3616,20 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3607
3616
auto resultTanType = valueResultConf.getTypeWitnessByName (
3608
3617
valueResultType, Ctx.Id_TangentVector );
3609
3618
3619
+ // Compute the actual differential/pullback type that we use for comparison
3620
+ // with the expected type. We must canonicalize the derivative interface type
3621
+ // before extracting the differential/pullback type from it, so that the
3622
+ // derivative interface type generic signature is available for simplifying
3623
+ // types.
3624
+ CanType canActualResultType = derivativeInterfaceType->getCanonicalType ();
3625
+ while (isa<AnyFunctionType>(canActualResultType)) {
3626
+ canActualResultType =
3627
+ cast<AnyFunctionType>(canActualResultType).getResult ();
3628
+ }
3629
+ CanType actualFuncEltType =
3630
+ cast<TupleType>(canActualResultType).getElementType (1 );
3631
+
3610
3632
// Compute expected differential/pullback type.
3611
- auto funcEltType = funcResultElt.getType ();
3612
3633
Type expectedFuncEltType;
3613
3634
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
3614
3635
auto diffParams = map<SmallVector<AnyFunctionType::Param, 4 >>(
@@ -3624,7 +3645,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3624
3645
expectedFuncEltType = expectedFuncEltType->mapTypeOutOfContext ();
3625
3646
3626
3647
// Check if differential/pullback type matches expected type.
3627
- if (!funcEltType ->isEqual (expectedFuncEltType)) {
3648
+ if (!actualFuncEltType ->isEqual (expectedFuncEltType)) {
3628
3649
// Emit differential/pullback type mismatch error on attribute.
3629
3650
diagnoseAndRemoveAttr (attr, diag::derivative_attr_result_func_type_mismatch,
3630
3651
funcResultElt.getName (), originalAFD->getFullName ());
0 commit comments