Skip to content

Commit 1da4430

Browse files
committed
Sema: Refactor findNamedWitnessImpl() to return a ConcreteDeclRef
1 parent 754e8e2 commit 1da4430

File tree

1 file changed

+29
-25
lines changed

1 file changed

+29
-25
lines changed

lib/Sema/CSApply.cpp

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,17 @@ void Solution::computeSubstitutions(
131131
/// \param diag The diagnostic to emit if the protocol definition doesn't
132132
/// have a requirement with the given name.
133133
///
134-
/// \returns The named witness, or nullptr if no witness could be found.
135-
template <typename DeclTy>
136-
static DeclTy *findNamedWitnessImpl(
134+
/// \returns The named witness, or an empty ConcreteDeclRef if no witness
135+
/// could be found.
136+
ConcreteDeclRef findNamedWitnessImpl(
137137
TypeChecker &tc, DeclContext *dc, Type type,
138138
ProtocolDecl *proto, DeclName name,
139139
Diag<> diag,
140140
Optional<ProtocolConformanceRef> conformance = None) {
141141
// Find the named requirement.
142-
DeclTy *requirement = nullptr;
142+
ValueDecl *requirement = nullptr;
143143
for (auto member : proto->getMembers()) {
144-
auto d = dyn_cast<DeclTy>(member);
144+
auto d = dyn_cast<ValueDecl>(member);
145145
if (!d || !d->hasName())
146146
continue;
147147

@@ -159,7 +159,7 @@ static DeclTy *findNamedWitnessImpl(
159159
// Find the member used to satisfy the named requirement.
160160
if (!conformance) {
161161
conformance = tc.conformsToProtocol(type, proto, dc,
162-
ConformanceCheckFlags::InExpression);
162+
ConformanceCheckFlags::InExpression);
163163
if (!conformance)
164164
return nullptr;
165165
}
@@ -169,8 +169,7 @@ static DeclTy *findNamedWitnessImpl(
169169
if (!conformance->isConcrete())
170170
return requirement;
171171
auto concrete = conformance->getConcrete();
172-
// FIXME: Dropping substitutions here.
173-
return cast_or_null<DeclTy>(concrete->getWitnessDecl(requirement, &tc));
172+
return concrete->getWitnessDeclRef(requirement, &tc);
174173
}
175174

176175
static bool shouldAccessStorageDirectly(Expr *base, VarDecl *member,
@@ -2377,18 +2376,21 @@ namespace {
23772376
DeclName name(tc.Context, DeclBaseName::createConstructor(),
23782377
{ tc.Context.Id_stringInterpolation });
23792378
auto member
2380-
= findNamedWitnessImpl<ConstructorDecl>(
2379+
= findNamedWitnessImpl(
23812380
tc, dc, type,
23822381
interpolationProto, name,
23832382
diag::interpolation_broken_proto);
23842383

23852384
DeclName segmentName(tc.Context, DeclBaseName::createConstructor(),
23862385
{ tc.Context.Id_stringInterpolationSegment });
23872386
auto segmentMember
2388-
= findNamedWitnessImpl<ConstructorDecl>(
2387+
= findNamedWitnessImpl(
23892388
tc, dc, type, interpolationProto, segmentName,
23902389
diag::interpolation_broken_proto);
2391-
if (!member || !segmentMember)
2390+
if (!member ||
2391+
!segmentMember ||
2392+
!isa<ConstructorDecl>(member.getDecl()) ||
2393+
!isa<ConstructorDecl>(segmentMember.getDecl()))
23922394
return nullptr;
23932395

23942396
// Build a reference to the init(stringInterpolation:) initializer.
@@ -2399,7 +2401,7 @@ namespace {
23992401
Expr *memberRef =
24002402
new (tc.Context) MemberRefExpr(typeRef,
24012403
expr->getStartLoc(),
2402-
member,
2404+
member.getDecl(),
24032405
DeclNameLoc(expr->getStartLoc()),
24042406
/*Implicit=*/true);
24052407
cs.cacheSubExprTypes(memberRef);
@@ -7283,11 +7285,11 @@ Expr *ExprRewriter::convertLiteralInPlace(Expr *literal,
72837285

72847286
// Find the witness that we'll use to initialize the type via a builtin
72857287
// literal.
7286-
auto witness = findNamedWitnessImpl<AbstractFunctionDecl>(
7288+
auto witness = findNamedWitnessImpl(
72877289
tc, dc, type->getRValueType(), builtinProtocol,
72887290
builtinLiteralFuncName, brokenBuiltinProtocolDiag,
72897291
*builtinConformance);
7290-
if (!witness)
7292+
if (!witness || !isa<AbstractFunctionDecl>(witness.getDecl()))
72917293
return nullptr;
72927294

72937295
// Form a reference to the builtin conversion function.
@@ -7297,7 +7299,7 @@ Expr *ExprRewriter::convertLiteralInPlace(Expr *literal,
72977299

72987300
Expr *unresolvedDot = new (tc.Context) UnresolvedDotExpr(
72997301
base, SourceLoc(),
7300-
witness->getFullName(),
7302+
witness.getDecl()->getFullName(),
73017303
DeclNameLoc(base->getEndLoc()),
73027304
/*Implicit=*/true);
73037305
(void)tc.typeCheckExpression(unresolvedDot, dc);
@@ -7348,11 +7350,11 @@ Expr *ExprRewriter::convertLiteralInPlace(Expr *literal,
73487350
}
73497351

73507352
// Find the witness that we'll use to initialize the literal value.
7351-
auto witness = findNamedWitnessImpl<AbstractFunctionDecl>(
7353+
auto witness = findNamedWitnessImpl(
73527354
tc, dc, type->getRValueType(), protocol,
73537355
literalFuncName, brokenProtocolDiag,
73547356
conformance);
7355-
if (!witness)
7357+
if (!witness || !isa<AbstractFunctionDecl>(witness.getDecl()))
73567358
return nullptr;
73577359

73587360
// Form a reference to the conversion function.
@@ -7362,7 +7364,7 @@ Expr *ExprRewriter::convertLiteralInPlace(Expr *literal,
73627364

73637365
Expr *unresolvedDot = new (tc.Context) UnresolvedDotExpr(
73647366
base, SourceLoc(),
7365-
witness->getFullName(),
7367+
witness.getDecl()->getFullName(),
73667368
DeclNameLoc(base->getEndLoc()),
73677369
/*Implicit=*/true);
73687370
(void)tc.typeCheckExpression(unresolvedDot, dc);
@@ -8363,18 +8365,20 @@ Expr *TypeChecker::callWitness(Expr *base, DeclContext *dc,
83638365
if (auto metaType = type->getAs<AnyMetatypeType>())
83648366
type = metaType->getInstanceType();
83658367

8366-
auto witness = findNamedWitnessImpl<AbstractFunctionDecl>(
8368+
auto witness = findNamedWitnessImpl(
83678369
*this, dc, type->getRValueType(), protocol,
83688370
name, brokenProtocolDiag);
8369-
if (!witness)
8371+
if (!witness || !isa<AbstractFunctionDecl>(witness.getDecl()))
83708372
return nullptr;
83718373

8374+
auto *witnessFn = cast<AbstractFunctionDecl>(witness.getDecl());
8375+
83728376
// Form a syntactic expression that describes the reference to the
83738377
// witness.
83748378
// FIXME: Egregious hack.
83758379
auto unresolvedDot = new (Context) UnresolvedDotExpr(
83768380
base, SourceLoc(),
8377-
witness->getFullName(),
8381+
witness.getDecl()->getFullName(),
83788382
DeclNameLoc(base->getEndLoc()),
83798383
/*Implicit=*/true);
83808384
unresolvedDot->setFunctionRefKind(FunctionRefKind::SingleApply);
@@ -8383,7 +8387,7 @@ Expr *TypeChecker::callWitness(Expr *base, DeclContext *dc,
83838387
// Form a reference to the witness itself.
83848388
Type openedFullType, openedType;
83858389
std::tie(openedFullType, openedType)
8386-
= cs.getTypeOfMemberReference(base->getType(), witness, dc,
8390+
= cs.getTypeOfMemberReference(base->getType(), witness.getDecl(), dc,
83878391
/*isDynamicResult=*/false,
83888392
FunctionRefKind::DoubleApply,
83898393
dotLocator);
@@ -8396,9 +8400,9 @@ Expr *TypeChecker::callWitness(Expr *base, DeclContext *dc,
83968400
// FIXME: Standardize all callers to always provide all argument names,
83978401
// rather than hack around this.
83988402
CallExpr *call;
8399-
auto argLabels = witness->getFullName().getArgumentNames();
8403+
auto argLabels = witness.getDecl()->getFullName().getArgumentNames();
84008404
if (arguments.size() == 1 &&
8401-
(isVariadicWitness(witness) ||
8405+
(isVariadicWitness(witnessFn) ||
84028406
argumentNamesMatch(cs.getType(arguments[0]), argLabels))) {
84038407
call = CallExpr::create(Context, unresolvedDot, arguments[0], {}, {},
84048408
/*hasTrailingClosure=*/false,
@@ -8449,7 +8453,7 @@ Expr *TypeChecker::callWitness(Expr *base, DeclContext *dc,
84498453
/*suppressDiagnostics=*/false);
84508454

84518455
auto choice =
8452-
OverloadChoice(openedFullType, witness, FunctionRefKind::SingleApply);
8456+
OverloadChoice(openedFullType, witnessFn, FunctionRefKind::SingleApply);
84538457
auto memberRef = rewriter.buildMemberRef(
84548458
base, openedFullType, base->getStartLoc(), choice,
84558459
DeclNameLoc(base->getEndLoc()), openedType, dotLocator, dotLocator,

0 commit comments

Comments
 (0)