Skip to content

[Constraint solver] Compute common apply result type in the solver. #22977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions lib/AST/NameLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ static void recordShadowedDeclsAfterSignatureMatch(
/// Look through the given set of declarations (that all have the same name),
/// recording those that are shadowed by another declaration in the
/// \c shadowed set.
static void recordShadowDeclsAfterObjCInitMatch(
static void recordShadowedDeclsForImportedInits(
ArrayRef<ConstructorDecl *> ctors,
llvm::SmallPtrSetImpl<ValueDecl *> &shadowed) {
assert(ctors.size() > 1 && "No collisions");
Expand Down Expand Up @@ -372,18 +372,21 @@ static void recordShadowedDecls(ArrayRef<ValueDecl *> decls,
llvm::SmallDenseMap<CanType, llvm::TinyPtrVector<ValueDecl *>> collisions;
llvm::SmallVector<CanType, 2> collisionTypes;
llvm::SmallDenseMap<NominalTypeDecl *, llvm::TinyPtrVector<ConstructorDecl *>>
objCInitializerCollisions;
llvm::TinyPtrVector<NominalTypeDecl *> objCInitializerCollisionNominals;
importedInitializerCollisions;
llvm::TinyPtrVector<NominalTypeDecl *> importedInitializerCollectionTypes;

for (auto decl : decls) {
// Specifically keep track of Objective-C initializers, which can come from
// either init methods or factory methods.
if (decl->hasClangNode()) {
// Specifically keep track of imported initializers, which can come from
// Objective-C init methods, Objective-C factory methods, renamed C
// functions, or be synthesized by the importer.
if (decl->hasClangNode() ||
(isa<NominalTypeDecl>(decl->getDeclContext()) &&
cast<NominalTypeDecl>(decl->getDeclContext())->hasClangNode())) {
if (auto ctor = dyn_cast<ConstructorDecl>(decl)) {
auto nominal = ctor->getDeclContext()->getSelfNominalTypeDecl();
auto &knownInits = objCInitializerCollisions[nominal];
auto &knownInits = importedInitializerCollisions[nominal];
if (knownInits.size() == 1) {
objCInitializerCollisionNominals.push_back(nominal);
importedInitializerCollectionTypes.push_back(nominal);
}
knownInits.push_back(ctor);
}
Expand Down Expand Up @@ -431,9 +434,9 @@ static void recordShadowedDecls(ArrayRef<ValueDecl *> decls,
shadowed);
}

// Check whether we have shadowing for Objective-C initializer collisions.
for (auto nominal : objCInitializerCollisionNominals) {
recordShadowDeclsAfterObjCInitMatch(objCInitializerCollisions[nominal],
// Check whether we have shadowing for imported initializer collisions.
for (auto nominal : importedInitializerCollectionTypes) {
recordShadowedDeclsForImportedInits(importedInitializerCollisions[nominal],
shadowed);
}
}
Expand Down
88 changes: 21 additions & 67 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1123,10 +1123,6 @@ namespace {

if (outputTy.isNull()) {
outputTy = CS.createTypeVariable(resultLocator, TVO_CanBindToLValue);
} else {
// TODO: Generalize this for non-subscript-expr anchors, so that e.g.
// keypath lookup benefits from the peephole as well.
CS.setFavoredType(anchor, outputTy.getPointer());
}

// FIXME: This can only happen when diagnostics successfully type-checked
Expand Down Expand Up @@ -1171,6 +1167,13 @@ namespace {
memberTy,
fnLocator);

Type fixedOutputType =
CS.getFixedTypeRecursive(outputTy, /*wantRValue=*/false);
if (!fixedOutputType->isTypeVariableOrMember()) {
CS.setFavoredType(anchor, fixedOutputType.getPointer());
outputTy = fixedOutputType;
}

return outputTy;
}

Expand Down Expand Up @@ -2516,8 +2519,6 @@ namespace {
}

Type visitApplyExpr(ApplyExpr *expr) {
Type outputTy;

auto fnExpr = expr->getFn();

if (auto *UDE = dyn_cast<UnresolvedDotExpr>(fnExpr)) {
Expand All @@ -2526,65 +2527,9 @@ namespace {
return resultOfTypeOperation(typeOperation, expr->getArg());
}

if (auto fnType = CS.getType(fnExpr)->getAs<AnyFunctionType>()) {
outputTy = fnType->getResult();
} else if (auto OSR = dyn_cast<OverloadedDeclRefExpr>(fnExpr)) {
// Determine if the overloads are all functions that share a common
// return type.
Type commonType;
for (auto OD : OSR->getDecls()) {
auto OFD = dyn_cast<AbstractFunctionDecl>(OD);
if (!OFD) {
commonType = Type();
break;
}

auto OFT = OFD->getInterfaceType()->getAs<AnyFunctionType>();
if (!OFT) {
commonType = Type();
break;
}

// Look past the self parameter.
if (OFD->getDeclContext()->isTypeContext()) {
OFT = OFT->getResult()->getAs<AnyFunctionType>();
if (!OFT) {
commonType = Type();
break;
}
}

Type resultType = OFT->getResult();

// If there are any type parameters in the result,
if (resultType->hasTypeParameter()) {
commonType = Type();
break;
}

if (commonType.isNull()) {
commonType = resultType;
} else if (!commonType->isEqual(resultType)) {
commonType = Type();
break;
}
}

if (commonType) {
outputTy = commonType;
}
}

// The function subexpression has some rvalue type T1 -> T2 for fresh
// variables T1 and T2.
if (outputTy.isNull()) {
outputTy = CS.createTypeVariable(
CS.getConstraintLocator(expr, ConstraintLocator::FunctionResult));
} else {
// Since we know what the output type is, we can set it as the favored
// type of this expression.
CS.setFavoredType(expr, outputTy.getPointer());
}
// The result type is a fresh type variable.
Type resultType = CS.createTypeVariable(
CS.getConstraintLocator(expr, ConstraintLocator::FunctionResult));

// A direct call to a ClosureExpr makes it noescape.
FunctionType::ExtInfo extInfo;
Expand All @@ -2598,11 +2543,20 @@ namespace {
AnyFunctionType::decomposeInput(CS.getType(expr->getArg()), params);

CS.addConstraint(ConstraintKind::ApplicableFunction,
FunctionType::get(params, outputTy, extInfo),
FunctionType::get(params, resultType, extInfo),
CS.getType(expr->getFn()),
CS.getConstraintLocator(expr, ConstraintLocator::ApplyFunction));

return outputTy;
// If we ended up resolving the result type variable to a concrete type,
// set it as the favored type for this expression.
Type fixedType =
CS.getFixedTypeRecursive(resultType, /*wantRvalue=*/true);
if (!fixedType->isTypeVariableOrMember()) {
CS.setFavoredType(expr, fixedType.getPointer());
resultType = fixedType;
}

return resultType;
}

Type getSuperType(VarDecl *selfDecl,
Expand Down
22 changes: 20 additions & 2 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4777,7 +4777,7 @@ ConstraintSystem::simplifyApplicableFnConstraint(

// By construction, the left hand side is a type that looks like the
// following: $T1 -> $T2.
assert(type1->is<FunctionType>());
auto func1 = type1->castTo<FunctionType>();

// Let's check if this member couldn't be found and is fixed
// to exist based on its usage.
Expand Down Expand Up @@ -4821,6 +4821,25 @@ ConstraintSystem::simplifyApplicableFnConstraint(

};

// If the right-hand side is a type variable, try to find a common result
// type in the overload set.
if (auto typeVar = desugar2->getAs<TypeVariableType>()) {
auto choices = getUnboundBindOverloads(typeVar);
if (Type resultType = findCommonResultType(choices)) {
ASTContext &ctx = getASTContext();
if (ctx.LangOpts.DebugConstraintSolver) {
auto &log = ctx.TypeCheckerDebug->getStream();
log.indent(solverState ? solverState->depth * 2 + 2 : 0)
<< "(common result type for $T" << typeVar->getID() << " is "
<< resultType.getString()
<< ")\n";
}

addConstraint(ConstraintKind::Bind, func1->getResult(), resultType,
locator);
}
}

// If right-hand side is a type variable, the constraint is unsolved.
if (desugar2->isTypeVariableOrMember())
return formUnsolved();
Expand Down Expand Up @@ -4856,7 +4875,6 @@ ConstraintSystem::simplifyApplicableFnConstraint(
}

// For a function, bind the output and convert the argument to the input.
auto func1 = type1->castTo<FunctionType>();
if (auto func2 = dyn_cast<FunctionType>(desugar2)) {
// The argument type must be convertible to the input type.
if (::matchCallArguments(
Expand Down
50 changes: 38 additions & 12 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1567,19 +1567,13 @@ bool ConstraintSystem::haveTypeInformationForAllArguments(
});
}

// Given a type variable representing the RHS of an ApplicableFunction
// constraint, attempt to find the disjunction of bind overloads
// associated with it. This may return null in cases where have not
// yet created a disjunction because we need to resolve a base type,
// e.g.: [1].map{ ... } does not have a disjunction until we decide on
// a type for [1].
static Constraint *getUnboundBindOverloadDisjunction(TypeVariableType *tyvar,
ConstraintSystem &cs) {
auto *rep = cs.getRepresentative(tyvar);
assert(!cs.getFixedType(rep));
Constraint *ConstraintSystem::getUnboundBindOverloadDisjunction(
TypeVariableType *tyvar) {
auto *rep = getRepresentative(tyvar);
assert(!getFixedType(rep));

llvm::SetVector<Constraint *> disjunctions;
cs.getConstraintGraph().gatherConstraints(
getConstraintGraph().gatherConstraints(
rep, disjunctions, ConstraintGraph::GatheringKind::EquivalenceClass,
[](Constraint *match) {
return match->getKind() == ConstraintKind::Disjunction &&
Expand All @@ -1593,6 +1587,38 @@ static Constraint *getUnboundBindOverloadDisjunction(TypeVariableType *tyvar,
return disjunctions[0];
}

/// solely resolved by an overload set.
SmallVector<OverloadChoice, 2> ConstraintSystem::getUnboundBindOverloads(
TypeVariableType *tyvar) {
// Always work on the representation.
tyvar = getRepresentative(tyvar);

SmallVector<OverloadChoice, 2> choices;

auto disjunction = getUnboundBindOverloadDisjunction(tyvar);
if (!disjunction) return choices;

for (auto constraint : disjunction->getNestedConstraints()) {
// We must have bind-overload constraints.
if (constraint->getKind() != ConstraintKind::BindOverload) {
choices.clear();
return choices;
}

// We must be binding the type variable (or a type variable equivalent to
// it).
auto boundTypeVar = constraint->getFirstType()->getAs<TypeVariableType>();
if (!boundTypeVar || getRepresentative(boundTypeVar) != tyvar) {
choices.clear();
return choices;
}

choices.push_back(constraint->getOverloadChoice());
}

return choices;
}

// Find a disjunction associated with an ApplicableFunction constraint
// where we have some information about all of the types of in the
// function application (even if we only know something about what the
Expand All @@ -1608,7 +1634,7 @@ Constraint *ConstraintSystem::selectApplyDisjunction() {
auto *tyvar = applicable->getSecondType()->castTo<TypeVariableType>();

// If we have created the disjunction for this apply, find it.
auto *disjunction = getUnboundBindOverloadDisjunction(tyvar, *this);
auto *disjunction = getUnboundBindOverloadDisjunction(tyvar);
if (disjunction)
return disjunction;
}
Expand Down
Loading