Skip to content

[5.5] [Refactoring] Support async for function extraction #37670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions include/swift/AST/Effects.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define SWIFT_EFFECTS_H

#include "swift/AST/Type.h"
#include "swift/Basic/OptionSet.h"

#include <utility>

Expand All @@ -34,11 +35,14 @@ class raw_ostream;
}

namespace swift {
class AbstractFunctionDecl;
class ProtocolDecl;

enum class EffectKind : uint8_t {
Throws,
Async
Throws = 1 << 0,
Async = 1 << 1
};
using PossibleEffects = OptionSet<EffectKind>;

void simple_display(llvm::raw_ostream &out, const EffectKind kind);

Expand Down
9 changes: 5 additions & 4 deletions include/swift/IDE/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "swift/Basic/LLVM.h"
#include "swift/AST/ASTNode.h"
#include "swift/AST/DeclNameLoc.h"
#include "swift/AST/Effects.h"
#include "swift/AST/Module.h"
#include "swift/AST/ASTPrinter.h"
#include "swift/IDE/SourceEntityWalker.h"
Expand Down Expand Up @@ -345,7 +346,7 @@ struct ResolvedRangeInfo {
ArrayRef<Token> TokensInRange;
CharSourceRange ContentRange;
bool HasSingleEntry;
bool ThrowingUnhandledError;
PossibleEffects UnhandledEffects;
OrphanKind Orphan;

// The topmost ast nodes contained in the given range.
Expand All @@ -359,15 +360,15 @@ struct ResolvedRangeInfo {
ArrayRef<Token> TokensInRange,
DeclContext* RangeContext,
Expr *CommonExprParent, bool HasSingleEntry,
bool ThrowingUnhandledError,
PossibleEffects UnhandledEffects,
OrphanKind Orphan, ArrayRef<ASTNode> ContainedNodes,
ArrayRef<DeclaredDecl> DeclaredDecls,
ArrayRef<ReferencedDecl> ReferencedDecls): Kind(Kind),
ExitInfo(ExitInfo),
TokensInRange(TokensInRange),
ContentRange(calculateContentRange(TokensInRange)),
HasSingleEntry(HasSingleEntry),
ThrowingUnhandledError(ThrowingUnhandledError),
UnhandledEffects(UnhandledEffects),
Orphan(Orphan), ContainedNodes(ContainedNodes),
DeclaredDecls(DeclaredDecls),
ReferencedDecls(ReferencedDecls),
Expand All @@ -376,7 +377,7 @@ struct ResolvedRangeInfo {
ResolvedRangeInfo(ArrayRef<Token> TokensInRange) :
ResolvedRangeInfo(RangeKind::Invalid, {nullptr, ExitState::Unsure},
TokensInRange, nullptr, /*Commom Expr Parent*/nullptr,
/*Single entry*/true, /*unhandled error*/false,
/*Single entry*/true, /*UnhandledEffects*/{},
OrphanKind::None, {}, {}, {}) {}
ResolvedRangeInfo(): ResolvedRangeInfo(ArrayRef<Token>()) {}
void print(llvm::raw_ostream &OS) const;
Expand Down
57 changes: 31 additions & 26 deletions lib/IDE/IDERequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "swift/AST/ASTPrinter.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Effects.h"
#include "swift/AST/NameLookup.h"
#include "swift/AST/ASTDemangler.h"
#include "swift/Basic/SourceManager.h"
Expand Down Expand Up @@ -377,41 +378,45 @@ class RangeResolver : public SourceEntityWalker {
ResolvedRangeInfo resolve();
};

static bool hasUnhandledError(ArrayRef<ASTNode> Nodes) {
class ThrowingEntityAnalyzer : public SourceEntityWalker {
bool Throwing;
static PossibleEffects getUnhandledEffects(ArrayRef<ASTNode> Nodes) {
class EffectsAnalyzer : public SourceEntityWalker {
PossibleEffects Effects;
public:
ThrowingEntityAnalyzer(): Throwing(false) {}
bool walkToStmtPre(Stmt *S) override {
if (auto DCS = dyn_cast<DoCatchStmt>(S)) {
if (DCS->isSyntacticallyExhaustive())
return false;
Throwing = true;
Effects |= EffectKind::Throws;
} else if (isa<ThrowStmt>(S)) {
Throwing = true;
Effects |= EffectKind::Throws;
}
return !Throwing;
return true;
}
bool walkToExprPre(Expr *E) override {
if (isa<TryExpr>(E)) {
Throwing = true;
}
return !Throwing;
// Don't walk into closures, they only produce effects when called.
if (isa<ClosureExpr>(E))
return false;

if (isa<TryExpr>(E))
Effects |= EffectKind::Throws;
if (isa<AwaitExpr>(E))
Effects |= EffectKind::Async;

return true;
}
bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
return false;
}
bool walkToDeclPost(Decl *D) override { return !Throwing; }
bool walkToStmtPost(Stmt *S) override { return !Throwing; }
bool walkToExprPost(Expr *E) override { return !Throwing; }
bool isThrowing() { return Throwing; }
PossibleEffects getEffects() const { return Effects; }
};

return Nodes.end() != std::find_if(Nodes.begin(), Nodes.end(), [](ASTNode N) {
ThrowingEntityAnalyzer Analyzer;
PossibleEffects Effects;
for (auto N : Nodes) {
EffectsAnalyzer Analyzer;
Analyzer.walk(N);
return Analyzer.isThrowing();
});
Effects |= Analyzer.getEffects();
}
return Effects;
}

struct RangeResolver::Implementation {
Expand Down Expand Up @@ -549,7 +554,7 @@ struct RangeResolver::Implementation {
assert(ContainedASTNodes.size() == 1);
// Single node implies single entry point, or is it?
bool SingleEntry = true;
bool UnhandledError = hasUnhandledError({Node});
auto UnhandledEffects = getUnhandledEffects({Node});
OrphanKind Kind = getOrphanKind(ContainedASTNodes);
if (Node.is<Expr*>())
return ResolvedRangeInfo(RangeKind::SingleExpression,
Expand All @@ -558,7 +563,7 @@ struct RangeResolver::Implementation {
getImmediateContext(),
/*Common Parent Expr*/nullptr,
SingleEntry,
UnhandledError, Kind,
UnhandledEffects, Kind,
llvm::makeArrayRef(ContainedASTNodes),
llvm::makeArrayRef(DeclaredDecls),
llvm::makeArrayRef(ReferencedDecls));
Expand All @@ -569,7 +574,7 @@ struct RangeResolver::Implementation {
getImmediateContext(),
/*Common Parent Expr*/nullptr,
SingleEntry,
UnhandledError, Kind,
UnhandledEffects, Kind,
llvm::makeArrayRef(ContainedASTNodes),
llvm::makeArrayRef(DeclaredDecls),
llvm::makeArrayRef(ReferencedDecls));
Expand All @@ -581,7 +586,7 @@ struct RangeResolver::Implementation {
getImmediateContext(),
/*Common Parent Expr*/nullptr,
SingleEntry,
UnhandledError, Kind,
UnhandledEffects, Kind,
llvm::makeArrayRef(ContainedASTNodes),
llvm::makeArrayRef(DeclaredDecls),
llvm::makeArrayRef(ReferencedDecls));
Expand Down Expand Up @@ -642,7 +647,7 @@ struct RangeResolver::Implementation {
getImmediateContext(),
Parent,
hasSingleEntryPoint(ContainedASTNodes),
hasUnhandledError(ContainedASTNodes),
getUnhandledEffects(ContainedASTNodes),
getOrphanKind(ContainedASTNodes),
llvm::makeArrayRef(ContainedASTNodes),
llvm::makeArrayRef(DeclaredDecls),
Expand Down Expand Up @@ -889,7 +894,7 @@ struct RangeResolver::Implementation {
TokensInRange,
getImmediateContext(), nullptr,
hasSingleEntryPoint(ContainedASTNodes),
hasUnhandledError(ContainedASTNodes),
getUnhandledEffects(ContainedASTNodes),
getOrphanKind(ContainedASTNodes),
llvm::makeArrayRef(ContainedASTNodes),
llvm::makeArrayRef(DeclaredDecls),
Expand All @@ -904,7 +909,7 @@ struct RangeResolver::Implementation {
getImmediateContext(),
/*Common Parent Expr*/ nullptr,
/*SinleEntry*/ true,
hasUnhandledError(ContainedASTNodes),
getUnhandledEffects(ContainedASTNodes),
getOrphanKind(ContainedASTNodes),
llvm::makeArrayRef(ContainedASTNodes),
llvm::makeArrayRef(DeclaredDecls),
Expand Down
10 changes: 9 additions & 1 deletion lib/IDE/Refactoring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,9 @@ bool RefactoringActionExtractFunction::performChange() {
}
OS << ")";

if (RangeInfo.ThrowingUnhandledError)
if (RangeInfo.UnhandledEffects.contains(EffectKind::Async))
OS << " async";
if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws))
OS << " " << tok::kw_throws;

bool InsertedReturnType = false;
Expand Down Expand Up @@ -1332,6 +1334,12 @@ bool RefactoringActionExtractFunction::performChange() {
llvm::raw_svector_ostream OS(Buffer);
if (RangeInfo.exit() == ExitState::Positive)
OS << tok::kw_return <<" ";

if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws))
OS << tok::kw_try << " ";
if (RangeInfo.UnhandledEffects.contains(EffectKind::Async))
OS << "await ";

CallNameOffset = Buffer.size() - ReplaceBegin;
OS << PreferredName << "(";
for (auto &RD : Parameters) {
Expand Down
5 changes: 4 additions & 1 deletion lib/IDE/SwiftSourceDocInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,9 +724,12 @@ void ResolvedRangeInfo::print(llvm::raw_ostream &OS) const {
OS << "<Entry>Multi</Entry>\n";
}

if (ThrowingUnhandledError) {
if (UnhandledEffects.contains(EffectKind::Throws)) {
OS << "<Error>Throwing</Error>\n";
}
if (UnhandledEffects.contains(EffectKind::Async)) {
OS << "<Effect>Async</Effect>\n";
}

if (Orphan != OrphanKind::None) {
OS << "<Orphan>";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
func longLongLongJourney() async -> Int { 0 }
func longLongLongAwryJourney() async throws -> Int { 0 }
func consumesAsync(_ fn: () async throws -> Void) rethrows {}

fileprivate func new_name() async -> Int {
return await longLongLongJourney()
}

func testThrowingClosure() async throws -> Int {
let x = await new_name()
let y = try await longLongLongAwryJourney() + 1
try consumesAsync { try await longLongLongAwryJourney() }
return x + y
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
func longLongLongJourney() async -> Int { 0 }
func longLongLongAwryJourney() async throws -> Int { 0 }
func consumesAsync(_ fn: () async throws -> Void) rethrows {}

fileprivate func new_name() async throws -> Int {
return try await longLongLongAwryJourney() + 1
}

func testThrowingClosure() async throws -> Int {
let x = await longLongLongJourney()
let y = try await new_name()
try consumesAsync { try await longLongLongAwryJourney() }
return x + y
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
func longLongLongJourney() async -> Int { 0 }
func longLongLongAwryJourney() async throws -> Int { 0 }
func consumesAsync(_ fn: () async throws -> Void) rethrows {}

fileprivate func new_name() throws {
try consumesAsync { try await longLongLongAwryJourney() }
}

func testThrowingClosure() async throws -> Int {
let x = await longLongLongJourney()
let y = try await longLongLongAwryJourney() + 1
try new_name()
return x + y
}

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ func foo2() throws {
do {
try foo1()
} catch {}
new_name()
try new_name()
}

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ return try foo1()
}

func foo2() throws {
new_name()
try new_name()
try! foo1()
do {
try foo1()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
enum Err : Error {
case wat
}

func throwsSomething() throws { throw Err.wat }
func consumesErrClosure(_ fn: () throws -> Void) {}
func rethrowsErrClosure(_ fn: () throws -> Void) rethrows {}

fileprivate func new_name() {
consumesErrClosure { throw Err.wat }
consumesErrClosure { try throwsSomething() }
}

func testThrowingClosure() throws {
new_name()
try rethrowsErrClosure { try throwsSomething() }
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
enum Err : Error {
case wat
}

func throwsSomething() throws { throw Err.wat }
func consumesErrClosure(_ fn: () throws -> Void) {}
func rethrowsErrClosure(_ fn: () throws -> Void) rethrows {}

fileprivate func new_name() throws {
consumesErrClosure { throw Err.wat }
consumesErrClosure { try throwsSomething() }
try rethrowsErrClosure { try throwsSomething() }
}

func testThrowingClosure() throws {
try new_name()
}

18 changes: 18 additions & 0 deletions test/refactoring/ExtractFunction/await.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
func longLongLongJourney() async -> Int { 0 }
func longLongLongAwryJourney() async throws -> Int { 0 }
func consumesAsync(_ fn: () async throws -> Void) rethrows {}

func testThrowingClosure() async throws -> Int {
let x = await longLongLongJourney()
let y = try await longLongLongAwryJourney() + 1
try consumesAsync { try await longLongLongAwryJourney() }
return x + y
}

// RUN: %empty-directory(%t.result)
// RUN: %refactor -extract-function -source-filename %s -pos=6:11 -end-pos=6:38 >> %t.result/async1.swift
// RUN: diff -u %S/Outputs/await/async1.swift.expected %t.result/async1.swift
// RUN: %refactor -extract-function -source-filename %s -pos=7:11 -end-pos=7:50 >> %t.result/async2.swift
// RUN: diff -u %S/Outputs/await/async2.swift.expected %t.result/async2.swift
// RUN: %refactor -extract-function -source-filename %s -pos=8:1 -end-pos=8:60 >> %t.result/consumes_async.swift
// RUN: diff -u %S/Outputs/await/consumes_async.swift.expected %t.result/consumes_async.swift
19 changes: 19 additions & 0 deletions test/refactoring/ExtractFunction/throw_errors3.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
enum Err : Error {
case wat
}

func throwsSomething() throws { throw Err.wat }
func consumesErrClosure(_ fn: () throws -> Void) {}
func rethrowsErrClosure(_ fn: () throws -> Void) rethrows {}

func testThrowingClosure() throws {
consumesErrClosure { throw Err.wat }
consumesErrClosure { try throwsSomething() }
try rethrowsErrClosure { try throwsSomething() }
}

// RUN: %empty-directory(%t.result)
// RUN: %refactor -extract-function -source-filename %s -pos=10:1 -end-pos=11:47 >> %t.result/consumes_err.swift
// RUN: diff -u %S/Outputs/throw_errors3/consumes_err.swift.expected %t.result/consumes_err.swift
// RUN: %refactor -extract-function -source-filename %s -pos=10:1 -end-pos=12:51 >> %t.result/rethrows_err.swift
// RUN: diff -u %S/Outputs/throw_errors3/rethrows_err.swift.expected %t.result/rethrows_err.swift