Skip to content

Commit 45d34a6

Browse files
authored
Merge pull request #37533 from hamishknight/so-in-sync-5.5
2 parents d2572cd + af0b164 commit 45d34a6

File tree

6 files changed

+459
-48
lines changed

6 files changed

+459
-48
lines changed

include/swift/AST/ParameterList.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ class alignas(ParamDecl *) ParameterList final :
7676
iterator end() { return getArray().end(); }
7777
const_iterator begin() const { return getArray().begin(); }
7878
const_iterator end() const { return getArray().end(); }
79-
79+
80+
ParamDecl *front() const { return getArray().front(); }
81+
ParamDecl *back() const { return getArray().back(); }
82+
8083
MutableArrayRef<ParamDecl*> getArray() {
8184
return {getTrailingObjects<ParamDecl*>(), numParameters};
8285
}

include/swift/IDE/RefactoringKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ CURSOR_REFACTORING(ConvertToAsync, "Convert Function to Async", convert.func-to-
6060

6161
CURSOR_REFACTORING(AddAsyncAlternative, "Add Async Alternative", add.async-alternative)
6262

63+
CURSOR_REFACTORING(AddAsyncWrapper, "Add Async Wrapper", add.async-wrapper)
64+
6365
RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr)
6466

6567
RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function)

lib/IDE/Refactoring.cpp

Lines changed: 262 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4204,6 +4204,13 @@ struct AsyncHandlerDesc {
42044204
return params();
42054205
}
42064206

4207+
/// If the completion handler has an Error parameter, return it.
4208+
Optional<AnyFunctionType::Param> getErrorParam() const {
4209+
if (HasError && Type == HandlerType::PARAMS)
4210+
return params().back();
4211+
return None;
4212+
}
4213+
42074214
/// Get the type of the error that will be thrown by the \c async method or \c
42084215
/// None if the completion handler doesn't accept an error parameter.
42094216
/// This may be more specialized than the generic 'Error' type if the
@@ -5405,6 +5412,41 @@ class AsyncConverter : private SourceEntityWalker {
54055412
return true;
54065413
}
54075414

5415+
/// Creates an async alternative function that forwards onto the completion
5416+
/// handler function through
5417+
/// withCheckedContinuation/withCheckedThrowingContinuation.
5418+
bool createAsyncWrapper() {
5419+
assert(Buffer.empty() && "AsyncConverter can only be used once");
5420+
auto *FD = cast<FuncDecl>(StartNode.get<Decl *>());
5421+
5422+
// First add the new async function declaration.
5423+
addFuncDecl(FD);
5424+
OS << tok::l_brace << "\n";
5425+
5426+
// Then add the body.
5427+
OS << tok::kw_return << " ";
5428+
if (TopHandler.HasError)
5429+
OS << tok::kw_try << " ";
5430+
5431+
OS << "await ";
5432+
5433+
// withChecked[Throwing]Continuation { cont in
5434+
if (TopHandler.HasError) {
5435+
OS << "withCheckedThrowingContinuation";
5436+
} else {
5437+
OS << "withCheckedContinuation";
5438+
}
5439+
OS << " " << tok::l_brace << " cont " << tok::kw_in << "\n";
5440+
5441+
// fnWithHandler(args...) { ... }
5442+
auto ClosureStr = getAsyncWrapperCompletionClosure("cont", TopHandler);
5443+
addForwardingCallTo(FD, TopHandler, /*HandlerReplacement*/ ClosureStr);
5444+
5445+
OS << tok::r_brace << "\n"; // end continuation closure
5446+
OS << tok::r_brace << "\n"; // end function body
5447+
return true;
5448+
}
5449+
54085450
void replace(ASTNode Node, SourceEditConsumer &EditConsumer,
54095451
SourceLoc StartOverride = SourceLoc()) {
54105452
SourceRange Range = Node.getSourceRange();
@@ -5440,6 +5482,130 @@ class AsyncConverter : private SourceEntityWalker {
54405482
return TopHandler.isValid();
54415483
}
54425484

5485+
/// Prints a tuple of elements, or a lone single element if only one is
5486+
/// present, using the provided printing function.
5487+
template <typename T, typename PrintFn>
5488+
void addTupleOf(ArrayRef<T> Elements, llvm::raw_ostream &OS,
5489+
PrintFn PrintElt) {
5490+
if (Elements.size() == 1) {
5491+
PrintElt(Elements[0]);
5492+
return;
5493+
}
5494+
OS << tok::l_paren;
5495+
llvm::interleave(Elements, PrintElt, [&]() { OS << tok::comma << " "; });
5496+
OS << tok::r_paren;
5497+
}
5498+
5499+
/// Retrieve the completion handler closure argument for an async wrapper
5500+
/// function.
5501+
std::string
5502+
getAsyncWrapperCompletionClosure(StringRef ContName,
5503+
const AsyncHandlerParamDesc &HandlerDesc) {
5504+
std::string OutputStr;
5505+
llvm::raw_string_ostream OS(OutputStr);
5506+
5507+
OS << " " << tok::l_brace; // start closure
5508+
5509+
// Prepare parameter names for the closure.
5510+
auto SuccessParams = HandlerDesc.getSuccessParams();
5511+
SmallVector<SmallString<4>, 2> SuccessParamNames;
5512+
for (auto idx : indices(SuccessParams)) {
5513+
SuccessParamNames.emplace_back("res");
5514+
5515+
// If we have multiple success params, number them e.g res1, res2...
5516+
if (SuccessParams.size() > 1)
5517+
SuccessParamNames.back().append(std::to_string(idx + 1));
5518+
}
5519+
Optional<SmallString<4>> ErrName;
5520+
if (HandlerDesc.getErrorParam())
5521+
ErrName.emplace("err");
5522+
5523+
auto HasAnyParams = !SuccessParamNames.empty() || ErrName;
5524+
if (HasAnyParams)
5525+
OS << " ";
5526+
5527+
// res1, res2
5528+
llvm::interleave(
5529+
SuccessParamNames, [&](auto Name) { OS << Name; },
5530+
[&]() { OS << tok::comma << " "; });
5531+
5532+
// , err
5533+
if (ErrName) {
5534+
if (!SuccessParamNames.empty())
5535+
OS << tok::comma << " ";
5536+
5537+
OS << *ErrName;
5538+
}
5539+
if (HasAnyParams)
5540+
OS << " " << tok::kw_in;
5541+
5542+
OS << "\n";
5543+
5544+
// The closure body.
5545+
switch (HandlerDesc.Type) {
5546+
case HandlerType::PARAMS: {
5547+
// For a (Success?, Error?) -> Void handler, we do an if let on the error.
5548+
if (ErrName) {
5549+
// if let err = err {
5550+
OS << tok::kw_if << " " << tok::kw_let << " ";
5551+
OS << *ErrName << " " << tok::equal << " " << *ErrName << " ";
5552+
OS << tok::l_brace << "\n";
5553+
5554+
// cont.resume(throwing: err)
5555+
OS << ContName << tok::period << "resume" << tok::l_paren;
5556+
OS << "throwing" << tok::colon << " " << *ErrName;
5557+
OS << tok::r_paren << "\n";
5558+
5559+
// return }
5560+
OS << tok::kw_return << "\n";
5561+
OS << tok::r_brace << "\n";
5562+
}
5563+
5564+
// If we have any success params that we need to unwrap, insert a guard.
5565+
for (auto Idx : indices(SuccessParamNames)) {
5566+
auto &Name = SuccessParamNames[Idx];
5567+
auto ParamTy = SuccessParams[Idx].getParameterType();
5568+
if (!HandlerDesc.shouldUnwrap(ParamTy))
5569+
continue;
5570+
5571+
// guard let res = res else {
5572+
OS << tok::kw_guard << " " << tok::kw_let << " ";
5573+
OS << Name << " " << tok::equal << " " << Name << " " << tok::kw_else;
5574+
OS << " " << tok::l_brace << "\n";
5575+
5576+
// fatalError(...)
5577+
OS << "fatalError" << tok::l_paren;
5578+
OS << "\"Expected non-nil success param '" << Name;
5579+
OS << "' for nil error\"";
5580+
OS << tok::r_paren << "\n";
5581+
5582+
// End guard.
5583+
OS << tok::r_brace << "\n";
5584+
}
5585+
5586+
// cont.resume(returning: (res1, res2, ...))
5587+
OS << ContName << tok::period << "resume" << tok::l_paren;
5588+
OS << "returning" << tok::colon << " ";
5589+
addTupleOf(llvm::makeArrayRef(SuccessParamNames), OS,
5590+
[&](auto Ref) { OS << Ref; });
5591+
OS << tok::r_paren << "\n";
5592+
break;
5593+
}
5594+
case HandlerType::RESULT: {
5595+
// cont.resume(with: res)
5596+
assert(SuccessParamNames.size() == 1);
5597+
OS << ContName << tok::period << "resume" << tok::l_paren;
5598+
OS << "with" << tok::colon << " " << SuccessParamNames[0];
5599+
OS << tok::r_paren << "\n";
5600+
break;
5601+
}
5602+
case HandlerType::INVALID:
5603+
llvm_unreachable("Should not have an invalid handler here");
5604+
}
5605+
5606+
OS << tok::r_brace << "\n"; // end closure
5607+
return OutputStr;
5608+
}
54435609

54445610
/// Retrieves the location for the start of a comment attached to the token
54455611
/// at the provided location, or the location itself if there is no comment.
@@ -6083,16 +6249,9 @@ class AsyncConverter : private SourceEntityWalker {
60836249
}
60846250
OS << " ";
60856251
}
6086-
if (SuccessParams.size() > 1)
6087-
OS << tok::l_paren;
6088-
OS << newNameFor(SuccessParams.front());
6089-
for (const auto Param : SuccessParams.drop_front()) {
6090-
OS << tok::comma << " ";
6091-
OS << newNameFor(Param);
6092-
}
6093-
if (SuccessParams.size() > 1) {
6094-
OS << tok::r_paren;
6095-
}
6252+
// 'res =' or '(res1, res2, ...) ='
6253+
addTupleOf(SuccessParams, OS,
6254+
[&](auto &Param) { OS << newNameFor(Param); });
60966255
OS << " " << tok::equal << " ";
60976256
}
60986257

@@ -6279,22 +6438,46 @@ class AsyncConverter : private SourceEntityWalker {
62796438
/// 'await' keyword.
62806439
void addCallToAsyncMethod(const FuncDecl *FD,
62816440
const AsyncHandlerDesc &HandlerDesc) {
6441+
// The call to the async function is the same as the call to the old
6442+
// completion handler function, minus the completion handler arg.
6443+
addForwardingCallTo(FD, HandlerDesc, /*HandlerReplacement*/ "");
6444+
}
6445+
6446+
/// Adds a forwarding call to the old completion handler function, with
6447+
/// \p HandlerReplacement that allows for a custom replacement or, if empty,
6448+
/// removal of the completion handler closure.
6449+
void addForwardingCallTo(
6450+
const FuncDecl *FD, const AsyncHandlerDesc &HandlerDesc,
6451+
StringRef HandlerReplacement, bool CanUseTrailingClosure = true) {
62826452
OS << FD->getBaseName() << tok::l_paren;
6283-
bool FirstParam = true;
6284-
for (auto Param : *FD->getParameters()) {
6453+
6454+
auto *Params = FD->getParameters();
6455+
for (auto Param : *Params) {
62856456
if (Param == HandlerDesc.getHandler()) {
6286-
/// We don't need to pass the completion handler to the async method.
6287-
continue;
6457+
/// If we're not replacing the handler with anything, drop it.
6458+
if (HandlerReplacement.empty())
6459+
continue;
6460+
6461+
// If this is the last param, and we can use a trailing closure, do so.
6462+
if (CanUseTrailingClosure && Param == Params->back()) {
6463+
OS << tok::r_paren << " ";
6464+
OS << HandlerReplacement;
6465+
return;
6466+
}
6467+
// Otherwise fall through to do the replacement.
62886468
}
6289-
if (!FirstParam) {
6469+
6470+
if (Param != Params->front())
62906471
OS << tok::comma << " ";
6291-
} else {
6292-
FirstParam = false;
6293-
}
6294-
if (!Param->getArgumentName().empty()) {
6472+
6473+
if (!Param->getArgumentName().empty())
62956474
OS << Param->getArgumentName() << tok::colon << " ";
6475+
6476+
if (Param == HandlerDesc.getHandler()) {
6477+
OS << HandlerReplacement;
6478+
} else {
6479+
OS << Param->getParameterName();
62966480
}
6297-
OS << Param->getParameterName();
62986481
}
62996482
OS << tok::r_paren;
63006483
}
@@ -6416,19 +6599,10 @@ class AsyncConverter : private SourceEntityWalker {
64166599
/// Adds the result type of a refactored async function that previously
64176600
/// returned results via a completion handler described by \p HandlerDesc.
64186601
void addAsyncFuncReturnType(const AsyncHandlerDesc &HandlerDesc) {
6602+
// Type or (Type1, Type2, ...)
64196603
SmallVector<Type, 2> Scratch;
6420-
auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
6421-
if (ReturnTypes.size() > 1) {
6422-
OS << tok::l_paren;
6423-
}
6424-
6425-
llvm::interleave(
6426-
ReturnTypes, [&](Type Ty) { Ty->print(OS); },
6427-
[&]() { OS << tok::comma << " "; });
6428-
6429-
if (ReturnTypes.size() > 1) {
6430-
OS << tok::r_paren;
6431-
}
6604+
addTupleOf(HandlerDesc.getAsyncReturnTypes(Scratch), OS,
6605+
[&](auto Ty) { Ty->print(OS); });
64326606
}
64336607

64346608
/// If \p FD is generic, adds a type annotation with the return type of the
@@ -6458,6 +6632,24 @@ class AsyncConverter : private SourceEntityWalker {
64586632
}
64596633
};
64606634

6635+
/// Adds an attribute to describe a completion handler function's async
6636+
/// alternative if necessary.
6637+
void addCompletionHandlerAsyncAttrIfNeccessary(
6638+
ASTContext &Ctx, const FuncDecl *FD,
6639+
const AsyncHandlerParamDesc &HandlerDesc,
6640+
SourceEditConsumer &EditConsumer) {
6641+
if (!Ctx.LangOpts.EnableExperimentalConcurrency)
6642+
return;
6643+
6644+
llvm::SmallString<0> HandlerAttribute;
6645+
llvm::raw_svector_ostream OS(HandlerAttribute);
6646+
OS << "@completionHandlerAsync(\"";
6647+
HandlerDesc.printAsyncFunctionName(OS);
6648+
OS << "\", completionHandlerIndex: " << HandlerDesc.Index << ")\n";
6649+
EditConsumer.accept(Ctx.SourceMgr, FD->getAttributeInsertionLoc(false),
6650+
HandlerAttribute);
6651+
}
6652+
64616653
} // namespace asyncrefactorings
64626654

64636655
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
@@ -6579,16 +6771,7 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
65796771
"@available(*, deprecated, message: \"Prefer async "
65806772
"alternative instead\")\n");
65816773

6582-
if (Ctx.LangOpts.EnableExperimentalConcurrency) {
6583-
// Add an attribute to describe its async alternative
6584-
llvm::SmallString<0> HandlerAttribute;
6585-
llvm::raw_svector_ostream OS(HandlerAttribute);
6586-
OS << "@completionHandlerAsync(\"";
6587-
HandlerDesc.printAsyncFunctionName(OS);
6588-
OS << "\", completionHandlerIndex: " << HandlerDesc.Index << ")\n";
6589-
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
6590-
HandlerAttribute);
6591-
}
6774+
addCompletionHandlerAsyncAttrIfNeccessary(Ctx, FD, HandlerDesc, EditConsumer);
65926775

65936776
AsyncConverter LegacyBodyCreator(TheFile, SM, DiagEngine, FD, HandlerDesc);
65946777
if (LegacyBodyCreator.createLegacyBody()) {
@@ -6600,6 +6783,43 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
66006783

66016784
return false;
66026785
}
6786+
6787+
bool RefactoringActionAddAsyncWrapper::isApplicable(
6788+
const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) {
6789+
using namespace asyncrefactorings;
6790+
6791+
auto *FD = findFunction(CursorInfo);
6792+
if (!FD)
6793+
return false;
6794+
6795+
auto HandlerDesc =
6796+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
6797+
return HandlerDesc.isValid();
6798+
}
6799+
6800+
bool RefactoringActionAddAsyncWrapper::performChange() {
6801+
using namespace asyncrefactorings;
6802+
6803+
auto *FD = findFunction(CursorInfo);
6804+
assert(FD &&
6805+
"Should not run performChange when refactoring is not applicable");
6806+
6807+
auto HandlerDesc =
6808+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
6809+
assert(HandlerDesc.isValid() &&
6810+
"Should not run performChange when refactoring is not applicable");
6811+
6812+
AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
6813+
if (!Converter.createAsyncWrapper())
6814+
return true;
6815+
6816+
addCompletionHandlerAsyncAttrIfNeccessary(Ctx, FD, HandlerDesc, EditConsumer);
6817+
6818+
// Add the async wrapper.
6819+
Converter.insertAfter(FD, EditConsumer);
6820+
return false;
6821+
}
6822+
66036823
} // end of anonymous namespace
66046824

66056825
StringRef swift::ide::

0 commit comments

Comments
 (0)