Skip to content

Commit 1bbb941

Browse files
authored
Merge pull request #8065 from slavapestov/generic-typealias-missing-check
Sema: Actually check generic typealias requirements
2 parents 15431e9 + 4a46023 commit 1bbb941

File tree

2 files changed

+65
-61
lines changed

2 files changed

+65
-61
lines changed

lib/Sema/TypeCheckType.cpp

Lines changed: 52 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -615,88 +615,79 @@ Type TypeChecker::applyUnboundGenericArguments(
615615
if (!resolver)
616616
resolver = &defaultResolver;
617617

618-
// Validate the generic arguments and capture just the types.
619-
SmallVector<Type, 4> genericArgTypes;
620-
for (auto &genericArg : genericArgs) {
621-
// Validate the generic argument.
622-
if (validateType(genericArg, dc, options, resolver, unsatisfiedDependency))
623-
return ErrorType::get(Context);
624-
625-
if (!genericArg.getType())
626-
return nullptr;
618+
auto genericSig = decl->getGenericSignature();
619+
assert(genericSig != nullptr);
627620

628-
genericArgTypes.push_back(genericArg.getType());
629-
}
621+
TypeSubstitutionMap subs;
630622

631-
// If we're completing a generic TypeAlias, then we map the types provided
632-
// onto the underlying type.
633-
if (auto *TAD = dyn_cast<TypeAliasDecl>(decl)) {
634-
TypeSubstitutionMap subs;
623+
// Get the substitutions for outer generic parameters from the parent
624+
// type.
625+
auto *unboundType = type->castTo<UnboundGenericType>();
626+
if (auto parentType = unboundType->getParent())
627+
subs = parentType->getContextSubstitutions(decl->getDeclContext());
635628

636-
// The type should look like SomeNominal<T, U>.Alias<V, W>.
629+
SourceLoc noteLoc = decl->getLoc();
630+
if (noteLoc.isInvalid())
631+
noteLoc = loc;
637632

638-
// Get the substitutions for outer generic parameters from the parent
639-
// type.
640-
auto *unboundType = type->castTo<UnboundGenericType>();
641-
if (auto parentType = unboundType->getParent())
642-
subs = parentType->getContextSubstitutions(TAD->getDeclContext());
633+
// Realize the types of the generic arguments and add them to the
634+
// substitution map.
635+
bool hasTypeParameterOrVariable = false;
636+
for (unsigned i = 0, e = genericArgs.size(); i < e; i++) {
637+
auto &genericArg = genericArgs[i];
643638

644-
// Get the substitutions for the inner parameters.
645-
auto signature = TAD->getGenericSignature();
646-
for (unsigned i = 0, e = genericArgs.size(); i < e; i++) {
647-
auto t = signature->getInnermostGenericParams()[i];
648-
subs[t->getCanonicalType()->castTo<GenericTypeParamType>()] =
649-
genericArgs[i].getType();
650-
}
651-
652-
// Apply substitutions to the interface type of the typealias.
653-
type = TAD->getDeclaredInterfaceType();
654-
return type.subst(QueryTypeSubstitutionMap{subs},
655-
LookUpConformanceInModule(dc->getParentModule()),
656-
SubstFlags::UseErrorType);
657-
}
658-
659-
// Form the bound generic type.
660-
auto *UGT = type->castTo<UnboundGenericType>();
661-
auto *BGT = BoundGenericType::get(cast<NominalTypeDecl>(decl),
662-
UGT->getParent(), genericArgTypes);
639+
// Propagate failure.
640+
if (validateType(genericArg, dc, options, resolver, unsatisfiedDependency))
641+
return ErrorType::get(Context);
663642

664-
// Check protocol conformance.
665-
if (!BGT->hasTypeParameter() && !BGT->hasTypeVariable()) {
666-
SourceLoc noteLoc = decl->getLoc();
667-
if (noteLoc.isInvalid())
668-
noteLoc = loc;
643+
auto origTy = genericSig->getInnermostGenericParams()[i];
644+
auto substTy = genericArg.getType();
669645

670-
// FIXME: Record that we're checking substitutions, so we can't end up
671-
// with infinite recursion.
646+
// Unsatisfied dependency case.
647+
if (!substTy)
648+
return nullptr;
672649

673-
// Check the generic arguments against the generic signature.
674-
auto genericSig = decl->getGenericSignature();
650+
// Enter a substitution.
651+
subs[origTy->getCanonicalType()->castTo<GenericTypeParamType>()] =
652+
substTy;
675653

676-
// Collect the complete set of generic arguments.
677-
assert(genericSig != nullptr);
678-
auto substitutions = BGT->getContextSubstitutions(BGT->getDecl());
654+
hasTypeParameterOrVariable |=
655+
(substTy->hasTypeParameter() || substTy->hasTypeVariable());
656+
}
679657

658+
// Check the generic arguments against the requirements of the declaration's
659+
// generic signature.
660+
if (!hasTypeParameterOrVariable) {
680661
auto result =
681-
checkGenericArguments(dc, loc, noteLoc, UGT, genericSig,
682-
QueryTypeSubstitutionMap{substitutions},
683-
LookUpConformanceInModule{dc->getParentModule()},
684-
unsatisfiedDependency);
662+
checkGenericArguments(dc, loc, noteLoc, unboundType, genericSig,
663+
QueryTypeSubstitutionMap{subs},
664+
LookUpConformanceInModule{dc->getParentModule()},
665+
unsatisfiedDependency);
685666

686667
switch (result) {
687668
case RequirementCheckResult::UnsatisfiedDependency:
688669
return Type();
689670
case RequirementCheckResult::Failure:
690671
return ErrorType::get(Context);
691-
692672
case RequirementCheckResult::Success:
693-
if (useObjectiveCBridgeableConformancesOfArgs(dc, BGT,
694-
unsatisfiedDependency))
695-
return Type();
673+
break;
696674
}
697675
}
698676

699-
return BGT;
677+
// Apply the substitution map to the interface type of the declaration.
678+
type = decl->getDeclaredInterfaceType();
679+
type = type.subst(QueryTypeSubstitutionMap{subs},
680+
LookUpConformanceInModule(dc->getParentModule()),
681+
SubstFlags::UseErrorType);
682+
683+
if (isa<NominalTypeDecl>(decl)) {
684+
if (useObjectiveCBridgeableConformancesOfArgs(
685+
dc, type->castTo<BoundGenericType>(),
686+
unsatisfiedDependency))
687+
return Type();
688+
}
689+
690+
return type;
700691
}
701692

702693
/// \brief Diagnose a use of an unbound generic type.

test/decl/typealias/generic.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,16 @@ func f(x: S.G1<Int>, y: S.G2<Int>) {
340340
takesMyType(x: x)
341341
takesMyType(y: y)
342342
}
343+
344+
//
345+
// Generic typealiases with requirements
346+
//
347+
348+
typealias Element<S> = S.Iterator.Element where S : Sequence
349+
350+
func takesInt(_: Element<[Int]>) {}
351+
352+
takesInt(10)
353+
354+
func failsRequirementCheck(_: Element<Int>) {}
355+
// expected-error@-1 {{type 'Int' does not conform to protocol 'Sequence'}}

0 commit comments

Comments
 (0)