Skip to content

Sema: Small fixes for result builder inference #76749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/swift/AST/AnyFunctionRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,18 @@ class AnyFunctionRef {
llvm_unreachable("autoclosures don't have statement bodies");
}

/// Returns a boolean value indicating whether the body, if any, contains
/// an explicit `return` statement.
///
/// \returns `true` if the body contains an explicit `return` statement,
/// `false` otherwise.
bool bodyHasExplicitReturnStmt() const;

/// Finds occurrences of explicit `return` statements within the body, if any.
///
/// \param results An out container to which the results are added.
void getExplicitReturnStmts(SmallVectorImpl<ReturnStmt *> &results) const;

DeclContext *getAsDeclContext() const {
if (auto *AFD = TheFunction.dyn_cast<AbstractFunctionDecl *>())
return AFD;
Expand Down
13 changes: 13 additions & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ namespace swift {
class ProtocolType;
struct RawComment;
enum class ResilienceExpansion : unsigned;
class ReturnStmt;
enum class EffectKind : uint8_t;
enum class PolymorphicEffectKind : uint8_t;
class TrailingWhereClause;
Expand Down Expand Up @@ -7675,6 +7676,18 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
/// parsed.
bool hasBody() const;

/// Returns a boolean value indicating whether the body, if any, contains
/// an explicit `return` statement.
///
/// \returns `true` if the body contains an explicit `return` statement,
/// `false` otherwise.
bool bodyHasExplicitReturnStmt() const;

/// Finds occurrences of explicit `return` statements within the body, if any.
///
/// \param results An out container to which the results are added.
void getExplicitReturnStmts(SmallVectorImpl<ReturnStmt *> &results) const;

/// Returns true if the text of this function's body can be retrieved either
/// by extracting the text from the source buffer or reading the inlinable
/// body from a deserialized swiftmodule.
Expand Down
13 changes: 13 additions & 0 deletions include/swift/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ namespace swift {
class KeyPathExpr;
class CaptureListExpr;
class ThenStmt;
class ReturnStmt;

enum class ExprKind : uint8_t {
#define EXPR(Id, Parent) Id,
Expand Down Expand Up @@ -4051,6 +4052,18 @@ class AbstractClosureExpr : public DeclContext, public Expr {
/// returns nullptr if the closure doesn't have a body
BraceStmt *getBody() const;

/// Returns a boolean value indicating whether the body, if any, contains
/// an explicit `return` statement.
///
/// \returns `true` if the body contains an explicit `return` statement,
/// `false` otherwise.
bool bodyHasExplicitReturnStmt() const;

/// Finds occurrences of explicit `return` statements within the body, if any.
///
/// \param results An out container to which the results are added.
void getExplicitReturnStmts(SmallVectorImpl<ReturnStmt *> &results) const;

ActorIsolation getActorIsolation() const {
return actorIsolation;
}
Expand Down
6 changes: 3 additions & 3 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -3050,9 +3050,9 @@ class AssociatedConformanceRequest
void cacheResult(ProtocolConformanceRef value) const;
};

class BraceHasReturnRequest
: public SimpleRequest<BraceHasReturnRequest, bool(const BraceStmt *),
RequestFlags::Cached> {
class BraceHasExplicitReturnStmtRequest
: public SimpleRequest<BraceHasExplicitReturnStmtRequest,
bool(const BraceStmt *), RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;

Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ SWIFT_REQUEST(TypeChecker, HasUserDefinedDesignatedInitRequest,
bool(NominalTypeDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, HasMemberwiseInitRequest,
bool(StructDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, BraceHasReturnRequest,
SWIFT_REQUEST(TypeChecker, BraceHasExplicitReturnStmtRequest,
bool(const BraceStmt *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, ResolveImplicitMemberRequest,
Expand Down
11 changes: 11 additions & 0 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9356,6 +9356,17 @@ bool AbstractFunctionDecl::hasBody() const {
}
}

bool AbstractFunctionDecl::bodyHasExplicitReturnStmt() const {
return AnyFunctionRef(const_cast<AbstractFunctionDecl *>(this))
.bodyHasExplicitReturnStmt();
}

void AbstractFunctionDecl::getExplicitReturnStmts(
SmallVectorImpl<ReturnStmt *> &results) const {
AnyFunctionRef(const_cast<AbstractFunctionDecl *>(this))
.getExplicitReturnStmts(results);
}

/// Expand all preamble macros attached to the given function declaration.
static std::vector<ASTNode> expandPreamble(AbstractFunctionDecl *func) {
std::vector<ASTNode> preamble;
Expand Down
11 changes: 11 additions & 0 deletions lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1962,6 +1962,17 @@ BraceStmt * AbstractClosureExpr::getBody() const {
llvm_unreachable("Unknown closure expression");
}

bool AbstractClosureExpr::bodyHasExplicitReturnStmt() const {
return AnyFunctionRef(const_cast<AbstractClosureExpr *>(this))
.bodyHasExplicitReturnStmt();
}

void AbstractClosureExpr::getExplicitReturnStmts(
SmallVectorImpl<ReturnStmt *> &results) const {
AnyFunctionRef(const_cast<AbstractClosureExpr *>(this))
.getExplicitReturnStmts(results);
}

Type AbstractClosureExpr::getResultType(
llvm::function_ref<Type(Expr *)> getType) const {
auto *E = const_cast<AbstractClosureExpr *>(this);
Expand Down
114 changes: 74 additions & 40 deletions lib/Sema/BuilderTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,13 +918,13 @@ std::optional<BraceStmt *>
TypeChecker::applyResultBuilderBodyTransform(FuncDecl *func, Type builderType) {
// First look for any return statements, and bail if we have any.
auto &ctx = func->getASTContext();
if (evaluateOrDefault(ctx.evaluator, BraceHasReturnRequest{func->getBody()},
false)) {

SmallVector<ReturnStmt *> returnStmts;
func->getExplicitReturnStmts(returnStmts);

if (!returnStmts.empty()) {
// One or more explicit 'return' statements were encountered, which
// disables the result builder transform. Warn when we do this.
auto returnStmts = findReturnStatements(func);
assert(!returnStmts.empty());

ctx.Diags.diagnose(
returnStmts.front()->getReturnLoc(),
diag::result_builder_disabled_by_return_warn, builderType);
Expand Down Expand Up @@ -1126,8 +1126,7 @@ ConstraintSystem::matchResultBuilder(AnyFunctionRef fn, Type builderType,
// not apply the result builder transform if it contained an explicit return.
// To maintain source compatibility, we still need to check for HasReturnStmt.
// https://github.com/apple/swift/issues/64332.
if (evaluateOrDefault(getASTContext().evaluator,
BraceHasReturnRequest{fn.getBody()}, false)) {
if (fn.bodyHasExplicitReturnStmt()) {
// Diagnostic mode means that solver couldn't reach any viable
// solution, so let's diagnose presence of a `return` statement
// in the closure body.
Expand Down Expand Up @@ -1235,49 +1234,84 @@ void ConstraintSystem::removeResultBuilderTransform(AnyFunctionRef fn) {
ASSERT(erased);
}

namespace {
class ReturnStmtFinder : public ASTWalker {
std::vector<ReturnStmt *> ReturnStmts;
/// Walks the given brace statement and calls the given function reference on
/// every occurrence of an explicit `return` statement.
///
/// \param callback A function reference that takes a `return` statement and
/// returns a boolean value indicating whether to abort the walk.
///
/// \returns `true` if the walk was aborted, `false` otherwise.
static bool walkExplicitReturnStmts(const BraceStmt *BS,
function_ref<bool(ReturnStmt *)> callback) {
class Walker : public ASTWalker {
function_ref<bool(ReturnStmt *)> callback;

public:
Walker(decltype(Walker::callback) callback) : callback(callback) {}

MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
}

public:
static std::vector<ReturnStmt *> find(const BraceStmt *BS) {
ReturnStmtFinder finder;
const_cast<BraceStmt *>(BS)->walk(finder);
return std::move(finder.ReturnStmts);
}
PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
return Action::SkipNode(E);
}

MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
}
PreWalkResult<Stmt *> walkToStmtPre(Stmt *S) override {
if (S->isImplicit()) {
return Action::SkipNode(S);
}

PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
return Action::SkipNode(E);
}
auto *returnStmt = dyn_cast<ReturnStmt>(S);
if (!returnStmt) {
return Action::Continue(S);
}

PreWalkResult<Stmt *> walkToStmtPre(Stmt *S) override {
// If we see a return statement, note it..
auto *returnStmt = dyn_cast<ReturnStmt>(S);
if (!returnStmt || returnStmt->isImplicit())
return Action::Continue(S);
if (callback(returnStmt)) {
return Action::Stop();
}

ReturnStmts.push_back(returnStmt);
return Action::SkipNode(S);
}
// Skip children & post walk and continue.
return Action::SkipNode(S);
}

/// Ignore patterns.
PreWalkResult<Pattern *> walkToPatternPre(Pattern *pat) override {
return Action::SkipNode(pat);
/// Ignore patterns.
PreWalkResult<Pattern *> walkToPatternPre(Pattern *pat) override {
return Action::SkipNode(pat);
}
};

Walker walker(callback);

return const_cast<BraceStmt *>(BS)->walk(walker) == nullptr;
}

bool BraceHasExplicitReturnStmtRequest::evaluate(Evaluator &evaluator,
const BraceStmt *BS) const {
return walkExplicitReturnStmts(BS, [](ReturnStmt *) { return true; });
}

bool AnyFunctionRef::bodyHasExplicitReturnStmt() const {
auto *body = getBody();
if (!body) {
return false;
}
};
} // end anonymous namespace

bool BraceHasReturnRequest::evaluate(Evaluator &evaluator,
const BraceStmt *BS) const {
return !ReturnStmtFinder::find(BS).empty();
auto &ctx = getAsDeclContext()->getASTContext();
return evaluateOrDefault(ctx.evaluator,
BraceHasExplicitReturnStmtRequest{body}, false);
}

std::vector<ReturnStmt *> TypeChecker::findReturnStatements(AnyFunctionRef fn) {
return ReturnStmtFinder::find(fn.getBody());
void AnyFunctionRef::getExplicitReturnStmts(
SmallVectorImpl<ReturnStmt *> &results) const {
if (!bodyHasExplicitReturnStmt()) {
return;
}

walkExplicitReturnStmts(getBody(), [&results](ReturnStmt *RS) {
results.push_back(RS);
return false;
});
}

ResultBuilderOpSupport TypeChecker::checkBuilderOpSupport(
Expand Down
3 changes: 2 additions & 1 deletion lib/Sema/CSDiagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8862,7 +8862,8 @@ bool ReferenceToInvalidDeclaration::diagnoseAsError() {
bool InvalidReturnInResultBuilderBody::diagnoseAsError() {
auto *closure = castToExpr<ClosureExpr>(getAnchor());

auto returnStmts = TypeChecker::findReturnStatements(closure);
SmallVector<ReturnStmt *> returnStmts;
closure->getExplicitReturnStmts(returnStmts);
assert(!returnStmts.empty());

auto loc = returnStmts.front()->getReturnLoc();
Expand Down
Loading