@@ -3263,25 +3263,34 @@ static bool checkFunctionSignature(
3263
3263
return false ;
3264
3264
}
3265
3265
3266
+ // Map type into the required function type's generic signature, if it exists.
3267
+ // This is significant when the required generic signature has same-type
3268
+ // requirements while the candidate generic signature does not.
3269
+ auto mapType = [&](Type type) {
3270
+ if (!requiredGenSig)
3271
+ return type->getCanonicalType ();
3272
+ return requiredGenSig->getCanonicalTypeInContext (type);
3273
+ };
3274
+
3266
3275
// Check that parameter types match, disregarding labels.
3267
3276
if (required->getNumParams () != candidateFnTy->getNumParams ())
3268
3277
return false ;
3269
3278
if (!std::equal (required->getParams ().begin (), required->getParams ().end (),
3270
3279
candidateFnTy->getParams ().begin (),
3271
- [](AnyFunctionType::Param x, AnyFunctionType::Param y) {
3272
- return x.getPlainType ()->isEqual (y.getPlainType ());
3280
+ [& ](AnyFunctionType::Param x, AnyFunctionType::Param y) {
3281
+ return x.getPlainType ()->isEqual (mapType ( y.getPlainType () ));
3273
3282
}))
3274
3283
return false ;
3275
3284
3276
3285
// If required result type is not a function type, check that result types
3277
3286
// match exactly.
3278
3287
auto requiredResultFnTy = dyn_cast<AnyFunctionType>(required.getResult ());
3288
+ auto candidateResultTy = mapType (candidateFnTy.getResult ());
3279
3289
if (!requiredResultFnTy) {
3280
3290
auto requiredResultTupleTy = dyn_cast<TupleType>(required.getResult ());
3281
- auto candidateResultTupleTy =
3282
- dyn_cast<TupleType>(candidateFnTy.getResult ());
3291
+ auto candidateResultTupleTy = dyn_cast<TupleType>(candidateResultTy);
3283
3292
if (!requiredResultTupleTy || !candidateResultTupleTy)
3284
- return required.getResult ()->isEqual (candidateFnTy. getResult () );
3293
+ return required.getResult ()->isEqual (candidateResultTy );
3285
3294
// If result types are tuple types, check that element types match,
3286
3295
// ignoring labels.
3287
3296
if (requiredResultTupleTy->getNumElements () !=
@@ -3294,7 +3303,7 @@ static bool checkFunctionSignature(
3294
3303
}
3295
3304
3296
3305
// Required result type is a function. Recurse.
3297
- return checkFunctionSignature (requiredResultFnTy, candidateFnTy. getResult () );
3306
+ return checkFunctionSignature (requiredResultFnTy, candidateResultTy );
3298
3307
};
3299
3308
3300
3309
// Returns an `AnyFunctionType` with the same `ExtInfo` as `fnType`, but with
@@ -3578,8 +3587,20 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3578
3587
auto resultTanType = valueResultConf.getTypeWitnessByName (
3579
3588
valueResultType, Ctx.Id_TangentVector );
3580
3589
3590
+ // Compute the actual differential/pullback type that we use for comparison
3591
+ // with the expected type. We must canonicalize the derivative interface type
3592
+ // before extracting the differential/pullback type from it, so that the
3593
+ // derivative interface type generic signature is available for simplifying
3594
+ // types.
3595
+ CanType canActualResultType = derivativeInterfaceType->getCanonicalType ();
3596
+ while (isa<AnyFunctionType>(canActualResultType)) {
3597
+ canActualResultType =
3598
+ cast<AnyFunctionType>(canActualResultType).getResult ();
3599
+ }
3600
+ CanType actualFuncEltType =
3601
+ cast<TupleType>(canActualResultType).getElementType (1 );
3602
+
3581
3603
// Compute expected differential/pullback type.
3582
- auto funcEltType = funcResultElt.getType ();
3583
3604
Type expectedFuncEltType;
3584
3605
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
3585
3606
auto diffParams = map<SmallVector<AnyFunctionType::Param, 4 >>(
@@ -3595,7 +3616,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3595
3616
expectedFuncEltType = expectedFuncEltType->mapTypeOutOfContext ();
3596
3617
3597
3618
// Check if differential/pullback type matches expected type.
3598
- if (!funcEltType ->isEqual (expectedFuncEltType)) {
3619
+ if (!actualFuncEltType ->isEqual (expectedFuncEltType)) {
3599
3620
// Emit differential/pullback type mismatch error on attribute.
3600
3621
diagnoseAndRemoveAttr (attr, diag::derivative_attr_result_func_type_mismatch,
3601
3622
funcResultElt.getName (), originalAFD->getFullName ());
0 commit comments