Skip to content

Commit af0b164

Browse files
committed
[Refactoring] Add async wrapper refactoring action
This allows an async alternative function to be created that forwards onto the user's completion handler function through the use of `withCheckedContinuation`/`withCheckedThrowingContinuation`. rdar://77802486
1 parent b6c986a commit af0b164

File tree

5 files changed

+397
-11
lines changed

5 files changed

+397
-11
lines changed

include/swift/IDE/RefactoringKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ CURSOR_REFACTORING(ConvertToAsync, "Convert Function to Async", convert.func-to-
6060

6161
CURSOR_REFACTORING(AddAsyncAlternative, "Add Async Alternative", add.async-alternative)
6262

63+
CURSOR_REFACTORING(AddAsyncWrapper, "Add Async Wrapper", add.async-wrapper)
64+
6365
RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr)
6466

6567
RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function)

lib/IDE/Refactoring.cpp

Lines changed: 208 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4204,6 +4204,13 @@ struct AsyncHandlerDesc {
42044204
return params();
42054205
}
42064206

4207+
/// If the completion handler has an Error parameter, return it.
4208+
Optional<AnyFunctionType::Param> getErrorParam() const {
4209+
if (HasError && Type == HandlerType::PARAMS)
4210+
return params().back();
4211+
return None;
4212+
}
4213+
42074214
/// Get the type of the error that will be thrown by the \c async method or \c
42084215
/// None if the completion handler doesn't accept an error parameter.
42094216
/// This may be more specialized than the generic 'Error' type if the
@@ -5405,6 +5412,41 @@ class AsyncConverter : private SourceEntityWalker {
54055412
return true;
54065413
}
54075414

5415+
/// Creates an async alternative function that forwards onto the completion
5416+
/// handler function through
5417+
/// withCheckedContinuation/withCheckedThrowingContinuation.
5418+
bool createAsyncWrapper() {
5419+
assert(Buffer.empty() && "AsyncConverter can only be used once");
5420+
auto *FD = cast<FuncDecl>(StartNode.get<Decl *>());
5421+
5422+
// First add the new async function declaration.
5423+
addFuncDecl(FD);
5424+
OS << tok::l_brace << "\n";
5425+
5426+
// Then add the body.
5427+
OS << tok::kw_return << " ";
5428+
if (TopHandler.HasError)
5429+
OS << tok::kw_try << " ";
5430+
5431+
OS << "await ";
5432+
5433+
// withChecked[Throwing]Continuation { cont in
5434+
if (TopHandler.HasError) {
5435+
OS << "withCheckedThrowingContinuation";
5436+
} else {
5437+
OS << "withCheckedContinuation";
5438+
}
5439+
OS << " " << tok::l_brace << " cont " << tok::kw_in << "\n";
5440+
5441+
// fnWithHandler(args...) { ... }
5442+
auto ClosureStr = getAsyncWrapperCompletionClosure("cont", TopHandler);
5443+
addForwardingCallTo(FD, TopHandler, /*HandlerReplacement*/ ClosureStr);
5444+
5445+
OS << tok::r_brace << "\n"; // end continuation closure
5446+
OS << tok::r_brace << "\n"; // end function body
5447+
return true;
5448+
}
5449+
54085450
void replace(ASTNode Node, SourceEditConsumer &EditConsumer,
54095451
SourceLoc StartOverride = SourceLoc()) {
54105452
SourceRange Range = Node.getSourceRange();
@@ -5454,6 +5496,116 @@ class AsyncConverter : private SourceEntityWalker {
54545496
OS << tok::r_paren;
54555497
}
54565498

5499+
/// Retrieve the completion handler closure argument for an async wrapper
5500+
/// function.
5501+
std::string
5502+
getAsyncWrapperCompletionClosure(StringRef ContName,
5503+
const AsyncHandlerParamDesc &HandlerDesc) {
5504+
std::string OutputStr;
5505+
llvm::raw_string_ostream OS(OutputStr);
5506+
5507+
OS << " " << tok::l_brace; // start closure
5508+
5509+
// Prepare parameter names for the closure.
5510+
auto SuccessParams = HandlerDesc.getSuccessParams();
5511+
SmallVector<SmallString<4>, 2> SuccessParamNames;
5512+
for (auto idx : indices(SuccessParams)) {
5513+
SuccessParamNames.emplace_back("res");
5514+
5515+
// If we have multiple success params, number them e.g res1, res2...
5516+
if (SuccessParams.size() > 1)
5517+
SuccessParamNames.back().append(std::to_string(idx + 1));
5518+
}
5519+
Optional<SmallString<4>> ErrName;
5520+
if (HandlerDesc.getErrorParam())
5521+
ErrName.emplace("err");
5522+
5523+
auto HasAnyParams = !SuccessParamNames.empty() || ErrName;
5524+
if (HasAnyParams)
5525+
OS << " ";
5526+
5527+
// res1, res2
5528+
llvm::interleave(
5529+
SuccessParamNames, [&](auto Name) { OS << Name; },
5530+
[&]() { OS << tok::comma << " "; });
5531+
5532+
// , err
5533+
if (ErrName) {
5534+
if (!SuccessParamNames.empty())
5535+
OS << tok::comma << " ";
5536+
5537+
OS << *ErrName;
5538+
}
5539+
if (HasAnyParams)
5540+
OS << " " << tok::kw_in;
5541+
5542+
OS << "\n";
5543+
5544+
// The closure body.
5545+
switch (HandlerDesc.Type) {
5546+
case HandlerType::PARAMS: {
5547+
// For a (Success?, Error?) -> Void handler, we do an if let on the error.
5548+
if (ErrName) {
5549+
// if let err = err {
5550+
OS << tok::kw_if << " " << tok::kw_let << " ";
5551+
OS << *ErrName << " " << tok::equal << " " << *ErrName << " ";
5552+
OS << tok::l_brace << "\n";
5553+
5554+
// cont.resume(throwing: err)
5555+
OS << ContName << tok::period << "resume" << tok::l_paren;
5556+
OS << "throwing" << tok::colon << " " << *ErrName;
5557+
OS << tok::r_paren << "\n";
5558+
5559+
// return }
5560+
OS << tok::kw_return << "\n";
5561+
OS << tok::r_brace << "\n";
5562+
}
5563+
5564+
// If we have any success params that we need to unwrap, insert a guard.
5565+
for (auto Idx : indices(SuccessParamNames)) {
5566+
auto &Name = SuccessParamNames[Idx];
5567+
auto ParamTy = SuccessParams[Idx].getParameterType();
5568+
if (!HandlerDesc.shouldUnwrap(ParamTy))
5569+
continue;
5570+
5571+
// guard let res = res else {
5572+
OS << tok::kw_guard << " " << tok::kw_let << " ";
5573+
OS << Name << " " << tok::equal << " " << Name << " " << tok::kw_else;
5574+
OS << " " << tok::l_brace << "\n";
5575+
5576+
// fatalError(...)
5577+
OS << "fatalError" << tok::l_paren;
5578+
OS << "\"Expected non-nil success param '" << Name;
5579+
OS << "' for nil error\"";
5580+
OS << tok::r_paren << "\n";
5581+
5582+
// End guard.
5583+
OS << tok::r_brace << "\n";
5584+
}
5585+
5586+
// cont.resume(returning: (res1, res2, ...))
5587+
OS << ContName << tok::period << "resume" << tok::l_paren;
5588+
OS << "returning" << tok::colon << " ";
5589+
addTupleOf(llvm::makeArrayRef(SuccessParamNames), OS,
5590+
[&](auto Ref) { OS << Ref; });
5591+
OS << tok::r_paren << "\n";
5592+
break;
5593+
}
5594+
case HandlerType::RESULT: {
5595+
// cont.resume(with: res)
5596+
assert(SuccessParamNames.size() == 1);
5597+
OS << ContName << tok::period << "resume" << tok::l_paren;
5598+
OS << "with" << tok::colon << " " << SuccessParamNames[0];
5599+
OS << tok::r_paren << "\n";
5600+
break;
5601+
}
5602+
case HandlerType::INVALID:
5603+
llvm_unreachable("Should not have an invalid handler here");
5604+
}
5605+
5606+
OS << tok::r_brace << "\n"; // end closure
5607+
return OutputStr;
5608+
}
54575609

54585610
/// Retrieves the location for the start of a comment attached to the token
54595611
/// at the provided location, or the location itself if there is no comment.
@@ -6480,6 +6632,24 @@ class AsyncConverter : private SourceEntityWalker {
64806632
}
64816633
};
64826634

6635+
/// Adds an attribute to describe a completion handler function's async
6636+
/// alternative if necessary.
6637+
void addCompletionHandlerAsyncAttrIfNeccessary(
6638+
ASTContext &Ctx, const FuncDecl *FD,
6639+
const AsyncHandlerParamDesc &HandlerDesc,
6640+
SourceEditConsumer &EditConsumer) {
6641+
if (!Ctx.LangOpts.EnableExperimentalConcurrency)
6642+
return;
6643+
6644+
llvm::SmallString<0> HandlerAttribute;
6645+
llvm::raw_svector_ostream OS(HandlerAttribute);
6646+
OS << "@completionHandlerAsync(\"";
6647+
HandlerDesc.printAsyncFunctionName(OS);
6648+
OS << "\", completionHandlerIndex: " << HandlerDesc.Index << ")\n";
6649+
EditConsumer.accept(Ctx.SourceMgr, FD->getAttributeInsertionLoc(false),
6650+
HandlerAttribute);
6651+
}
6652+
64836653
} // namespace asyncrefactorings
64846654

64856655
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
@@ -6601,16 +6771,7 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
66016771
"@available(*, deprecated, message: \"Prefer async "
66026772
"alternative instead\")\n");
66036773

6604-
if (Ctx.LangOpts.EnableExperimentalConcurrency) {
6605-
// Add an attribute to describe its async alternative
6606-
llvm::SmallString<0> HandlerAttribute;
6607-
llvm::raw_svector_ostream OS(HandlerAttribute);
6608-
OS << "@completionHandlerAsync(\"";
6609-
HandlerDesc.printAsyncFunctionName(OS);
6610-
OS << "\", completionHandlerIndex: " << HandlerDesc.Index << ")\n";
6611-
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
6612-
HandlerAttribute);
6613-
}
6774+
addCompletionHandlerAsyncAttrIfNeccessary(Ctx, FD, HandlerDesc, EditConsumer);
66146775

66156776
AsyncConverter LegacyBodyCreator(TheFile, SM, DiagEngine, FD, HandlerDesc);
66166777
if (LegacyBodyCreator.createLegacyBody()) {
@@ -6622,6 +6783,43 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
66226783

66236784
return false;
66246785
}
6786+
6787+
bool RefactoringActionAddAsyncWrapper::isApplicable(
6788+
const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) {
6789+
using namespace asyncrefactorings;
6790+
6791+
auto *FD = findFunction(CursorInfo);
6792+
if (!FD)
6793+
return false;
6794+
6795+
auto HandlerDesc =
6796+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
6797+
return HandlerDesc.isValid();
6798+
}
6799+
6800+
bool RefactoringActionAddAsyncWrapper::performChange() {
6801+
using namespace asyncrefactorings;
6802+
6803+
auto *FD = findFunction(CursorInfo);
6804+
assert(FD &&
6805+
"Should not run performChange when refactoring is not applicable");
6806+
6807+
auto HandlerDesc =
6808+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
6809+
assert(HandlerDesc.isValid() &&
6810+
"Should not run performChange when refactoring is not applicable");
6811+
6812+
AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
6813+
if (!Converter.createAsyncWrapper())
6814+
return true;
6815+
6816+
addCompletionHandlerAsyncAttrIfNeccessary(Ctx, FD, HandlerDesc, EditConsumer);
6817+
6818+
// Add the async wrapper.
6819+
Converter.insertAfter(FD, EditConsumer);
6820+
return false;
6821+
}
6822+
66256823
} // end of anonymous namespace
66266824

66276825
StringRef swift::ide::

test/SourceKit/Refactoring/basic.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,16 +244,20 @@ func hasCallToAsyncAlternative() {
244244
// CHECK-ASYNC-NEXT: Convert Function to Async
245245
// CHECK-ASYNC-NEXT: source.refactoring.kind.add.async-alternative
246246
// CHECK-ASYNC-NEXT: Add Async Alternative
247+
// CHECK-ASYNC-NEXT: source.refactoring.kind.add.async-wrapper
248+
// CHECK-ASYNC-NEXT: Add Async Wrapper
247249
// CHECK-ASYNC-NOT: source.refactoring.kind.convert.call-to-async
248250
// CHECK-ASYNC: ACTIONS END
249251

250252
// CHECK-CALLASYNC: ACTIONS BEGIN
251253
// CHECK-CALLASYNC-NOT: source.refactoring.kind.add.async-alternative
252254
// CHECK-CALLASYNC-NOT: source.refactoring.kind.convert.func-to-async
255+
// CHECK-CALLASYNC-NOT: source.refactoring.kind.add.async-wrapper
253256
// CHECK-CALLASYNC: source.refactoring.kind.convert.call-to-async
254257
// CHECK-CALLASYNC-NEXT: Convert Call to Async Alternative
255258
// CHECK-CALLASYNC-NOT: source.refactoring.kind.add.async-alternative
256259
// CHECK-CALLASYNC-NOT: source.refactoring.kind.convert.func-to-async
260+
// CHECK-CALLASYNC-NOT: source.refactoring.kind.add.async-wrapper
257261
// CHECK-CALLASYNC: ACTIONS END
258262

259263
// REQUIRES: OS=macosx || OS=linux-gnu

0 commit comments

Comments
 (0)