Skip to content

Commit 7c65ed2

Browse files
committed
Use disjunction to resolve main function
Using `resolveValueMember` can result in the wrong, or unrelated main functions being selected. If an error occurs while resolving, such as an ambiguous resolution, the solution will contain everything called 'main'. During the resolution, the `BestOverload` will not be set, so the code skips down to iterating over the member decls and selecting the first declaration that is a `MainTypeMainMethod`. The order of candidates in the failure state is related to the order that the declarations appear in source. The selected and potentially unrelated main function is called from `$main`. The call in the expression will then detect the invalid state later. When only a single main function could exist in a given context, this was not a problem. Any other possible solutions would result in an error, either due to a duplicate declaration, or from calling an unrelated function. When we introduced the asynchronous main function, the resolution can be ambiguous because we allow asynchronous overloads of synchronous functions. This enables us to legally write ``` struct MainType { static func main() { } static func main() async { } } ``` From the perspective of duplicate declarations, this is not an issue. It is, however, ambiguous since we have no calling context with which to break the tie. In the original code, this example would select the synchronous main function because it comes first in the source. If we flip the declarations, the asynchronous main function is selected because it comes first in the source. Instead, using the constraint solver to solve for the correct overload will ensure that we only get back a valid main function. If the constraint solver reports no solutions or many solutions, we can emit an error immediately, as it will not consider unrelated functions.
1 parent 5dfc6b8 commit 7c65ed2

File tree

1 file changed

+79
-27
lines changed

1 file changed

+79
-27
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 79 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2096,28 +2096,6 @@ synthesizeMainBody(AbstractFunctionDecl *fn, void *arg) {
20962096
return std::make_pair(body, /*typechecked=*/false);
20972097
}
20982098

2099-
static FuncDecl *resolveMainFunctionDecl(DeclContext *declContext,
2100-
ResolvedMemberResult &resolution,
2101-
ASTContext &ctx) {
2102-
// Choose the best overload if it's a main function
2103-
if (resolution.hasBestOverload()) {
2104-
ValueDecl *best = resolution.getBestOverload();
2105-
if (FuncDecl *func = dyn_cast<FuncDecl>(best)) {
2106-
if (func->isMainTypeMainMethod()) {
2107-
return func;
2108-
}
2109-
}
2110-
}
2111-
// Look for the most highly-ranked main-function candidate
2112-
for (ValueDecl *candidate : resolution.getMemberDecls(Viable)) {
2113-
if (FuncDecl *func = dyn_cast<FuncDecl>(candidate)) {
2114-
if (func->isMainTypeMainMethod())
2115-
return func;
2116-
}
2117-
}
2118-
return nullptr;
2119-
}
2120-
21212099
FuncDecl *
21222100
SynthesizeMainFunctionRequest::evaluate(Evaluator &evaluator,
21232101
Decl *D) const {
@@ -2171,11 +2149,85 @@ SynthesizeMainFunctionRequest::evaluate(Evaluator &evaluator,
21712149
// mainType.main() from the entry point, and that would require fully
21722150
// type-checking the call to mainType.main().
21732151

2174-
auto resolution = resolveValueMember(
2175-
*declContext, nominal->getInterfaceType(), context.Id_main,
2176-
constraints::ConstraintSystemFlags::IgnoreAsyncSyncMismatch);
2177-
FuncDecl *mainFunction =
2178-
resolveMainFunctionDecl(declContext, resolution, context);
2152+
constraints::ConstraintSystem CS(
2153+
declContext, constraints::ConstraintSystemFlags::IgnoreAsyncSyncMismatch);
2154+
constraints::ConstraintLocator *locator = CS.getConstraintLocator({});
2155+
// Allowed main function types
2156+
// `() -> Void`
2157+
// `() async -> Void`
2158+
// `() throws -> Void`
2159+
// `() async throws -> Void`
2160+
// `@MainActor () -> Void`
2161+
// `@MainActor () async -> Void`
2162+
// `@MainActor () throws -> Void`
2163+
// `@MainActor () async throws -> Void`
2164+
{
2165+
llvm::SmallVector<Type, 8> mainTypes = {
2166+
2167+
FunctionType::get(/*params*/ {}, context.TheEmptyTupleType,
2168+
ASTExtInfo()),
2169+
FunctionType::get(
2170+
/*params*/ {}, context.TheEmptyTupleType,
2171+
ASTExtInfoBuilder().withAsync().build()),
2172+
2173+
FunctionType::get(/*params*/ {}, context.TheEmptyTupleType,
2174+
ASTExtInfoBuilder().withThrows().build()),
2175+
2176+
FunctionType::get(
2177+
/*params*/ {}, context.TheEmptyTupleType,
2178+
ASTExtInfoBuilder().withAsync().withThrows().build())};
2179+
2180+
Type mainActor = context.getMainActorType();
2181+
if (mainActor) {
2182+
mainTypes.push_back(FunctionType::get(
2183+
/*params*/ {}, context.TheEmptyTupleType,
2184+
ASTExtInfoBuilder().withGlobalActor(mainActor).build()));
2185+
mainTypes.push_back(FunctionType::get(
2186+
/*params*/ {}, context.TheEmptyTupleType,
2187+
ASTExtInfoBuilder().withAsync().withGlobalActor(mainActor).build()));
2188+
mainTypes.push_back(FunctionType::get(
2189+
/*params*/ {}, context.TheEmptyTupleType,
2190+
ASTExtInfoBuilder().withThrows().withGlobalActor(mainActor).build()));
2191+
mainTypes.push_back(FunctionType::get(/*params*/ {},
2192+
context.TheEmptyTupleType,
2193+
ASTExtInfoBuilder()
2194+
.withAsync()
2195+
.withThrows()
2196+
.withGlobalActor(mainActor)
2197+
.build()));
2198+
}
2199+
2200+
llvm::SmallVector<constraints::Constraint *, 4> mainTypeConstraints;
2201+
for (const Type &mainType : mainTypes) {
2202+
constraints::Constraint *fnConstraint =
2203+
constraints::Constraint::createMember(
2204+
CS, constraints::ConstraintKind::ValueMember,
2205+
nominal->getInterfaceType(), mainType,
2206+
DeclNameRef(context.Id_main), declContext,
2207+
FunctionRefKind::SingleApply, locator);
2208+
mainTypeConstraints.push_back(fnConstraint);
2209+
}
2210+
2211+
CS.addDisjunctionConstraint(mainTypeConstraints, locator);
2212+
}
2213+
2214+
FuncDecl *mainFunction = nullptr;
2215+
llvm::SmallVector<constraints::Solution, 4> candidates;
2216+
2217+
if (!CS.solve(candidates, FreeTypeVariableBinding::Disallow)) {
2218+
if (candidates.size() != 1) {
2219+
context.Diags.diagnose(nominal->getLoc(), diag::ambiguous_decl_ref,
2220+
DeclNameRef(context.Id_main));
2221+
// TODO: CS.diagnoseAmbiguity doesn't report anything because the types
2222+
// are different. It would be good to get notes on the decls causing the
2223+
// ambiguity.
2224+
attr->setInvalid();
2225+
return nullptr;
2226+
}
2227+
mainFunction = dyn_cast<FuncDecl>(
2228+
candidates[0].overloadChoices[locator].choice.getDecl());
2229+
}
2230+
21792231
if (!mainFunction) {
21802232
const bool hasAsyncSupport =
21812233
AvailabilityContext::forDeploymentTarget(context).isContainedIn(

0 commit comments

Comments
 (0)