Skip to content

Commit 2878ab4

Browse files
committed
[Refactoring] When adding an async alternative refactor the old method to call the async method using detach
Instead of leaving two copies of the same implementation, rewrite the old method with the completion handler to call the newly added `async` method. Resolves rdar://74464833
1 parent 345c3b0 commit 2878ab4

File tree

3 files changed

+565
-41
lines changed

3 files changed

+565
-41
lines changed

lib/IDE/Refactoring.cpp

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4158,6 +4158,36 @@ struct AsyncHandlerDesc {
41584158
return params();
41594159
}
41604160

4161+
/// Get the type of the error that will be thrown by the \c async method or \c
4162+
/// None if the completion handler doesn't accept an error parameter.
4163+
/// This may be more specialized than the generic 'Error' type if the
4164+
/// completion handler of the converted function takes a more specialized
4165+
/// error type.
4166+
Optional<swift::Type> getErrorType() const {
4167+
if (HasError) {
4168+
switch (Type) {
4169+
case HandlerType::INVALID:
4170+
return None;
4171+
case HandlerType::PARAMS:
4172+
// The last parameter of the completion handler is the error param
4173+
return params().back().getPlainType()->lookThroughSingleOptionalType();
4174+
case HandlerType::RESULT:
4175+
assert(
4176+
params().size() == 1 &&
4177+
"Result handler should have the Result type as the only parameter");
4178+
auto ResultType =
4179+
params().back().getPlainType()->getAs<BoundGenericType>();
4180+
auto GenericArgs = ResultType->getGenericArgs();
4181+
assert(GenericArgs.size() == 2 && "Result should have two params");
4182+
// The second (last) generic parameter of the Result type is the error
4183+
// type.
4184+
return GenericArgs.back();
4185+
}
4186+
} else {
4187+
return None;
4188+
}
4189+
}
4190+
41614191
/// The `CallExpr` if the given node is a call to the `Handler`
41624192
CallExpr *getAsHandlerCall(ASTNode Node) const {
41634193
if (!isValid())
@@ -5318,6 +5348,262 @@ class AsyncConverter : private SourceEntityWalker {
53185348
}
53195349
}
53205350
};
5351+
5352+
/// When adding an async alternative method for the function declaration \c FD,
5353+
/// this class tries to create a function body for the legacy function (the one
5354+
/// with a completion handler), which calls the newly converted async function.
5355+
/// There are certain situations in which we fail to create such a body, e.g.
5356+
/// if the completion handler has the signature `(String, Error?) -> Void` in
5357+
/// which case we can't synthesize the result of type \c String in the error
5358+
/// case.
5359+
class LegacyAlternativeBodyCreator {
5360+
/// The old function declaration for which an async alternative has been added
5361+
/// and whose body shall be rewritten to call the newly added async
5362+
/// alternative.
5363+
FuncDecl *FD;
5364+
5365+
/// The description of the completion handler in the old function declaration.
5366+
AsyncHandlerDesc HandlerDesc;
5367+
5368+
std::string Buffer;
5369+
llvm::raw_string_ostream OS;
5370+
5371+
/// Adds the call to the refactored 'async' method without the 'await'
5372+
/// keyword to the output stream.
5373+
void addCallToAsyncMethod() {
5374+
OS << FD->getBaseName() << "(";
5375+
bool FirstParam = true;
5376+
for (auto Param : *FD->getParameters()) {
5377+
if (Param == HandlerDesc.Handler) {
5378+
/// We don't need to pass the completion handler to the async method.
5379+
continue;
5380+
}
5381+
if (!FirstParam) {
5382+
OS << ", ";
5383+
} else {
5384+
FirstParam = false;
5385+
}
5386+
if (!Param->getArgumentName().empty()) {
5387+
OS << Param->getArgumentName() << ": ";
5388+
}
5389+
OS << Param->getParameterName();
5390+
}
5391+
OS << ")";
5392+
}
5393+
5394+
/// If the returned error type is more specialized than \c Error, adds an
5395+
/// 'as! CustomError' cast to the more specialized error type to the output
5396+
/// stream.
5397+
void addCastToCustomErrorTypeIfNecessary() {
5398+
auto ErrorType = *HandlerDesc.getErrorType();
5399+
if (ErrorType->getCanonicalType() !=
5400+
FD->getASTContext().getExceptionType()) {
5401+
OS << " as! ";
5402+
ErrorType->lookThroughSingleOptionalType()->print(OS);
5403+
}
5404+
}
5405+
5406+
/// Adds the \c Index -th parameter to the completion handler.
5407+
/// If \p HasResult is \c true, it is assumed that a variable named 'result'
5408+
/// contains the result returned from the async alternative. If the callback
5409+
/// also takes an error parameter, \c nil passed to the completion handler for
5410+
/// the error.
5411+
/// If \p HasResult is \c false, it is a assumed that a variable named 'error'
5412+
/// contains the error thrown from the async method and 'nil' will be passed
5413+
/// to the completion handler for all result parameters.
5414+
void addCompletionHandlerArgument(size_t Index, bool HasResult) {
5415+
if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) {
5416+
// The error parameter is the last argument of the completion handler.
5417+
if (!HasResult) {
5418+
OS << "error";
5419+
addCastToCustomErrorTypeIfNecessary();
5420+
} else {
5421+
OS << "nil";
5422+
}
5423+
} else {
5424+
if (!HasResult) {
5425+
OS << "nil";
5426+
} else if (HandlerDesc
5427+
.getSuccessParamAsyncReturnType(
5428+
HandlerDesc.params()[Index].getPlainType())
5429+
->isVoid()) {
5430+
// Void return types are not returned by the async function, synthesize
5431+
// a Void instance.
5432+
OS << "()";
5433+
} else if (HandlerDesc.getSuccessParams().size() > 1) {
5434+
// If the async method returns a tuple, we need to pass its elements to
5435+
// the completion handler separately. For example:
5436+
//
5437+
// func foo() async -> (String, Int) {}
5438+
//
5439+
// causes the following legacy body to be created:
5440+
//
5441+
// func foo(completion: (String, Int) -> Void) {
5442+
// async {
5443+
// let result = await foo()
5444+
// completion(result.0, result.1)
5445+
// }
5446+
// }
5447+
OS << "result." << Index;
5448+
} else {
5449+
OS << "result";
5450+
}
5451+
}
5452+
}
5453+
5454+
/// Adds the call to the completion handler. See \c
5455+
/// getCompletionHandlerArgument for how the arguments are synthesized if the
5456+
/// completion handler takes arguments, not a \c Result type.
5457+
void addCallToCompletionHandler(bool HasResult) {
5458+
OS << HandlerDesc.Handler->getParameterName() << "(";
5459+
5460+
// Construct arguments to pass to the completion handler
5461+
switch (HandlerDesc.Type) {
5462+
case HandlerType::INVALID:
5463+
llvm_unreachable("Cannot be rewritten");
5464+
break;
5465+
case HandlerType::PARAMS: {
5466+
for (size_t I = 0; I < HandlerDesc.params().size(); ++I) {
5467+
if (I > 0) {
5468+
OS << ", ";
5469+
}
5470+
addCompletionHandlerArgument(I, HasResult);
5471+
}
5472+
break;
5473+
}
5474+
case HandlerType::RESULT: {
5475+
if (HasResult) {
5476+
OS << ".success(result)";
5477+
} else {
5478+
OS << ".failure(error";
5479+
addCastToCustomErrorTypeIfNecessary();
5480+
OS << ")";
5481+
}
5482+
break;
5483+
}
5484+
}
5485+
OS << ")"; // Close the call to the completion handler
5486+
}
5487+
5488+
/// Adds the result type of the converted async function.
5489+
void addAsyncFuncReturnType() {
5490+
SmallVector<Type, 2> Scratch;
5491+
auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
5492+
if (ReturnTypes.size() > 1) {
5493+
OS << "(";
5494+
}
5495+
5496+
llvm::interleave(
5497+
ReturnTypes, [&](Type Ty) { Ty->print(OS); }, [&]() { OS << ", "; });
5498+
5499+
if (ReturnTypes.size() > 1) {
5500+
OS << ")";
5501+
}
5502+
}
5503+
5504+
/// If the async alternative function is generic, adds the type annotation
5505+
/// to the 'return' variable in the legacy function so that the generic
5506+
/// parameters of the legacy function are passed to the generic function.
5507+
/// For example for
5508+
/// \code
5509+
/// func foo<GenericParam>() async -> GenericParam {}
5510+
/// \endcode
5511+
/// we generate
5512+
/// \code
5513+
/// func foo<GenericParam>(completion: (T) -> Void) {
5514+
/// async {
5515+
/// let result: GenericParam = await foo()
5516+
/// <------------>
5517+
/// completion(result)
5518+
/// }
5519+
/// }
5520+
/// \endcode
5521+
/// This function adds the range marked by \c <----->
5522+
void addResultTypeAnnotationIfNecessary() {
5523+
if (FD->isGeneric()) {
5524+
OS << ": ";
5525+
addAsyncFuncReturnType();
5526+
}
5527+
}
5528+
5529+
public:
5530+
LegacyAlternativeBodyCreator(FuncDecl *FD, AsyncHandlerDesc HandlerDesc)
5531+
: FD(FD), HandlerDesc(HandlerDesc), OS(Buffer) {}
5532+
5533+
bool canRewriteLegacyBody() {
5534+
if (FD == nullptr || FD->getBody() == nullptr) {
5535+
return false;
5536+
}
5537+
if (FD->hasThrows()) {
5538+
assert(!HandlerDesc.isValid() && "We shouldn't have found a handler desc "
5539+
"if the original function throws");
5540+
return false;
5541+
}
5542+
switch (HandlerDesc.Type) {
5543+
case HandlerType::INVALID:
5544+
return false;
5545+
case HandlerType::PARAMS: {
5546+
if (HandlerDesc.HasError) {
5547+
// The non-error parameters must be optional so that we can set them to
5548+
// nil in the error case.
5549+
// The error parameter must be optional so we can set it to nil in the
5550+
// success case.
5551+
// Otherwise we can't synthesize the values to return for these
5552+
// parameters.
5553+
return llvm::all_of(HandlerDesc.params(),
5554+
[](AnyFunctionType::Param Param) -> bool {
5555+
return Param.getPlainType()->isOptional();
5556+
});
5557+
} else {
5558+
return true;
5559+
}
5560+
}
5561+
case HandlerType::RESULT:
5562+
return true;
5563+
}
5564+
}
5565+
5566+
std::string create() {
5567+
assert(Buffer.empty() &&
5568+
"LegacyAlternativeBodyCreator can only be used once");
5569+
assert(canRewriteLegacyBody() &&
5570+
"Cannot create a legacy body if the body can't be rewritten");
5571+
OS << "{\n"; // start function body
5572+
OS << "async {\n";
5573+
if (HandlerDesc.HasError) {
5574+
OS << "do {\n";
5575+
if (!HandlerDesc.willAsyncReturnVoid()) {
5576+
OS << "let result";
5577+
addResultTypeAnnotationIfNecessary();
5578+
OS << " = ";
5579+
}
5580+
OS << "try await ";
5581+
addCallToAsyncMethod();
5582+
OS << "\n";
5583+
addCallToCompletionHandler(/*HasResult=*/true);
5584+
OS << "\n"
5585+
<< "} catch {\n";
5586+
addCallToCompletionHandler(/*HasResult=*/false);
5587+
OS << "\n"
5588+
<< "}\n"; // end catch
5589+
} else {
5590+
if (!HandlerDesc.willAsyncReturnVoid()) {
5591+
OS << "let result";
5592+
addResultTypeAnnotationIfNecessary();
5593+
OS << " = ";
5594+
}
5595+
OS << "await ";
5596+
addCallToAsyncMethod();
5597+
OS << "\n";
5598+
addCallToCompletionHandler(/*HasResult=*/true);
5599+
OS << "\n";
5600+
}
5601+
OS << "}\n"; // end 'async'
5602+
OS << "}\n"; // end function body
5603+
return Buffer;
5604+
}
5605+
};
5606+
53215607
} // namespace asyncrefactorings
53225608

53235609
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
@@ -5424,6 +5710,13 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
54245710
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
54255711
"@available(*, deprecated, message: \"Prefer async "
54265712
"alternative instead\")\n");
5713+
LegacyAlternativeBodyCreator LegacyBody(FD, HandlerDesc);
5714+
if (LegacyBody.canRewriteLegacyBody()) {
5715+
EditConsumer.accept(SM,
5716+
Lexer::getCharSourceRangeFromSourceRange(
5717+
SM, FD->getBody()->getSourceRange()),
5718+
LegacyBody.create());
5719+
}
54275720
Converter.insertAfter(FD, EditConsumer);
54285721

54295722
return false;

0 commit comments

Comments
 (0)