Skip to content

[CS] Unify ReturnStmt handling #71272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -4562,8 +4562,7 @@ class ConstraintSystem {
///
/// \returns a possibly-sanitized expression, or null if an error occurred.
[[nodiscard]]
Expr *generateConstraints(Expr *E, DeclContext *dc,
bool isInputExpression = true);
Expr *generateConstraints(Expr *E, DeclContext *dc);

/// Generate constraints for binding the given pattern to the
/// value of the given expression.
Expand Down
5 changes: 5 additions & 0 deletions include/swift/Sema/SyntacticElementTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ class SyntacticElementTarget {
expression.contextualInfo.typeLoc = type;
}

void setExprContextualTypePurpose(ContextualTypePurpose ctp) {
assert(kind == Kind::expression);
expression.contextualInfo.purpose = ctp;
}

/// Whether this target is for an initialization expression and pattern.
bool isForInitialization() const {
return kind == Kind::expression &&
Expand Down
18 changes: 2 additions & 16 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9595,22 +9595,8 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {

auto *locator = target.getExprConvertTypeLocator();
if (!locator) {
// Bodies of single-expression closures use a special locator
// for contextual type conversion to make sure that result is
// convertible to `Void` when `return` is not used explicitly.
auto *closure = dyn_cast<ClosureExpr>(target.getDeclContext());
if (closure && closure->hasSingleExpressionBody() &&
contextualTypePurpose == CTP_ClosureResult) {
auto *returnStmt =
castToStmt<ReturnStmt>(closure->getBody()->getLastElement());

locator = cs.getConstraintLocator(
closure, LocatorPathElt::ClosureBody(
/*hasImpliedReturn*/ returnStmt->isImplied()));
} else {
locator = cs.getConstraintLocator(
expr, LocatorPathElt::ContextualType(contextualTypePurpose));
}
locator = cs.getConstraintLocator(
expr, LocatorPathElt::ContextualType(contextualTypePurpose));
}
assert(locator);

Expand Down
6 changes: 2 additions & 4 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4959,10 +4959,8 @@ bool ConstraintSystem::generateConstraints(
}
}

Expr *ConstraintSystem::generateConstraints(
Expr *expr, DeclContext *dc, bool isInputExpression) {
if (isInputExpression)
InputExprs.insert(expr);
Expr *ConstraintSystem::generateConstraints(Expr *expr, DeclContext *dc) {
InputExprs.insert(expr);
return generateConstraintsFor(*this, expr, dc);
}

Expand Down
79 changes: 28 additions & 51 deletions lib/Sema/CSSyntacticElement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ struct SyntacticElementContext
}
}

bool isSingleExpressionClosure(ConstraintSystem &cs) {
bool isSingleExpressionClosure(ConstraintSystem &cs) const {
if (auto ref = getAsAnyFunctionRef()) {
if (cs.getAppliedResultBuilderTransform(*ref))
return false;
Expand Down Expand Up @@ -1115,8 +1115,8 @@ class SyntacticElementConstraintGenerator

for (auto node : braceStmt->getElements()) {
if (auto expr = node.dyn_cast<Expr *>()) {
auto generatedExpr = cs.generateConstraints(
expr, context.getAsDeclContext(), /*isInputExpression=*/false);
auto generatedExpr =
cs.generateConstraints(expr, context.getAsDeclContext());
if (!generatedExpr) {
hadError = true;
}
Expand Down Expand Up @@ -1242,33 +1242,7 @@ class SyntacticElementConstraintGenerator
}

void visitReturnStmt(ReturnStmt *returnStmt) {
// Single-expression closures are effectively a `return` statement,
// so let's give them a special locator as to indicate that.
// Return statements might not have a result if we have a closure whose
// implicit returned value is coerced to Void.
if (context.isSingleExpressionClosure(cs) && returnStmt->hasResult()) {
auto *expr = returnStmt->getResult();
assert(expr && "single expression closure without expression?");

expr = cs.generateConstraints(expr, context.getAsDeclContext(),
/*isInputExpression=*/false);
if (!expr) {
hadError = true;
return;
}

auto contextualResultInfo = getContextualResultInfo();
cs.addConstraint(ConstraintKind::Conversion, cs.getType(expr),
contextualResultInfo.getType(),
cs.getConstraintLocator(
context.getAsAbstractClosureExpr().get(),
LocatorPathElt::ClosureBody(
/*hasImpliedReturn=*/returnStmt->isImplied())));
return;
}

Expr *resultExpr;

if (returnStmt->hasResult()) {
resultExpr = returnStmt->getResult();
assert(resultExpr && "non-empty result without expression?");
Expand All @@ -1280,10 +1254,10 @@ class SyntacticElementConstraintGenerator
resultExpr = getVoidExpr(cs.getASTContext(), returnStmt->getEndLoc());
}

auto contextualResultInfo = getContextualResultInfo();
auto contextualResultInfo = getContextualResultInfoFor(returnStmt);

SyntacticElementTarget target(resultExpr, context.getAsDeclContext(),
contextualResultInfo,
/*isDiscarded=*/false);
contextualResultInfo, /*isDiscarded=*/false);

if (cs.generateConstraints(target)) {
hadError = true;
Expand Down Expand Up @@ -1328,7 +1302,7 @@ class SyntacticElementConstraintGenerator
createConjunction({resultElt}, locator);
}

ContextualTypeInfo getContextualResultInfo() const {
ContextualTypeInfo getContextualResultInfoFor(ReturnStmt *returnStmt) const {
auto funcRef = AnyFunctionRef::fromDeclContext(context.getAsDeclContext());
if (!funcRef)
return {Type(), CTP_Unused};
Expand All @@ -1337,8 +1311,18 @@ class SyntacticElementConstraintGenerator
return {transform->bodyResultType, CTP_ReturnStmt};

if (auto *closure =
getAsExpr<ClosureExpr>(funcRef->getAbstractClosureExpr()))
return {cs.getClosureType(closure)->getResult(), CTP_ClosureResult};
getAsExpr<ClosureExpr>(funcRef->getAbstractClosureExpr())) {
// Single-expression closures need their contextual type locator anchored
// on the closure itself. Otherwise we use the default contextual type
// locator, which will be created for us.
ConstraintLocator *loc = nullptr;
if (context.isSingleExpressionClosure(cs) && returnStmt->hasResult()) {
loc = cs.getConstraintLocator(
closure, {LocatorPathElt::ClosureBody(
/*hasImpliedReturn=*/returnStmt->isImplied())});
}
return {cs.getClosureType(closure)->getResult(), CTP_ClosureResult, loc};
}

return {funcRef->getBodyResultType(), CTP_ReturnStmt};
}
Expand Down Expand Up @@ -2156,22 +2140,15 @@ class SyntacticElementSolutionApplication
mode = convertToResult;
}

llvm::Optional<SyntacticElementTarget> resultTarget;
if (auto target = cs.getTargetFor(returnStmt)) {
resultTarget = *target;
} else {
// Single-expression closures have to handle returns in a special
// way so the target has to be created for them during solution
// application based on the resolved type.
assert(context.isSingleExpressionClosure(cs));
resultTarget = SyntacticElementTarget(
resultExpr, context.getAsDeclContext(),
mode == convertToResult ? CTP_ClosureResult : CTP_Unused,
mode == convertToResult ? resultType : Type(),
/*isDiscarded=*/false);
}

if (auto newResultTarget = rewriteTarget(*resultTarget)) {
auto target = *cs.getTargetFor(returnStmt);

// If we're not converting to a result, unset the contextual type.
if (mode != convertToResult) {
target.setExprConversionType(Type());
target.setExprContextualTypePurpose(CTP_Unused);
}

if (auto newResultTarget = rewriteTarget(target)) {
resultExpr = newResultTarget->getAsExpr();
}

Expand Down
6 changes: 6 additions & 0 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,12 @@ void ConstraintSystem::addPackEnvironment(PackElementExpr *packElement,
static void extendDepthMap(
Expr *expr,
llvm::DenseMap<Expr *, std::pair<unsigned, Expr *>> &depthMap) {
// If we already have an entry in the map, we don't need to update it. This
// avoids invalidating previous entries when solving a smaller component of a
// larger AST node, e.g during conjunction solving.
if (depthMap.contains(expr))
return;

class RecordingTraversal : public ASTWalker {
SmallVector<ClosureExpr *, 4> Closures;

Expand Down
21 changes: 19 additions & 2 deletions test/IDE/complete_single_expression_return.swift
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,27 @@ struct TestExplicitSingleExprClosureBinding {
return self.#^TestExplicitSingleExprClosureBinding^#
}
}
// FIXME: Because we have an explicit return, and no expected type, we shouldn't suggest Void.
// We have an explicit return, and no expected type, so we don't suggest Void.
// TestExplicitSingleExprClosureBinding-DAG: Decl[InstanceMethod]/CurrNominal: str()[#String#];
// TestExplicitSingleExprClosureBinding-DAG: Decl[InstanceMethod]/CurrNominal: int()[#Int#];
// TestExplicitSingleExprClosureBinding-DAG: Decl[InstanceMethod]/CurrNominal: void()[#Void#];
// TestExplicitSingleExprClosureBinding-DAG: Decl[InstanceMethod]/CurrNominal/TypeRelation[Invalid]: void()[#Void#];
}

struct TestExplicitMultiStmtClosureBinding {
func void() -> Void {}
func str() -> String { return "" }
func int() -> Int { return 0 }

func test() {
let fn = {
()
return self.#^TestExplicitMultiStmtClosureBinding^#
}
}
// We have an explicit return, and no expected type, so we don't suggest Void.
// TestExplicitMultiStmtClosureBinding-DAG: Decl[InstanceMethod]/CurrNominal: str()[#String#];
// TestExplicitMultiStmtClosureBinding-DAG: Decl[InstanceMethod]/CurrNominal: int()[#Int#];
// TestExplicitMultiStmtClosureBinding-DAG: Decl[InstanceMethod]/CurrNominal/TypeRelation[Invalid]: void()[#Void#];
}

struct TestExplicitSingleExprClosureBindingWithContext {
Expand Down
47 changes: 44 additions & 3 deletions unittests/Sema/ConstraintGenerationTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ TEST_F(SemaTest, TestImplicitForceCastConstraintGeneration) {
auto *castExpr = ForcedCheckedCastExpr::createImplicit(Context, literal,
Context.TheAnyType);

auto *expr = cs.generateConstraints(castExpr, DC, /*isInputExpression=*/true);
auto *expr = cs.generateConstraints(castExpr, DC);

ASSERT_NE(expr, nullptr);

Expand Down Expand Up @@ -66,7 +66,7 @@ TEST_F(SemaTest, TestImplicitCoercionConstraintGeneration) {
auto *castExpr = CoerceExpr::createImplicit(Context, literal,
getStdlibType("Double"));

auto *expr = cs.generateConstraints(castExpr, DC, /*isInputExpression=*/true);
auto *expr = cs.generateConstraints(castExpr, DC);

ASSERT_NE(expr, nullptr);

Expand Down Expand Up @@ -95,7 +95,7 @@ TEST_F(SemaTest, TestImplicitConditionalCastConstraintGeneration) {
auto *castExpr = ConditionalCheckedCastExpr::createImplicit(
Context, literal, getStdlibType("Double"));

auto *expr = cs.generateConstraints(castExpr, DC, /*isInputExpression=*/true);
auto *expr = cs.generateConstraints(castExpr, DC);

ASSERT_NE(expr, nullptr);

Expand Down Expand Up @@ -179,3 +179,44 @@ TEST_F(SemaTest, TestCaptureListIsNotOpenedEarly) {
ASSERT_TRUE(cs.hasType(capture.getVar()));
}
}

TEST_F(SemaTest, TestMultiStmtClosureBodyParentAndDepth) {
// {
// ()
// return ()
// }
DeclAttributes attrs;
auto *closure = new (Context) ClosureExpr(attrs,
/*braceRange=*/SourceRange(),
/*capturedSelfDecl=*/nullptr,
ParameterList::createEmpty(Context),
/*asyncLoc=*/SourceLoc(),
/*throwsLoc=*/SourceLoc(),
/*thrownType*/ nullptr,
/*arrowLoc=*/SourceLoc(),
/*inLoc=*/SourceLoc(),
/*explicitResultType=*/nullptr, DC);
closure->setImplicit();

auto *RS = ReturnStmt::createImplicit(
Context, TupleExpr::createImplicit(Context, {}, {}));

closure->setBody(BraceStmt::createImplicit(Context, {
TupleExpr::createImplicit(Context, {}, {}), RS
}));

SyntacticElementTarget target(closure, DC, ContextualTypeInfo(),
/*isDiscarded*/ true);

ConstraintSystem cs(DC, ConstraintSystemOptions());
cs.solve(target);

ASSERT_EQ(cs.getParentExpr(closure), nullptr);
ASSERT_EQ(cs.getExprDepth(closure), 0);

// We visit the ReturnStmt twice when computing the parent map, ensure we
// don't invalidate its parent on the second walk during the conjunction.
auto *result = RS->getResult();
ASSERT_EQ(cs.getParentExpr(result), closure);
ASSERT_EQ(cs.getExprDepth(result), 1);
}