Skip to content

Commit e4d306f

Browse files
authored
Merge pull request #17168 from huonw/validate-conditional-type-in-expr
[Sema] More aggressive consideration of conditionally defined types in expressions
2 parents 6ec69b7 + 28460bc commit e4d306f

File tree

8 files changed

+325
-60
lines changed

8 files changed

+325
-60
lines changed

lib/Sema/ConstraintSystem.cpp

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -481,36 +481,74 @@ Type ConstraintSystem::openUnboundGenericType(UnboundGenericType *unbound,
481481
/*unsatisfiedDependency*/nullptr);
482482
}
483483

484-
Type ConstraintSystem::openUnboundGenericType(
485-
Type type,
486-
ConstraintLocatorBuilder locator) {
487-
assert(!type->hasTypeParameter());
484+
static void checkNestedTypeConstraints(ConstraintSystem &cs, Type type,
485+
ConstraintLocatorBuilder locator) {
486+
// If this is a type defined inside of constrainted extension, let's add all
487+
// of the generic requirements to the constraint system to make sure that it's
488+
// something we can use.
489+
GenericTypeDecl *decl = nullptr;
490+
Type parentTy;
491+
SubstitutionMap subMap;
488492

489-
// If this is a generic typealias defined inside of constrainted extension,
490-
// let's add all of the generic requirements to the constraint system to
491-
// make sure that it's something we can use.
492493
if (auto *NAT = dyn_cast<NameAliasType>(type.getPointer())) {
493-
auto *decl = NAT->getDecl();
494-
auto *extension = dyn_cast<ExtensionDecl>(decl->getDeclContext());
495-
auto parentTy = NAT->getParent();
496-
494+
decl = NAT->getDecl();
495+
parentTy = NAT->getParent();
496+
subMap = NAT->getSubstitutionMap();
497+
} else if (auto *AGT = type->getAs<AnyGenericType>()) {
498+
decl = AGT->getDecl();
499+
parentTy = AGT->getParent();
500+
// the context substitution map is fine here, since we can't be adding more
501+
// info than that, unlike a typealias
502+
}
503+
504+
// If this decl is generic, the constraints are handled when the generic
505+
// parameters are applied, so we don't have to handle them here (which makes
506+
// getting the right substitution maps easier).
507+
if (decl && !decl->isGeneric()) {
508+
auto extension = dyn_cast<ExtensionDecl>(decl->getDeclContext());
497509
if (parentTy && extension && extension->isConstrainedExtension()) {
498-
auto subMap = NAT->getSubstitutionMap();
499510
auto contextSubMap = parentTy->getContextSubstitutionMap(
500511
extension->getParentModule(),
501512
extension->getAsNominalTypeOrNominalTypeExtensionContext());
513+
if (!subMap) {
514+
// The substitution map wasn't set above, meaning we should grab the map
515+
// for the extension itself.
516+
subMap = parentTy->getContextSubstitutionMap(
517+
extension->getParentModule(), extension);
518+
}
502519

503-
if (auto *signature = NAT->getGenericSignature()) {
504-
openGenericRequirements(
520+
if (auto *signature = decl->getGenericSignature()) {
521+
cs.openGenericRequirements(
505522
extension, signature, /*skipProtocolSelfConstraint*/ true, locator,
506523
[&](Type type) {
524+
// Why do we look in two substitution maps? We have to use the
525+
// context substitution map to find types, because we need to
526+
// avoid thinking about them when handling the constraints, or all
527+
// the requirements in the signature become tautologies (if the
528+
// extension has 'T == Int', subMap will map T -> Int, so the
529+
// requirement becomes Int == Int no matter what the actual types
530+
// are here). However, we need the conformances for the extension
531+
// because the requirements might look like `T: P, T.U: Q`, where
532+
// U is an associated type of protocol P.
507533
return type.subst(QuerySubstitutionMap{contextSubMap},
508534
LookUpConformanceInSubstitutionMap(subMap),
509535
SubstFlags::UseErrorType);
510536
});
511537
}
512538
}
539+
540+
// And now make sure sure the parent is okay, for things like X<T>.Y.Z.
541+
if (parentTy) {
542+
checkNestedTypeConstraints(cs, parentTy, locator);
543+
}
513544
}
545+
}
546+
547+
Type ConstraintSystem::openUnboundGenericType(
548+
Type type, ConstraintLocatorBuilder locator) {
549+
assert(!type->hasTypeParameter());
550+
551+
checkNestedTypeConstraints(*this, type, locator);
514552

515553
if (!type->hasUnboundGenericType())
516554
return type;

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,9 +1215,9 @@ TypeExpr *PreCheckExpression::simplifyNestedTypeExpr(UnresolvedDotExpr *UDE) {
12151215
// If there is no nested type with this name, we have a lookup of
12161216
// a non-type member, so leave the expression as-is.
12171217
if (Result.size() == 1) {
1218-
return TypeExpr::createForMemberDecl(
1219-
DRE->getNameLoc().getBaseNameLoc(), TD,
1220-
NameLoc, Result.front().first);
1218+
return TypeExpr::createForMemberDecl(DRE->getNameLoc().getBaseNameLoc(),
1219+
TD, NameLoc,
1220+
Result.front().Member);
12211221
}
12221222
}
12231223

@@ -1273,7 +1273,7 @@ TypeExpr *PreCheckExpression::simplifyNestedTypeExpr(UnresolvedDotExpr *UDE) {
12731273
// a non-type member, so leave the expression as-is.
12741274
if (Result.size() == 1) {
12751275
return TypeExpr::createForMemberDecl(ITR, NameLoc,
1276-
Result.front().first);
1276+
Result.front().Member);
12771277
}
12781278
}
12791279
}

lib/Sema/TypeCheckNameLookup.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ LookupTypeResult TypeChecker::lookupMemberType(DeclContext *dc,
447447
// Add the type to the result set, so that we can diagnose the
448448
// reference instead of just saying the member does not exist.
449449
if (types.insert(memberType->getCanonicalType()).second)
450-
result.Results.push_back({typeDecl, memberType});
450+
result.Results.push_back({typeDecl, memberType, nullptr});
451451

452452
continue;
453453
}
@@ -492,7 +492,7 @@ LookupTypeResult TypeChecker::lookupMemberType(DeclContext *dc,
492492

493493
// If we haven't seen this type result yet, add it to the result set.
494494
if (types.insert(memberType->getCanonicalType()).second)
495-
result.Results.push_back({typeDecl, memberType});
495+
result.Results.push_back({typeDecl, memberType, nullptr});
496496
}
497497

498498
if (result.Results.empty()) {
@@ -521,19 +521,21 @@ LookupTypeResult TypeChecker::lookupMemberType(DeclContext *dc,
521521

522522
// Use the type witness.
523523
auto concrete = conformance->getConcrete();
524-
Type memberType = concrete->getTypeWitness(assocType, this);
525524

526525
// This is the only case where NormalProtocolConformance::
527526
// getTypeWitnessAndDecl() returns a null type.
528527
if (concrete->getState() ==
529528
ProtocolConformanceState::CheckingTypeWitnesses)
530529
continue;
531530

532-
assert(memberType && "Missing type witness?");
531+
auto typeDecl = concrete->getTypeWitnessAndDecl(assocType, this).second;
533532

534-
// If we haven't seen this type result yet, add it to the result set.
533+
assert(typeDecl && "Missing type witness?");
534+
535+
auto memberType =
536+
substMemberTypeWithBase(dc->getParentModule(), typeDecl, type);
535537
if (types.insert(memberType->getCanonicalType()).second)
536-
result.Results.push_back({assocType, memberType});
538+
result.Results.push_back({typeDecl, memberType, assocType});
537539
}
538540
}
539541

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3098,23 +3098,23 @@ ResolveWitnessResult ConformanceChecker::resolveTypeWitnessViaLookup(
30983098
}
30993099

31003100
// Determine which of the candidates is viable.
3101-
SmallVector<std::pair<TypeDecl *, Type>, 2> viable;
3101+
SmallVector<LookupTypeResultEntry, 2> viable;
31023102
SmallVector<std::pair<TypeDecl *, CheckTypeWitnessResult>, 2> nonViable;
31033103
for (auto candidate : candidates) {
31043104
// Skip nested generic types.
3105-
if (auto *genericDecl = dyn_cast<GenericTypeDecl>(candidate.first))
3105+
if (auto *genericDecl = dyn_cast<GenericTypeDecl>(candidate.Member))
31063106
if (genericDecl->getGenericParams())
31073107
continue;
31083108

31093109
// Skip typealiases with an unbound generic type as their underlying type.
3110-
if (auto *typeAliasDecl = dyn_cast<TypeAliasDecl>(candidate.first))
3110+
if (auto *typeAliasDecl = dyn_cast<TypeAliasDecl>(candidate.Member))
31113111
if (typeAliasDecl->getDeclaredInterfaceType()->is<UnboundGenericType>())
31123112
continue;
31133113

31143114
// Check this type against the protocol requirements.
3115-
if (auto checkResult = checkTypeWitness(TC, DC, Proto, assocType,
3116-
candidate.second)) {
3117-
nonViable.push_back({candidate.first, checkResult});
3115+
if (auto checkResult =
3116+
checkTypeWitness(TC, DC, Proto, assocType, candidate.MemberType)) {
3117+
nonViable.push_back({candidate.Member, checkResult});
31183118
} else {
31193119
viable.push_back(candidate);
31203120
}
@@ -3132,10 +3132,10 @@ ResolveWitnessResult ConformanceChecker::resolveTypeWitnessViaLookup(
31323132

31333133
// If there is a single viable candidate, form a substitution for it.
31343134
if (viable.size() == 1) {
3135-
auto interfaceType = viable.front().second;
3135+
auto interfaceType = viable.front().MemberType;
31363136
if (interfaceType->hasArchetype())
31373137
interfaceType = interfaceType->mapTypeOutOfContext();
3138-
recordTypeWitness(assocType, interfaceType, viable.front().first, true);
3138+
recordTypeWitness(assocType, interfaceType, viable.front().Member, true);
31393139
return ResolveWitnessResult::Success;
31403140
}
31413141

@@ -3151,7 +3151,7 @@ ResolveWitnessResult ConformanceChecker::resolveTypeWitnessViaLookup(
31513151
assocType->getName());
31523152

31533153
for (auto candidate : viable)
3154-
diags.diagnose(candidate.first, diag::protocol_witness_type);
3154+
diags.diagnose(candidate.Member, diag::protocol_witness_type);
31553155
});
31563156

31573157
return ResolveWitnessResult::ExplicitFailed;

lib/Sema/TypeCheckType.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ static Type getStdlibType(TypeChecker &TC, Type &cached, DeclContext *dc,
123123
TC.Context.getIdentifier(
124124
name));
125125
if (lookup)
126-
cached = lookup.back().second;
126+
cached = lookup.back().MemberType;
127127
}
128128
return cached;
129129
}
@@ -203,7 +203,7 @@ static Type getObjectiveCNominalType(TypeChecker &TC,
203203
if (auto result = TC.lookupMemberType(dc, ModuleType::get(module), TypeName,
204204
lookupOptions)) {
205205
for (auto pair : result) {
206-
if (auto nominal = dyn_cast<NominalTypeDecl>(pair.first)) {
206+
if (auto nominal = dyn_cast<NominalTypeDecl>(pair.Member)) {
207207
cache = nominal->getDeclaredType();
208208
return cache;
209209
}
@@ -934,14 +934,14 @@ static Type diagnoseUnknownType(TypeChecker &tc, DeclContext *dc,
934934
relookupOptions);
935935
if (inaccessibleMembers) {
936936
// FIXME: What if the unviable candidates have different levels of access?
937-
const TypeDecl *first = inaccessibleMembers.front().first;
937+
const TypeDecl *first = inaccessibleMembers.front().Member;
938938
tc.diagnose(comp->getIdLoc(), diag::candidate_inaccessible,
939939
comp->getIdentifier(), first->getFormalAccess());
940940

941941
// FIXME: If any of the candidates (usually just one) are in the same module
942942
// we could offer a fix-it.
943943
for (auto lookupResult : inaccessibleMembers)
944-
tc.diagnose(lookupResult.first, diag::kind_declared_here,
944+
tc.diagnose(lookupResult.Member, diag::kind_declared_here,
945945
DescriptiveDeclKind::Type);
946946

947947
// Don't try to recover here; we'll get more access-related diagnostics
@@ -1254,7 +1254,8 @@ static Type resolveNestedIdentTypeComponent(
12541254
bool diagnoseErrors,
12551255
GenericTypeResolver *resolver,
12561256
UnsatisfiedDependency *unsatisfiedDependency) {
1257-
auto maybeDiagnoseBadMemberType = [&](TypeDecl *member, Type memberType) {
1257+
auto maybeDiagnoseBadMemberType = [&](TypeDecl *member, Type memberType,
1258+
AssociatedTypeDecl *inferredAssocType) {
12581259
// Diagnose invalid cases.
12591260
if (TC.isUnsupportedMemberTypeAccess(parentTy, member)) {
12601261
if (diagnoseErrors) {
@@ -1285,20 +1286,21 @@ static Type resolveNestedIdentTypeComponent(
12851286
}
12861287
}
12871288

1289+
// Diagnose a bad conformance reference if we need to.
1290+
if (inferredAssocType && diagnoseErrors && memberType &&
1291+
memberType->hasError()) {
1292+
maybeDiagnoseBadConformanceRef(TC, DC, parentTy, comp->getLoc(),
1293+
inferredAssocType);
1294+
}
1295+
12881296
// If we found a reference to an associated type or other member type that
12891297
// was marked invalid, just return ErrorType to silence downstream errors.
12901298
if (member->isInvalid())
12911299
return ErrorType::get(TC.Context);
12921300

1293-
// Diagnose a bad conformance reference if we need to.
1294-
if (isa<AssociatedTypeDecl>(member) && diagnoseErrors &&
1295-
memberType && memberType->hasError()) {
1296-
maybeDiagnoseBadConformanceRef(TC, DC, parentTy, comp->getLoc(),
1297-
cast<AssociatedTypeDecl>(member));
1298-
}
1299-
13001301
// At this point, we need to have resolved the type of the member.
1301-
if (!memberType || memberType->hasError()) return memberType;
1302+
if (!memberType || memberType->hasError())
1303+
return memberType;
13021304

13031305
// If there are generic arguments, apply them now.
13041306
if (!options.contains(TypeResolutionFlags::ResolveStructure)) {
@@ -1335,7 +1337,7 @@ static Type resolveNestedIdentTypeComponent(
13351337
if (auto *typeDecl = comp->getBoundDecl()) {
13361338
auto memberType = TC.substMemberTypeWithBase(DC->getParentModule(),
13371339
typeDecl, parentTy);
1338-
return maybeDiagnoseBadMemberType(typeDecl, memberType);
1340+
return maybeDiagnoseBadMemberType(typeDecl, memberType, nullptr);
13391341
}
13401342

13411343
// Phase 1: Find and bind the component decl.
@@ -1390,6 +1392,7 @@ static Type resolveNestedIdentTypeComponent(
13901392
// If we didn't find anything, complain.
13911393
Type memberType;
13921394
TypeDecl *member = nullptr;
1395+
AssociatedTypeDecl *inferredAssocType = nullptr;
13931396
if (!memberTypes) {
13941397
// If we're not allowed to complain or we couldn't fix the
13951398
// source, bail out.
@@ -1403,12 +1406,13 @@ static Type resolveNestedIdentTypeComponent(
14031406
if (!member)
14041407
return ErrorType::get(TC.Context);
14051408
} else {
1406-
memberType = memberTypes.back().second;
1407-
member = memberTypes.back().first;
1409+
memberType = memberTypes.back().MemberType;
1410+
member = memberTypes.back().Member;
1411+
inferredAssocType = memberTypes.back().InferredAssociatedType;
14081412
comp->setValue(member, nullptr);
14091413
}
14101414

1411-
return maybeDiagnoseBadMemberType(member, memberType);
1415+
return maybeDiagnoseBadMemberType(member, memberType, inferredAssocType);
14121416
}
14131417

14141418
static Type resolveIdentTypeComponent(

lib/Sema/TypeChecker.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -942,8 +942,7 @@ void TypeChecker::diagnoseAmbiguousMemberType(Type baseTy,
942942
.highlight(baseRange);
943943
}
944944
for (const auto &member : lookup) {
945-
diagnose(member.first, diag::found_candidate_type,
946-
member.second);
945+
diagnose(member.Member, diag::found_candidate_type, member.MemberType);
947946
}
948947
}
949948

lib/Sema/TypeChecker.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,30 +151,37 @@ class LookupResult {
151151
filter(llvm::function_ref<bool(LookupResultEntry, /*isOuter*/ bool)> pred);
152152
};
153153

154+
/// An individual result of a name lookup for a type.
155+
struct LookupTypeResultEntry {
156+
TypeDecl *Member;
157+
Type MemberType;
158+
/// The associated type that the Member/MemberType were inferred for, but only
159+
/// if inference happened when creating this entry.
160+
AssociatedTypeDecl *InferredAssociatedType;
161+
};
162+
154163
/// The result of name lookup for types.
155164
class LookupTypeResult {
156165
/// The set of results found.
157-
SmallVector<std::pair<TypeDecl *, Type>, 4> Results;
166+
SmallVector<LookupTypeResultEntry, 4> Results;
158167

159168
friend class TypeChecker;
160169

161170
public:
162-
using iterator = SmallVectorImpl<std::pair<TypeDecl *, Type>>::iterator;
171+
using iterator = SmallVectorImpl<LookupTypeResultEntry>::iterator;
163172
iterator begin() { return Results.begin(); }
164173
iterator end() { return Results.end(); }
165174
unsigned size() const { return Results.size(); }
166175

167-
std::pair<TypeDecl *, Type> operator[](unsigned index) const {
176+
LookupTypeResultEntry operator[](unsigned index) const {
168177
return Results[index];
169178
}
170179

171-
std::pair<TypeDecl *, Type> front() const { return Results.front(); }
172-
std::pair<TypeDecl *, Type> back() const { return Results.back(); }
180+
LookupTypeResultEntry front() const { return Results.front(); }
181+
LookupTypeResultEntry back() const { return Results.back(); }
173182

174183
/// Add a result to the set of results.
175-
void addResult(std::pair<TypeDecl *, Type> result) {
176-
Results.push_back(result);
177-
}
184+
void addResult(LookupTypeResultEntry result) { Results.push_back(result); }
178185

179186
/// \brief Determine whether this result set is ambiguous.
180187
bool isAmbiguous() const {

0 commit comments

Comments
 (0)