Skip to content

Commit 00e30dd

Browse files
authored
Merge pull request #25291 from xedin/refactor-open-generic
[ConstraintSystem] Refactor openGeneric/openFunctionType
2 parents 07e6940 + c1087b9 commit 00e30dd

File tree

6 files changed

+118
-164
lines changed

6 files changed

+118
-164
lines changed

include/swift/AST/Types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3172,6 +3172,7 @@ class GenericFunctionType final : public AnyFunctionType,
31723172
/// Substitute the given generic arguments into this generic
31733173
/// function type and return the resulting non-generic type.
31743174
FunctionType *substGenericArgs(SubstitutionMap subs);
3175+
FunctionType *substGenericArgs(llvm::function_ref<Type(Type)> substFn) const;
31753176

31763177
void Profile(llvm::FoldingSetNodeID &ID) {
31773178
Profile(ID, getGenericSignature(), getParams(), getResult(),

lib/AST/Type.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2878,6 +2878,22 @@ GenericFunctionType::substGenericArgs(SubstitutionMap subs) {
28782878
substFn->getResult(), getExtInfo());
28792879
}
28802880

2881+
FunctionType *GenericFunctionType::substGenericArgs(
2882+
llvm::function_ref<Type(Type)> substFn) const {
2883+
llvm::SmallVector<AnyFunctionType::Param, 4> params;
2884+
params.reserve(getNumParams());
2885+
2886+
llvm::transform(getParams(), std::back_inserter(params),
2887+
[&](const AnyFunctionType::Param &param) {
2888+
return param.withType(substFn(param.getPlainType()));
2889+
});
2890+
2891+
auto resultTy = substFn(getResult());
2892+
2893+
// Build the resulting (non-generic) function type.
2894+
return FunctionType::get(params, resultTy, getExtInfo());
2895+
}
2896+
28812897
CanFunctionType
28822898
CanGenericFunctionType::substGenericArgs(SubstitutionMap subs) const {
28832899
return cast<FunctionType>(

lib/Sema/CSRanking.cpp

Lines changed: 21 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,7 @@ static bool isProtocolExtensionAsSpecializedAs(TypeChecker &tc,
341341
// the second protocol extension.
342342
ConstraintSystem cs(tc, dc1, None);
343343
OpenedTypeMap replacements;
344-
cs.openGeneric(dc2, dc2, sig2,
345-
/*skipProtocolSelfConstraint=*/false,
346-
ConstraintLocatorBuilder(nullptr),
347-
replacements);
344+
cs.openGeneric(dc2, sig2, ConstraintLocatorBuilder(nullptr), replacements);
348345

349346
// Bind the 'Self' type from the first extension to the type parameter from
350347
// opening 'Self' of the second extension.
@@ -507,54 +504,33 @@ static bool isDeclAsSpecializedAs(TypeChecker &tc, DeclContext *dc,
507504
checkKind = CheckAll;
508505
}
509506

507+
auto openType = [&](ConstraintSystem &cs, DeclContext *innerDC,
508+
DeclContext *outerDC, Type type,
509+
OpenedTypeMap &replacements,
510+
ConstraintLocator *locator) -> Type {
511+
if (auto *funcType = type->getAs<AnyFunctionType>()) {
512+
return cs.openFunctionType(funcType, locator, replacements, outerDC);
513+
}
514+
515+
cs.openGeneric(outerDC, innerDC->getGenericSignatureOfContext(),
516+
locator, replacements);
517+
518+
return cs.openType(type, replacements);
519+
};
520+
510521
// Construct a constraint system to compare the two declarations.
511522
ConstraintSystem cs(tc, dc, ConstraintSystemOptions());
512523
bool knownNonSubtype = false;
513524

514-
auto locator = cs.getConstraintLocator(nullptr);
525+
auto *locator = cs.getConstraintLocator(nullptr);
515526
// FIXME: Locator when anchored on a declaration.
516527
// Get the type of a reference to the second declaration.
517-
OpenedTypeMap unused;
518-
Type openedType2;
519-
if (auto *funcType = type2->getAs<AnyFunctionType>()) {
520-
openedType2 = cs.openFunctionType(
521-
funcType, /*numArgumentLabelsToRemove=*/0, locator,
522-
/*replacements=*/unused,
523-
innerDC2,
524-
outerDC2,
525-
/*skipProtocolSelfConstraint=*/false);
526-
} else {
527-
cs.openGeneric(innerDC2,
528-
outerDC2,
529-
innerDC2->getGenericSignatureOfContext(),
530-
/*skipProtocolSelfConstraint=*/false,
531-
locator,
532-
unused);
533-
534-
openedType2 = cs.openType(type2, unused);
535-
}
536528

537-
// Get the type of a reference to the first declaration, swapping in
538-
// archetypes for the dependent types.
539-
OpenedTypeMap replacements;
540-
Type openedType1;
541-
if (auto *funcType = type1->getAs<AnyFunctionType>()) {
542-
openedType1 = cs.openFunctionType(
543-
funcType, /*numArgumentLabelsToRemove=*/0, locator,
544-
replacements,
545-
innerDC1,
546-
outerDC1,
547-
/*skipProtocolSelfConstraint=*/false);
548-
} else {
549-
cs.openGeneric(innerDC1,
550-
outerDC1,
551-
innerDC1->getGenericSignatureOfContext(),
552-
/*skipProtocolSelfConstraint=*/false,
553-
locator,
554-
replacements);
555-
556-
openedType1 = cs.openType(type1, replacements);
557-
}
529+
OpenedTypeMap unused, replacements;
530+
auto openedType2 =
531+
openType(cs, innerDC1, outerDC2, type2, unused, locator);
532+
auto openedType1 =
533+
openType(cs, innerDC2, outerDC1, type1, replacements, locator);
558534

559535
for (const auto &replacement : replacements) {
560536
if (auto mapped = innerDC1->mapTypeIntoContext(replacement.first)) {

lib/Sema/CSSimplify.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4885,10 +4885,8 @@ ConstraintSystem::simplifyOpaqueUnderlyingTypeConstraint(Type type1, Type type2,
48854885
// corresponding to the underlying type should be the constraints on the
48864886
// underlying return type.
48874887
OpenedTypeMap replacements;
4888-
openGeneric(nullptr, DC, opaque2->getBoundSignature(),
4889-
/*skip self*/ false,
4890-
locator, replacements);
4891-
4888+
openGeneric(DC, opaque2->getBoundSignature(), locator, replacements);
4889+
48924890
auto underlyingTyVar = openType(opaque2->getInterfaceType(),
48934891
replacements);
48944892
assert(underlyingTyVar);

lib/Sema/ConstraintSystem.cpp

Lines changed: 63 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -465,12 +465,8 @@ Type ConstraintSystem::openUnboundGenericType(UnboundGenericType *unbound,
465465
}
466466

467467
// Open up the generic type.
468-
openGeneric(unboundDecl->getInnermostDeclContext(),
469-
unboundDecl->getDeclContext(),
470-
unboundDecl->getGenericSignature(),
471-
/*skipProtocolSelfConstraint=*/false,
472-
locator,
473-
replacements);
468+
openGeneric(unboundDecl->getDeclContext(), unboundDecl->getGenericSignature(),
469+
locator, replacements);
474470

475471
if (parentTy) {
476472
auto subs = parentTy->getContextSubstitutions(
@@ -633,45 +629,25 @@ Type ConstraintSystem::openType(Type type, OpenedTypeMap &replacements) {
633629
});
634630
}
635631

636-
Type ConstraintSystem::openFunctionType(
632+
FunctionType *ConstraintSystem::openFunctionType(
637633
AnyFunctionType *funcType,
638-
unsigned numArgumentLabelsToRemove,
639634
ConstraintLocatorBuilder locator,
640635
OpenedTypeMap &replacements,
641-
DeclContext *innerDC,
642-
DeclContext *outerDC,
643-
bool skipProtocolSelfConstraint,
644-
bool skipGenericRequirements) {
645-
Type type;
646-
636+
DeclContext *outerDC) {
647637
if (auto *genericFn = funcType->getAs<GenericFunctionType>()) {
648-
// Open up the generic parameters and requirements.
649-
openGeneric(innerDC,
650-
outerDC,
651-
genericFn->getGenericSignature(),
652-
skipProtocolSelfConstraint,
653-
locator,
654-
replacements,
655-
skipGenericRequirements);
656-
657-
// Transform the parameters and output type.
658-
llvm::SmallVector<AnyFunctionType::Param, 4> openedParams;
659-
openedParams.reserve(genericFn->getNumParams());
660-
for (const auto &param : genericFn->getParams()) {
661-
auto type = openType(param.getPlainType(), replacements);
662-
openedParams.push_back(AnyFunctionType::Param(type, param.getLabel(),
663-
param.getParameterFlags()));
664-
}
638+
auto *signature = genericFn->getGenericSignature();
665639

666-
auto resultTy = openType(genericFn->getResult(), replacements);
640+
openGenericParameters(outerDC, signature, replacements, locator);
667641

668-
// Build the resulting (non-generic) function type.
669-
funcType = FunctionType::get(
670-
openedParams, resultTy,
671-
FunctionType::ExtInfo().withThrows(genericFn->throws()));
642+
openGenericRequirements(
643+
outerDC, signature, /*skipProtocolSelfConstraint=*/false, locator,
644+
[&](Type type) -> Type { return openType(type, replacements); });
645+
646+
funcType = genericFn->substGenericArgs(
647+
[&](Type type) { return openType(type, replacements); });
672648
}
673649

674-
return funcType->removeArgumentLabels(numArgumentLabelsToRemove);
650+
return funcType->castTo<FunctionType>();
675651
}
676652

677653
Optional<Type> ConstraintSystem::isArrayType(Type type) {
@@ -934,14 +910,9 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,
934910

935911
OpenedTypeMap replacements;
936912

937-
auto openedType = openFunctionType(
938-
func->getInterfaceType()->castTo<AnyFunctionType>(),
939-
/*numArgumentLabelsToRemove=*/0,
940-
locator, replacements,
941-
func->getInnermostDeclContext(),
942-
func->getDeclContext(),
943-
/*skipProtocolSelfConstraint=*/false);
944-
auto openedFnType = openedType->castTo<FunctionType>();
913+
auto openedType =
914+
openFunctionType(func->getInterfaceType()->castTo<AnyFunctionType>(),
915+
locator, replacements, func->getDeclContext());
945916

946917
// If we opened up any type variables, record the replacements.
947918
recordOpenedTypes(locator, replacements);
@@ -950,36 +921,32 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,
950921
// DynamicSelf with the actual object type.
951922
if (!func->getDeclContext()->getSelfProtocolDecl()) {
952923
if (func->hasDynamicSelf()) {
953-
auto params = openedFnType->getParams();
924+
auto params = openedType->getParams();
954925
assert(params.size() == 1);
955926
Type selfTy = params.front().getPlainType()->getMetatypeInstanceType();
956-
openedType = openedType->replaceCovariantResultType(selfTy, 2);
957-
openedFnType = openedType->castTo<FunctionType>();
927+
openedType = openedType->replaceCovariantResultType(selfTy, 2)
928+
->castTo<FunctionType>();
958929
}
959930
} else {
960-
openedType = openedType->eraseDynamicSelfType();
961-
openedFnType = openedType->castTo<FunctionType>();
931+
openedType = openedType->eraseDynamicSelfType()->castTo<FunctionType>();
962932
}
963933

964934
// The reference implicitly binds 'self'.
965-
return { openedType, openedFnType->getResult() };
935+
return {openedType, openedType->getResult()};
966936
}
967937

968938
// Unqualified reference to a local or global function.
969939
if (auto funcDecl = dyn_cast<AbstractFunctionDecl>(value)) {
970940
OpenedTypeMap replacements;
971941

972942
auto funcType = funcDecl->getInterfaceType()->castTo<AnyFunctionType>();
973-
auto openedType =
974-
openFunctionType(
975-
funcType,
976-
getNumRemovedArgumentLabels(TC, funcDecl,
977-
/*isCurriedInstanceReference=*/false,
978-
functionRefKind),
979-
locator, replacements,
980-
funcDecl->getInnermostDeclContext(),
981-
funcDecl->getDeclContext(),
982-
/*skipProtocolSelfConstraint=*/false);
943+
auto numLabelsToRemove = getNumRemovedArgumentLabels(
944+
TC, funcDecl,
945+
/*isCurriedInstanceReference=*/false, functionRefKind);
946+
947+
auto openedType = openFunctionType(funcType, locator, replacements,
948+
funcDecl->getDeclContext())
949+
->removeArgumentLabels(numLabelsToRemove);
983950

984951
// If we opened up any type variables, record the replacements.
985952
recordOpenedTypes(locator, replacements);
@@ -1103,46 +1070,44 @@ static void bindArchetypesFromContext(
11031070
}
11041071

11051072
void ConstraintSystem::openGeneric(
1106-
DeclContext *innerDC,
11071073
DeclContext *outerDC,
11081074
GenericSignature *sig,
1109-
bool skipProtocolSelfConstraint,
11101075
ConstraintLocatorBuilder locator,
1111-
OpenedTypeMap &replacements,
1112-
bool skipGenericRequirements) {
1076+
OpenedTypeMap &replacements) {
11131077
if (sig == nullptr)
11141078
return;
11151079

1116-
auto locatorPtr = getConstraintLocator(locator);
1080+
openGenericParameters(outerDC, sig, replacements, locator);
1081+
1082+
// Add the requirements as constraints.
1083+
openGenericRequirements(
1084+
outerDC, sig, /*skipProtocolSelfConstraint=*/false, locator,
1085+
[&](Type type) { return openType(type, replacements); });
1086+
}
1087+
1088+
void ConstraintSystem::openGenericParameters(DeclContext *outerDC,
1089+
GenericSignature *sig,
1090+
OpenedTypeMap &replacements,
1091+
ConstraintLocatorBuilder locator) {
1092+
assert(sig);
11171093

11181094
// Create the type variables for the generic parameters.
11191095
for (auto gp : sig->getGenericParams()) {
1120-
locatorPtr = getConstraintLocator(
1121-
locator.withPathElement(LocatorPathElt(gp)));
1122-
1123-
auto typeVar = createTypeVariable(locatorPtr,
1124-
TVO_PrefersSubtypeBinding);
1125-
auto result = replacements.insert(
1126-
std::make_pair(cast<GenericTypeParamType>(gp->getCanonicalType()),
1127-
typeVar));
1096+
auto *paramLocator =
1097+
getConstraintLocator(locator.withPathElement(LocatorPathElt(gp)));
1098+
1099+
auto typeVar = createTypeVariable(paramLocator, TVO_PrefersSubtypeBinding);
1100+
auto result = replacements.insert(std::make_pair(
1101+
cast<GenericTypeParamType>(gp->getCanonicalType()), typeVar));
1102+
11281103
assert(result.second);
1129-
(void) result;
1104+
(void)result;
11301105
}
11311106

1132-
// Remember that any new constraints generated by opening this generic are
1133-
// due to the opening.
1134-
locatorPtr = getConstraintLocator(
1107+
auto *baseLocator = getConstraintLocator(
11351108
locator.withPathElement(LocatorPathElt::getOpenedGeneric(sig)));
11361109

1137-
bindArchetypesFromContext(*this, outerDC, locatorPtr, replacements);
1138-
1139-
if (skipGenericRequirements)
1140-
return;
1141-
1142-
// Add the requirements as constraints.
1143-
openGenericRequirements(
1144-
outerDC, sig, skipProtocolSelfConstraint, locator,
1145-
[&](Type type) { return openType(type, replacements); });
1110+
bindArchetypesFromContext(*this, outerDC, baseLocator, replacements);
11461111
}
11471112

11481113
void ConstraintSystem::openGenericRequirements(
@@ -1331,10 +1296,17 @@ ConstraintSystem::getTypeOfMemberReference(
13311296

13321297
// While opening member function type, let's delay opening requirements
13331298
// to allow contextual types to affect the situation.
1334-
openedType = openFunctionType(funcType, numRemovedArgumentLabels,
1335-
locator, replacements, innerDC, outerDC,
1336-
/*skipProtocolSelfConstraint=*/true,
1337-
/*skipGenericRequirements=*/true);
1299+
if (auto *genericFn = funcType->getAs<GenericFunctionType>()) {
1300+
openGenericParameters(outerDC, genericFn->getGenericSignature(),
1301+
replacements, locator);
1302+
1303+
openedType = genericFn->substGenericArgs(
1304+
[&](Type type) { return openType(type, replacements); });
1305+
} else {
1306+
openedType = funcType;
1307+
}
1308+
1309+
openedType = openedType->removeArgumentLabels(numRemovedArgumentLabels);
13381310

13391311
if (!outerDC->getSelfProtocolDecl()) {
13401312
// Class methods returning Self as well as constructors get the

0 commit comments

Comments
 (0)