Skip to content

[ConstraintSystem] Refactor openGeneric/openFunctionType #25291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 8, 2019
1 change: 1 addition & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,7 @@ class GenericFunctionType final : public AnyFunctionType,
/// Substitute the given generic arguments into this generic
/// function type and return the resulting non-generic type.
FunctionType *substGenericArgs(SubstitutionMap subs);
FunctionType *substGenericArgs(llvm::function_ref<Type(Type)> substFn) const;

void Profile(llvm::FoldingSetNodeID &ID) {
Profile(ID, getGenericSignature(), getParams(), getResult(),
Expand Down
16 changes: 16 additions & 0 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2878,6 +2878,22 @@ GenericFunctionType::substGenericArgs(SubstitutionMap subs) {
substFn->getResult(), getExtInfo());
}

FunctionType *GenericFunctionType::substGenericArgs(
llvm::function_ref<Type(Type)> substFn) const {
llvm::SmallVector<AnyFunctionType::Param, 4> params;
params.reserve(getNumParams());

llvm::transform(getParams(), std::back_inserter(params),
[&](const AnyFunctionType::Param &param) {
return param.withType(substFn(param.getPlainType()));
});

auto resultTy = substFn(getResult());

// Build the resulting (non-generic) function type.
return FunctionType::get(params, resultTy, getExtInfo());
}

CanFunctionType
CanGenericFunctionType::substGenericArgs(SubstitutionMap subs) const {
return cast<FunctionType>(
Expand Down
66 changes: 21 additions & 45 deletions lib/Sema/CSRanking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,7 @@ static bool isProtocolExtensionAsSpecializedAs(TypeChecker &tc,
// the second protocol extension.
ConstraintSystem cs(tc, dc1, None);
OpenedTypeMap replacements;
cs.openGeneric(dc2, dc2, sig2,
/*skipProtocolSelfConstraint=*/false,
ConstraintLocatorBuilder(nullptr),
replacements);
cs.openGeneric(dc2, sig2, ConstraintLocatorBuilder(nullptr), replacements);

// Bind the 'Self' type from the first extension to the type parameter from
// opening 'Self' of the second extension.
Expand Down Expand Up @@ -507,54 +504,33 @@ static bool isDeclAsSpecializedAs(TypeChecker &tc, DeclContext *dc,
checkKind = CheckAll;
}

auto openType = [&](ConstraintSystem &cs, DeclContext *innerDC,
DeclContext *outerDC, Type type,
OpenedTypeMap &replacements,
ConstraintLocator *locator) -> Type {
if (auto *funcType = type->getAs<AnyFunctionType>()) {
return cs.openFunctionType(funcType, locator, replacements, outerDC);
}

cs.openGeneric(outerDC, innerDC->getGenericSignatureOfContext(),
locator, replacements);

return cs.openType(type, replacements);
};

// Construct a constraint system to compare the two declarations.
ConstraintSystem cs(tc, dc, ConstraintSystemOptions());
bool knownNonSubtype = false;

auto locator = cs.getConstraintLocator(nullptr);
auto *locator = cs.getConstraintLocator(nullptr);
// FIXME: Locator when anchored on a declaration.
// Get the type of a reference to the second declaration.
OpenedTypeMap unused;
Type openedType2;
if (auto *funcType = type2->getAs<AnyFunctionType>()) {
openedType2 = cs.openFunctionType(
funcType, /*numArgumentLabelsToRemove=*/0, locator,
/*replacements=*/unused,
innerDC2,
outerDC2,
/*skipProtocolSelfConstraint=*/false);
} else {
cs.openGeneric(innerDC2,
outerDC2,
innerDC2->getGenericSignatureOfContext(),
/*skipProtocolSelfConstraint=*/false,
locator,
unused);

openedType2 = cs.openType(type2, unused);
}

// Get the type of a reference to the first declaration, swapping in
// archetypes for the dependent types.
OpenedTypeMap replacements;
Type openedType1;
if (auto *funcType = type1->getAs<AnyFunctionType>()) {
openedType1 = cs.openFunctionType(
funcType, /*numArgumentLabelsToRemove=*/0, locator,
replacements,
innerDC1,
outerDC1,
/*skipProtocolSelfConstraint=*/false);
} else {
cs.openGeneric(innerDC1,
outerDC1,
innerDC1->getGenericSignatureOfContext(),
/*skipProtocolSelfConstraint=*/false,
locator,
replacements);

openedType1 = cs.openType(type1, replacements);
}
OpenedTypeMap unused, replacements;
auto openedType2 =
openType(cs, innerDC1, outerDC2, type2, unused, locator);
auto openedType1 =
openType(cs, innerDC2, outerDC1, type1, replacements, locator);

for (const auto &replacement : replacements) {
if (auto mapped = innerDC1->mapTypeIntoContext(replacement.first)) {
Expand Down
6 changes: 2 additions & 4 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4883,10 +4883,8 @@ ConstraintSystem::simplifyOpaqueUnderlyingTypeConstraint(Type type1, Type type2,
// corresponding to the underlying type should be the constraints on the
// underlying return type.
OpenedTypeMap replacements;
openGeneric(nullptr, DC, opaque2->getBoundSignature(),
/*skip self*/ false,
locator, replacements);

openGeneric(DC, opaque2->getBoundSignature(), locator, replacements);

auto underlyingTyVar = openType(opaque2->getInterfaceType(),
replacements);
assert(underlyingTyVar);
Expand Down
154 changes: 63 additions & 91 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,12 +465,8 @@ Type ConstraintSystem::openUnboundGenericType(UnboundGenericType *unbound,
}

// Open up the generic type.
openGeneric(unboundDecl->getInnermostDeclContext(),
unboundDecl->getDeclContext(),
unboundDecl->getGenericSignature(),
/*skipProtocolSelfConstraint=*/false,
locator,
replacements);
openGeneric(unboundDecl->getDeclContext(), unboundDecl->getGenericSignature(),
locator, replacements);

if (parentTy) {
auto subs = parentTy->getContextSubstitutions(
Expand Down Expand Up @@ -633,45 +629,25 @@ Type ConstraintSystem::openType(Type type, OpenedTypeMap &replacements) {
});
}

Type ConstraintSystem::openFunctionType(
FunctionType *ConstraintSystem::openFunctionType(
AnyFunctionType *funcType,
unsigned numArgumentLabelsToRemove,
ConstraintLocatorBuilder locator,
OpenedTypeMap &replacements,
DeclContext *innerDC,
DeclContext *outerDC,
bool skipProtocolSelfConstraint,
bool skipGenericRequirements) {
Type type;

DeclContext *outerDC) {
if (auto *genericFn = funcType->getAs<GenericFunctionType>()) {
// Open up the generic parameters and requirements.
openGeneric(innerDC,
outerDC,
genericFn->getGenericSignature(),
skipProtocolSelfConstraint,
locator,
replacements,
skipGenericRequirements);

// Transform the parameters and output type.
llvm::SmallVector<AnyFunctionType::Param, 4> openedParams;
openedParams.reserve(genericFn->getNumParams());
for (const auto &param : genericFn->getParams()) {
auto type = openType(param.getPlainType(), replacements);
openedParams.push_back(AnyFunctionType::Param(type, param.getLabel(),
param.getParameterFlags()));
}
auto *signature = genericFn->getGenericSignature();

auto resultTy = openType(genericFn->getResult(), replacements);
openGenericParameters(outerDC, signature, replacements, locator);

// Build the resulting (non-generic) function type.
funcType = FunctionType::get(
openedParams, resultTy,
FunctionType::ExtInfo().withThrows(genericFn->throws()));
openGenericRequirements(
outerDC, signature, /*skipProtocolSelfConstraint=*/false, locator,
[&](Type type) -> Type { return openType(type, replacements); });

funcType = genericFn->substGenericArgs(
[&](Type type) { return openType(type, replacements); });
}

return funcType->removeArgumentLabels(numArgumentLabelsToRemove);
return funcType->castTo<FunctionType>();
}

Optional<Type> ConstraintSystem::isArrayType(Type type) {
Expand Down Expand Up @@ -934,14 +910,9 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,

OpenedTypeMap replacements;

auto openedType = openFunctionType(
func->getInterfaceType()->castTo<AnyFunctionType>(),
/*numArgumentLabelsToRemove=*/0,
locator, replacements,
func->getInnermostDeclContext(),
func->getDeclContext(),
/*skipProtocolSelfConstraint=*/false);
auto openedFnType = openedType->castTo<FunctionType>();
auto openedType =
openFunctionType(func->getInterfaceType()->castTo<AnyFunctionType>(),
locator, replacements, func->getDeclContext());

// If we opened up any type variables, record the replacements.
recordOpenedTypes(locator, replacements);
Expand All @@ -950,36 +921,32 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value,
// DynamicSelf with the actual object type.
if (!func->getDeclContext()->getSelfProtocolDecl()) {
if (func->hasDynamicSelf()) {
auto params = openedFnType->getParams();
auto params = openedType->getParams();
assert(params.size() == 1);
Type selfTy = params.front().getPlainType()->getMetatypeInstanceType();
openedType = openedType->replaceCovariantResultType(selfTy, 2);
openedFnType = openedType->castTo<FunctionType>();
openedType = openedType->replaceCovariantResultType(selfTy, 2)
->castTo<FunctionType>();
}
} else {
openedType = openedType->eraseDynamicSelfType();
openedFnType = openedType->castTo<FunctionType>();
openedType = openedType->eraseDynamicSelfType()->castTo<FunctionType>();
}

// The reference implicitly binds 'self'.
return { openedType, openedFnType->getResult() };
return {openedType, openedType->getResult()};
}

// Unqualified reference to a local or global function.
if (auto funcDecl = dyn_cast<AbstractFunctionDecl>(value)) {
OpenedTypeMap replacements;

auto funcType = funcDecl->getInterfaceType()->castTo<AnyFunctionType>();
auto openedType =
openFunctionType(
funcType,
getNumRemovedArgumentLabels(TC, funcDecl,
/*isCurriedInstanceReference=*/false,
functionRefKind),
locator, replacements,
funcDecl->getInnermostDeclContext(),
funcDecl->getDeclContext(),
/*skipProtocolSelfConstraint=*/false);
auto numLabelsToRemove = getNumRemovedArgumentLabels(
TC, funcDecl,
/*isCurriedInstanceReference=*/false, functionRefKind);

auto openedType = openFunctionType(funcType, locator, replacements,
funcDecl->getDeclContext())
->removeArgumentLabels(numLabelsToRemove);

// If we opened up any type variables, record the replacements.
recordOpenedTypes(locator, replacements);
Expand Down Expand Up @@ -1103,46 +1070,44 @@ static void bindArchetypesFromContext(
}

void ConstraintSystem::openGeneric(
DeclContext *innerDC,
DeclContext *outerDC,
GenericSignature *sig,
bool skipProtocolSelfConstraint,
ConstraintLocatorBuilder locator,
OpenedTypeMap &replacements,
bool skipGenericRequirements) {
OpenedTypeMap &replacements) {
if (sig == nullptr)
return;

auto locatorPtr = getConstraintLocator(locator);
openGenericParameters(outerDC, sig, replacements, locator);

// Add the requirements as constraints.
openGenericRequirements(
outerDC, sig, /*skipProtocolSelfConstraint=*/false, locator,
[&](Type type) { return openType(type, replacements); });
}

void ConstraintSystem::openGenericParameters(DeclContext *outerDC,
GenericSignature *sig,
OpenedTypeMap &replacements,
ConstraintLocatorBuilder locator) {
assert(sig);

// Create the type variables for the generic parameters.
for (auto gp : sig->getGenericParams()) {
locatorPtr = getConstraintLocator(
locator.withPathElement(LocatorPathElt(gp)));

auto typeVar = createTypeVariable(locatorPtr,
TVO_PrefersSubtypeBinding);
auto result = replacements.insert(
std::make_pair(cast<GenericTypeParamType>(gp->getCanonicalType()),
typeVar));
auto *paramLocator =
getConstraintLocator(locator.withPathElement(LocatorPathElt(gp)));

auto typeVar = createTypeVariable(paramLocator, TVO_PrefersSubtypeBinding);
auto result = replacements.insert(std::make_pair(
cast<GenericTypeParamType>(gp->getCanonicalType()), typeVar));

assert(result.second);
(void) result;
(void)result;
}

// Remember that any new constraints generated by opening this generic are
// due to the opening.
locatorPtr = getConstraintLocator(
auto *baseLocator = getConstraintLocator(
locator.withPathElement(LocatorPathElt::getOpenedGeneric(sig)));

bindArchetypesFromContext(*this, outerDC, locatorPtr, replacements);

if (skipGenericRequirements)
return;

// Add the requirements as constraints.
openGenericRequirements(
outerDC, sig, skipProtocolSelfConstraint, locator,
[&](Type type) { return openType(type, replacements); });
bindArchetypesFromContext(*this, outerDC, baseLocator, replacements);
}

void ConstraintSystem::openGenericRequirements(
Expand Down Expand Up @@ -1331,10 +1296,17 @@ ConstraintSystem::getTypeOfMemberReference(

// While opening member function type, let's delay opening requirements
// to allow contextual types to affect the situation.
openedType = openFunctionType(funcType, numRemovedArgumentLabels,
locator, replacements, innerDC, outerDC,
/*skipProtocolSelfConstraint=*/true,
/*skipGenericRequirements=*/true);
if (auto *genericFn = funcType->getAs<GenericFunctionType>()) {
openGenericParameters(outerDC, genericFn->getGenericSignature(),
replacements, locator);

openedType = genericFn->substGenericArgs(
[&](Type type) { return openType(type, replacements); });
} else {
openedType = funcType;
}

openedType = openedType->removeArgumentLabels(numRemovedArgumentLabels);

if (!outerDC->getSelfProtocolDecl()) {
// Class methods returning Self as well as constructors get the
Expand Down
Loading