Skip to content

Commit cb69f2e

Browse files
authored
Merge pull request swiftlang#37491 from hamishknight/so-in-sync
2 parents 7739d78 + 10038b6 commit cb69f2e

File tree

6 files changed

+459
-48
lines changed

6 files changed

+459
-48
lines changed

include/swift/AST/ParameterList.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ class alignas(ParamDecl *) ParameterList final :
7676
iterator end() { return getArray().end(); }
7777
const_iterator begin() const { return getArray().begin(); }
7878
const_iterator end() const { return getArray().end(); }
79-
79+
80+
ParamDecl *front() const { return getArray().front(); }
81+
ParamDecl *back() const { return getArray().back(); }
82+
8083
MutableArrayRef<ParamDecl*> getArray() {
8184
return {getTrailingObjects<ParamDecl*>(), numParameters};
8285
}

include/swift/IDE/RefactoringKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ CURSOR_REFACTORING(ConvertToAsync, "Convert Function to Async", convert.func-to-
6060

6161
CURSOR_REFACTORING(AddAsyncAlternative, "Add Async Alternative", add.async-alternative)
6262

63+
CURSOR_REFACTORING(AddAsyncWrapper, "Add Async Wrapper", add.async-wrapper)
64+
6365
RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr)
6466

6567
RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function)

lib/IDE/Refactoring.cpp

Lines changed: 262 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4195,6 +4195,13 @@ struct AsyncHandlerDesc {
41954195
return params();
41964196
}
41974197

4198+
/// If the completion handler has an Error parameter, return it.
4199+
Optional<AnyFunctionType::Param> getErrorParam() const {
4200+
if (HasError && Type == HandlerType::PARAMS)
4201+
return params().back();
4202+
return None;
4203+
}
4204+
41984205
/// Get the type of the error that will be thrown by the \c async method or \c
41994206
/// None if the completion handler doesn't accept an error parameter.
42004207
/// This may be more specialized than the generic 'Error' type if the
@@ -5397,6 +5404,41 @@ class AsyncConverter : private SourceEntityWalker {
53975404
return true;
53985405
}
53995406

5407+
/// Creates an async alternative function that forwards onto the completion
5408+
/// handler function through
5409+
/// withCheckedContinuation/withCheckedThrowingContinuation.
5410+
bool createAsyncWrapper() {
5411+
assert(Buffer.empty() && "AsyncConverter can only be used once");
5412+
auto *FD = cast<FuncDecl>(StartNode.get<Decl *>());
5413+
5414+
// First add the new async function declaration.
5415+
addFuncDecl(FD);
5416+
OS << tok::l_brace << "\n";
5417+
5418+
// Then add the body.
5419+
OS << tok::kw_return << " ";
5420+
if (TopHandler.HasError)
5421+
OS << tok::kw_try << " ";
5422+
5423+
OS << "await ";
5424+
5425+
// withChecked[Throwing]Continuation { cont in
5426+
if (TopHandler.HasError) {
5427+
OS << "withCheckedThrowingContinuation";
5428+
} else {
5429+
OS << "withCheckedContinuation";
5430+
}
5431+
OS << " " << tok::l_brace << " cont " << tok::kw_in << "\n";
5432+
5433+
// fnWithHandler(args...) { ... }
5434+
auto ClosureStr = getAsyncWrapperCompletionClosure("cont", TopHandler);
5435+
addForwardingCallTo(FD, TopHandler, /*HandlerReplacement*/ ClosureStr);
5436+
5437+
OS << tok::r_brace << "\n"; // end continuation closure
5438+
OS << tok::r_brace << "\n"; // end function body
5439+
return true;
5440+
}
5441+
54005442
void replace(ASTNode Node, SourceEditConsumer &EditConsumer,
54015443
SourceLoc StartOverride = SourceLoc()) {
54025444
SourceRange Range = Node.getSourceRange();
@@ -5432,6 +5474,130 @@ class AsyncConverter : private SourceEntityWalker {
54325474
return TopHandler.isValid();
54335475
}
54345476

5477+
/// Prints a tuple of elements, or a lone single element if only one is
5478+
/// present, using the provided printing function.
5479+
template <typename T, typename PrintFn>
5480+
void addTupleOf(ArrayRef<T> Elements, llvm::raw_ostream &OS,
5481+
PrintFn PrintElt) {
5482+
if (Elements.size() == 1) {
5483+
PrintElt(Elements[0]);
5484+
return;
5485+
}
5486+
OS << tok::l_paren;
5487+
llvm::interleave(Elements, PrintElt, [&]() { OS << tok::comma << " "; });
5488+
OS << tok::r_paren;
5489+
}
5490+
5491+
/// Retrieve the completion handler closure argument for an async wrapper
5492+
/// function.
5493+
std::string
5494+
getAsyncWrapperCompletionClosure(StringRef ContName,
5495+
const AsyncHandlerParamDesc &HandlerDesc) {
5496+
std::string OutputStr;
5497+
llvm::raw_string_ostream OS(OutputStr);
5498+
5499+
OS << " " << tok::l_brace; // start closure
5500+
5501+
// Prepare parameter names for the closure.
5502+
auto SuccessParams = HandlerDesc.getSuccessParams();
5503+
SmallVector<SmallString<4>, 2> SuccessParamNames;
5504+
for (auto idx : indices(SuccessParams)) {
5505+
SuccessParamNames.emplace_back("res");
5506+
5507+
// If we have multiple success params, number them e.g res1, res2...
5508+
if (SuccessParams.size() > 1)
5509+
SuccessParamNames.back().append(std::to_string(idx + 1));
5510+
}
5511+
Optional<SmallString<4>> ErrName;
5512+
if (HandlerDesc.getErrorParam())
5513+
ErrName.emplace("err");
5514+
5515+
auto HasAnyParams = !SuccessParamNames.empty() || ErrName;
5516+
if (HasAnyParams)
5517+
OS << " ";
5518+
5519+
// res1, res2
5520+
llvm::interleave(
5521+
SuccessParamNames, [&](auto Name) { OS << Name; },
5522+
[&]() { OS << tok::comma << " "; });
5523+
5524+
// , err
5525+
if (ErrName) {
5526+
if (!SuccessParamNames.empty())
5527+
OS << tok::comma << " ";
5528+
5529+
OS << *ErrName;
5530+
}
5531+
if (HasAnyParams)
5532+
OS << " " << tok::kw_in;
5533+
5534+
OS << "\n";
5535+
5536+
// The closure body.
5537+
switch (HandlerDesc.Type) {
5538+
case HandlerType::PARAMS: {
5539+
// For a (Success?, Error?) -> Void handler, we do an if let on the error.
5540+
if (ErrName) {
5541+
// if let err = err {
5542+
OS << tok::kw_if << " " << tok::kw_let << " ";
5543+
OS << *ErrName << " " << tok::equal << " " << *ErrName << " ";
5544+
OS << tok::l_brace << "\n";
5545+
5546+
// cont.resume(throwing: err)
5547+
OS << ContName << tok::period << "resume" << tok::l_paren;
5548+
OS << "throwing" << tok::colon << " " << *ErrName;
5549+
OS << tok::r_paren << "\n";
5550+
5551+
// return }
5552+
OS << tok::kw_return << "\n";
5553+
OS << tok::r_brace << "\n";
5554+
}
5555+
5556+
// If we have any success params that we need to unwrap, insert a guard.
5557+
for (auto Idx : indices(SuccessParamNames)) {
5558+
auto &Name = SuccessParamNames[Idx];
5559+
auto ParamTy = SuccessParams[Idx].getParameterType();
5560+
if (!HandlerDesc.shouldUnwrap(ParamTy))
5561+
continue;
5562+
5563+
// guard let res = res else {
5564+
OS << tok::kw_guard << " " << tok::kw_let << " ";
5565+
OS << Name << " " << tok::equal << " " << Name << " " << tok::kw_else;
5566+
OS << " " << tok::l_brace << "\n";
5567+
5568+
// fatalError(...)
5569+
OS << "fatalError" << tok::l_paren;
5570+
OS << "\"Expected non-nil success param '" << Name;
5571+
OS << "' for nil error\"";
5572+
OS << tok::r_paren << "\n";
5573+
5574+
// End guard.
5575+
OS << tok::r_brace << "\n";
5576+
}
5577+
5578+
// cont.resume(returning: (res1, res2, ...))
5579+
OS << ContName << tok::period << "resume" << tok::l_paren;
5580+
OS << "returning" << tok::colon << " ";
5581+
addTupleOf(llvm::makeArrayRef(SuccessParamNames), OS,
5582+
[&](auto Ref) { OS << Ref; });
5583+
OS << tok::r_paren << "\n";
5584+
break;
5585+
}
5586+
case HandlerType::RESULT: {
5587+
// cont.resume(with: res)
5588+
assert(SuccessParamNames.size() == 1);
5589+
OS << ContName << tok::period << "resume" << tok::l_paren;
5590+
OS << "with" << tok::colon << " " << SuccessParamNames[0];
5591+
OS << tok::r_paren << "\n";
5592+
break;
5593+
}
5594+
case HandlerType::INVALID:
5595+
llvm_unreachable("Should not have an invalid handler here");
5596+
}
5597+
5598+
OS << tok::r_brace << "\n"; // end closure
5599+
return OutputStr;
5600+
}
54355601

54365602
/// Retrieves the location for the start of a comment attached to the token
54375603
/// at the provided location, or the location itself if there is no comment.
@@ -6075,16 +6241,9 @@ class AsyncConverter : private SourceEntityWalker {
60756241
}
60766242
OS << " ";
60776243
}
6078-
if (SuccessParams.size() > 1)
6079-
OS << tok::l_paren;
6080-
OS << newNameFor(SuccessParams.front());
6081-
for (const auto Param : SuccessParams.drop_front()) {
6082-
OS << tok::comma << " ";
6083-
OS << newNameFor(Param);
6084-
}
6085-
if (SuccessParams.size() > 1) {
6086-
OS << tok::r_paren;
6087-
}
6244+
// 'res =' or '(res1, res2, ...) ='
6245+
addTupleOf(SuccessParams, OS,
6246+
[&](auto &Param) { OS << newNameFor(Param); });
60886247
OS << " " << tok::equal << " ";
60896248
}
60906249

@@ -6271,22 +6430,46 @@ class AsyncConverter : private SourceEntityWalker {
62716430
/// 'await' keyword.
62726431
void addCallToAsyncMethod(const FuncDecl *FD,
62736432
const AsyncHandlerDesc &HandlerDesc) {
6433+
// The call to the async function is the same as the call to the old
6434+
// completion handler function, minus the completion handler arg.
6435+
addForwardingCallTo(FD, HandlerDesc, /*HandlerReplacement*/ "");
6436+
}
6437+
6438+
/// Adds a forwarding call to the old completion handler function, with
6439+
/// \p HandlerReplacement that allows for a custom replacement or, if empty,
6440+
/// removal of the completion handler closure.
6441+
void addForwardingCallTo(
6442+
const FuncDecl *FD, const AsyncHandlerDesc &HandlerDesc,
6443+
StringRef HandlerReplacement, bool CanUseTrailingClosure = true) {
62746444
OS << FD->getBaseName() << tok::l_paren;
6275-
bool FirstParam = true;
6276-
for (auto Param : *FD->getParameters()) {
6445+
6446+
auto *Params = FD->getParameters();
6447+
for (auto Param : *Params) {
62776448
if (Param == HandlerDesc.getHandler()) {
6278-
/// We don't need to pass the completion handler to the async method.
6279-
continue;
6449+
/// If we're not replacing the handler with anything, drop it.
6450+
if (HandlerReplacement.empty())
6451+
continue;
6452+
6453+
// If this is the last param, and we can use a trailing closure, do so.
6454+
if (CanUseTrailingClosure && Param == Params->back()) {
6455+
OS << tok::r_paren << " ";
6456+
OS << HandlerReplacement;
6457+
return;
6458+
}
6459+
// Otherwise fall through to do the replacement.
62806460
}
6281-
if (!FirstParam) {
6461+
6462+
if (Param != Params->front())
62826463
OS << tok::comma << " ";
6283-
} else {
6284-
FirstParam = false;
6285-
}
6286-
if (!Param->getArgumentName().empty()) {
6464+
6465+
if (!Param->getArgumentName().empty())
62876466
OS << Param->getArgumentName() << tok::colon << " ";
6467+
6468+
if (Param == HandlerDesc.getHandler()) {
6469+
OS << HandlerReplacement;
6470+
} else {
6471+
OS << Param->getParameterName();
62886472
}
6289-
OS << Param->getParameterName();
62906473
}
62916474
OS << tok::r_paren;
62926475
}
@@ -6408,19 +6591,10 @@ class AsyncConverter : private SourceEntityWalker {
64086591
/// Adds the result type of a refactored async function that previously
64096592
/// returned results via a completion handler described by \p HandlerDesc.
64106593
void addAsyncFuncReturnType(const AsyncHandlerDesc &HandlerDesc) {
6594+
// Type or (Type1, Type2, ...)
64116595
SmallVector<Type, 2> Scratch;
6412-
auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
6413-
if (ReturnTypes.size() > 1) {
6414-
OS << tok::l_paren;
6415-
}
6416-
6417-
llvm::interleave(
6418-
ReturnTypes, [&](Type Ty) { Ty->print(OS); },
6419-
[&]() { OS << tok::comma << " "; });
6420-
6421-
if (ReturnTypes.size() > 1) {
6422-
OS << tok::r_paren;
6423-
}
6596+
addTupleOf(HandlerDesc.getAsyncReturnTypes(Scratch), OS,
6597+
[&](auto Ty) { Ty->print(OS); });
64246598
}
64256599

64266600
/// If \p FD is generic, adds a type annotation with the return type of the
@@ -6450,6 +6624,24 @@ class AsyncConverter : private SourceEntityWalker {
64506624
}
64516625
};
64526626

6627+
/// Adds an attribute to describe a completion handler function's async
6628+
/// alternative if necessary.
6629+
void addCompletionHandlerAsyncAttrIfNeccessary(
6630+
ASTContext &Ctx, const FuncDecl *FD,
6631+
const AsyncHandlerParamDesc &HandlerDesc,
6632+
SourceEditConsumer &EditConsumer) {
6633+
if (!Ctx.LangOpts.EnableExperimentalConcurrency)
6634+
return;
6635+
6636+
llvm::SmallString<0> HandlerAttribute;
6637+
llvm::raw_svector_ostream OS(HandlerAttribute);
6638+
OS << "@completionHandlerAsync(\"";
6639+
HandlerDesc.printAsyncFunctionName(OS);
6640+
OS << "\", completionHandlerIndex: " << HandlerDesc.Index << ")\n";
6641+
EditConsumer.accept(Ctx.SourceMgr, FD->getAttributeInsertionLoc(false),
6642+
HandlerAttribute);
6643+
}
6644+
64536645
} // namespace asyncrefactorings
64546646

64556647
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
@@ -6571,16 +6763,7 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
65716763
"@available(*, deprecated, message: \"Prefer async "
65726764
"alternative instead\")\n");
65736765

6574-
if (Ctx.LangOpts.EnableExperimentalConcurrency) {
6575-
// Add an attribute to describe its async alternative
6576-
llvm::SmallString<0> HandlerAttribute;
6577-
llvm::raw_svector_ostream OS(HandlerAttribute);
6578-
OS << "@completionHandlerAsync(\"";
6579-
HandlerDesc.printAsyncFunctionName(OS);
6580-
OS << "\", completionHandlerIndex: " << HandlerDesc.Index << ")\n";
6581-
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
6582-
HandlerAttribute);
6583-
}
6766+
addCompletionHandlerAsyncAttrIfNeccessary(Ctx, FD, HandlerDesc, EditConsumer);
65846767

65856768
AsyncConverter LegacyBodyCreator(TheFile, SM, DiagEngine, FD, HandlerDesc);
65866769
if (LegacyBodyCreator.createLegacyBody()) {
@@ -6592,6 +6775,43 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
65926775

65936776
return false;
65946777
}
6778+
6779+
bool RefactoringActionAddAsyncWrapper::isApplicable(
6780+
const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) {
6781+
using namespace asyncrefactorings;
6782+
6783+
auto *FD = findFunction(CursorInfo);
6784+
if (!FD)
6785+
return false;
6786+
6787+
auto HandlerDesc =
6788+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
6789+
return HandlerDesc.isValid();
6790+
}
6791+
6792+
bool RefactoringActionAddAsyncWrapper::performChange() {
6793+
using namespace asyncrefactorings;
6794+
6795+
auto *FD = findFunction(CursorInfo);
6796+
assert(FD &&
6797+
"Should not run performChange when refactoring is not applicable");
6798+
6799+
auto HandlerDesc =
6800+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
6801+
assert(HandlerDesc.isValid() &&
6802+
"Should not run performChange when refactoring is not applicable");
6803+
6804+
AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
6805+
if (!Converter.createAsyncWrapper())
6806+
return true;
6807+
6808+
addCompletionHandlerAsyncAttrIfNeccessary(Ctx, FD, HandlerDesc, EditConsumer);
6809+
6810+
// Add the async wrapper.
6811+
Converter.insertAfter(FD, EditConsumer);
6812+
return false;
6813+
}
6814+
65956815
} // end of anonymous namespace
65966816

65976817
StringRef swift::ide::

0 commit comments

Comments
 (0)