@@ -3731,18 +3731,23 @@ enum class AbstractFunctionDeclLookupErrorKind {
3731
3731
CandidateNotFunctionDeclaration
3732
3732
};
3733
3733
3734
- // / Returns the function declaration corresponding to the given base type
3735
- // / (optional), function name, and lookup context.
3734
+ // / Returns the original function (in the context of a derivative or transpose
3735
+ // / function) declaration corresponding to the given base type (optional),
3736
+ // / function name, lookup context, and the expected original function type.
3736
3737
// /
3737
3738
// / If the base type of the function is specified, member lookup is performed.
3738
3739
// / Otherwise, unqualified lookup is performed.
3739
3740
// /
3741
+ // / If the expected original function type has a generic signature, any
3742
+ // / candidate with a less constrained type signature than the expected original
3743
+ // / function type will be treated as a viable candidate.
3744
+ // /
3740
3745
// / If the function declaration cannot be resolved, emits a diagnostic and
3741
3746
// / returns nullptr.
3742
3747
// /
3743
3748
// / Used for resolving the referenced declaration in `@derivative` and
3744
3749
// / `@transpose` attributes.
3745
- static AbstractFunctionDecl *findAbstractFunctionDecl (
3750
+ static AbstractFunctionDecl *findAutoDiffOriginalFunctionDecl (
3746
3751
DeclAttribute *attr, Type baseType, DeclNameRefWithLoc funcNameWithLoc,
3747
3752
DeclContext *lookupContext, NameLookupOptions lookupOptions,
3748
3753
const llvm::function_ref<Optional<AbstractFunctionDeclLookupErrorKind>(
@@ -4671,7 +4676,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
4671
4676
}
4672
4677
4673
4678
// Look up original function.
4674
- auto *originalAFD = findAbstractFunctionDecl (
4679
+ auto *originalAFD = findAutoDiffOriginalFunctionDecl (
4675
4680
attr, baseType, originalName, derivativeTypeCtx, lookupOptions,
4676
4681
isValidOriginalCandidate, originalFnType);
4677
4682
if (!originalAFD) {
@@ -5230,7 +5235,7 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
5230
5235
auto funcLoc = originalName.Loc .getBaseNameLoc ();
5231
5236
if (attr->getBaseTypeRepr ())
5232
5237
funcLoc = attr->getBaseTypeRepr ()->getLoc ();
5233
- auto *originalAFD = findAbstractFunctionDecl (
5238
+ auto *originalAFD = findAutoDiffOriginalFunctionDecl (
5234
5239
attr, baseType, originalName, transposeTypeCtx, lookupOptions,
5235
5240
isValidOriginalCandidate, expectedOriginalFnType);
5236
5241
if (!originalAFD) {
0 commit comments