Skip to content

Commit 98e6680

Browse files
committed
[Refactoring] When adding an async alternative refactor the old method to call the async method using detach
Instead of leaving two copies of the same implementation, rewrite the old method with the completion handler to call the newly added `async` method. Resolves rdar://74464833
1 parent 9916a6d commit 98e6680

File tree

3 files changed

+565
-41
lines changed

3 files changed

+565
-41
lines changed

lib/IDE/Refactoring.cpp

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4149,6 +4149,36 @@ struct AsyncHandlerDesc {
41494149
return params();
41504150
}
41514151

4152+
/// Get the type of the error that will be thrown by the \c async method or \c
4153+
/// None if the completion handler doesn't accept an error parameter.
4154+
/// This may be more specialized than the generic 'Error' type if the
4155+
/// completion handler of the converted function takes a more specialized
4156+
/// error type.
4157+
Optional<swift::Type> getErrorType() const {
4158+
if (HasError) {
4159+
switch (Type) {
4160+
case HandlerType::INVALID:
4161+
return None;
4162+
case HandlerType::PARAMS:
4163+
// The last parameter of the completion handler is the error param
4164+
return params().back().getPlainType()->lookThroughSingleOptionalType();
4165+
case HandlerType::RESULT:
4166+
assert(
4167+
params().size() == 1 &&
4168+
"Result handler should have the Result type as the only parameter");
4169+
auto ResultType =
4170+
params().back().getPlainType()->getAs<BoundGenericType>();
4171+
auto GenericArgs = ResultType->getGenericArgs();
4172+
assert(GenericArgs.size() == 2 && "Result should have two params");
4173+
// The second (last) generic parameter of the Result type is the error
4174+
// type.
4175+
return GenericArgs.back();
4176+
}
4177+
} else {
4178+
return None;
4179+
}
4180+
}
4181+
41524182
/// The `CallExpr` if the given node is a call to the `Handler`
41534183
CallExpr *getAsHandlerCall(ASTNode Node) const {
41544184
if (!isValid())
@@ -5310,6 +5340,262 @@ class AsyncConverter : private SourceEntityWalker {
53105340
}
53115341
}
53125342
};
5343+
5344+
/// When adding an async alternative method for the function declaration \c FD,
5345+
/// this class tries to create a function body for the legacy function (the one
5346+
/// with a completion handler), which calls the newly converted async function.
5347+
/// There are certain situations in which we fail to create such a body, e.g.
5348+
/// if the completion handler has the signature `(String, Error?) -> Void` in
5349+
/// which case we can't synthesize the result of type \c String in the error
5350+
/// case.
5351+
class LegacyAlternativeBodyCreator {
5352+
/// The old function declaration for which an async alternative has been added
5353+
/// and whose body shall be rewritten to call the newly added async
5354+
/// alternative.
5355+
FuncDecl *FD;
5356+
5357+
/// The description of the completion handler in the old function declaration.
5358+
AsyncHandlerDesc HandlerDesc;
5359+
5360+
std::string Buffer;
5361+
llvm::raw_string_ostream OS;
5362+
5363+
/// Adds the call to the refactored 'async' method without the 'await'
5364+
/// keyword to the output stream.
5365+
void addCallToAsyncMethod() {
5366+
OS << FD->getBaseName() << "(";
5367+
bool FirstParam = true;
5368+
for (auto Param : *FD->getParameters()) {
5369+
if (Param == HandlerDesc.Handler) {
5370+
/// We don't need to pass the completion handler to the async method.
5371+
continue;
5372+
}
5373+
if (!FirstParam) {
5374+
OS << ", ";
5375+
} else {
5376+
FirstParam = false;
5377+
}
5378+
if (!Param->getArgumentName().empty()) {
5379+
OS << Param->getArgumentName() << ": ";
5380+
}
5381+
OS << Param->getParameterName();
5382+
}
5383+
OS << ")";
5384+
}
5385+
5386+
/// If the returned error type is more specialized than \c Error, adds an
5387+
/// 'as! CustomError' cast to the more specialized error type to the output
5388+
/// stream.
5389+
void addCastToCustomErrorTypeIfNecessary() {
5390+
auto ErrorType = *HandlerDesc.getErrorType();
5391+
if (ErrorType->getCanonicalType() !=
5392+
FD->getASTContext().getExceptionType()) {
5393+
OS << " as! ";
5394+
ErrorType->lookThroughSingleOptionalType()->print(OS);
5395+
}
5396+
}
5397+
5398+
/// Adds the \c Index -th parameter to the completion handler.
5399+
/// If \p HasResult is \c true, it is assumed that a variable named 'result'
5400+
/// contains the result returned from the async alternative. If the callback
5401+
/// also takes an error parameter, \c nil passed to the completion handler for
5402+
/// the error.
5403+
/// If \p HasResult is \c false, it is a assumed that a variable named 'error'
5404+
/// contains the error thrown from the async method and 'nil' will be passed
5405+
/// to the completion handler for all result parameters.
5406+
void addCompletionHandlerArgument(size_t Index, bool HasResult) {
5407+
if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) {
5408+
// The error parameter is the last argument of the completion handler.
5409+
if (!HasResult) {
5410+
OS << "error";
5411+
addCastToCustomErrorTypeIfNecessary();
5412+
} else {
5413+
OS << "nil";
5414+
}
5415+
} else {
5416+
if (!HasResult) {
5417+
OS << "nil";
5418+
} else if (HandlerDesc
5419+
.getSuccessParamAsyncReturnType(
5420+
HandlerDesc.params()[Index].getPlainType())
5421+
->isVoid()) {
5422+
// Void return types are not returned by the async function, synthesize
5423+
// a Void instance.
5424+
OS << "()";
5425+
} else if (HandlerDesc.getSuccessParams().size() > 1) {
5426+
// If the async method returns a tuple, we need to pass its elements to
5427+
// the completion handler separately. For example:
5428+
//
5429+
// func foo() async -> (String, Int) {}
5430+
//
5431+
// causes the following legacy body to be created:
5432+
//
5433+
// func foo(completion: (String, Int) -> Void) {
5434+
// async {
5435+
// let result = await foo()
5436+
// completion(result.0, result.1)
5437+
// }
5438+
// }
5439+
OS << "result." << Index;
5440+
} else {
5441+
OS << "result";
5442+
}
5443+
}
5444+
}
5445+
5446+
/// Adds the call to the completion handler. See \c
5447+
/// getCompletionHandlerArgument for how the arguments are synthesized if the
5448+
/// completion handler takes arguments, not a \c Result type.
5449+
void addCallToCompletionHandler(bool HasResult) {
5450+
OS << HandlerDesc.Handler->getParameterName() << "(";
5451+
5452+
// Construct arguments to pass to the completion handler
5453+
switch (HandlerDesc.Type) {
5454+
case HandlerType::INVALID:
5455+
llvm_unreachable("Cannot be rewritten");
5456+
break;
5457+
case HandlerType::PARAMS: {
5458+
for (size_t I = 0; I < HandlerDesc.params().size(); ++I) {
5459+
if (I > 0) {
5460+
OS << ", ";
5461+
}
5462+
addCompletionHandlerArgument(I, HasResult);
5463+
}
5464+
break;
5465+
}
5466+
case HandlerType::RESULT: {
5467+
if (HasResult) {
5468+
OS << ".success(result)";
5469+
} else {
5470+
OS << ".failure(error";
5471+
addCastToCustomErrorTypeIfNecessary();
5472+
OS << ")";
5473+
}
5474+
break;
5475+
}
5476+
}
5477+
OS << ")"; // Close the call to the completion handler
5478+
}
5479+
5480+
/// Adds the result type of the converted async function.
5481+
void addAsyncFuncReturnType() {
5482+
SmallVector<Type, 2> Scratch;
5483+
auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
5484+
if (ReturnTypes.size() > 1) {
5485+
OS << "(";
5486+
}
5487+
5488+
llvm::interleave(
5489+
ReturnTypes, [&](Type Ty) { Ty->print(OS); }, [&]() { OS << ", "; });
5490+
5491+
if (ReturnTypes.size() > 1) {
5492+
OS << ")";
5493+
}
5494+
}
5495+
5496+
/// If the async alternative function is generic, adds the type annotation
5497+
/// to the 'return' variable in the legacy function so that the generic
5498+
/// parameters of the legacy function are passed to the generic function.
5499+
/// For example for
5500+
/// \code
5501+
/// func foo<GenericParam>() async -> GenericParam {}
5502+
/// \endcode
5503+
/// we generate
5504+
/// \code
5505+
/// func foo<GenericParam>(completion: (T) -> Void) {
5506+
/// async {
5507+
/// let result: GenericParam = await foo()
5508+
/// <------------>
5509+
/// completion(result)
5510+
/// }
5511+
/// }
5512+
/// \endcode
5513+
/// This function adds the range marked by \c <----->
5514+
void addResultTypeAnnotationIfNecessary() {
5515+
if (FD->isGeneric()) {
5516+
OS << ": ";
5517+
addAsyncFuncReturnType();
5518+
}
5519+
}
5520+
5521+
public:
5522+
LegacyAlternativeBodyCreator(FuncDecl *FD, AsyncHandlerDesc HandlerDesc)
5523+
: FD(FD), HandlerDesc(HandlerDesc), OS(Buffer) {}
5524+
5525+
bool canRewriteLegacyBody() {
5526+
if (FD == nullptr || FD->getBody() == nullptr) {
5527+
return false;
5528+
}
5529+
if (FD->hasThrows()) {
5530+
assert(!HandlerDesc.isValid() && "We shouldn't have found a handler desc "
5531+
"if the original function throws");
5532+
return false;
5533+
}
5534+
switch (HandlerDesc.Type) {
5535+
case HandlerType::INVALID:
5536+
return false;
5537+
case HandlerType::PARAMS: {
5538+
if (HandlerDesc.HasError) {
5539+
// The non-error parameters must be optional so that we can set them to
5540+
// nil in the error case.
5541+
// The error parameter must be optional so we can set it to nil in the
5542+
// success case.
5543+
// Otherwise we can't synthesize the values to return for these
5544+
// parameters.
5545+
return llvm::all_of(HandlerDesc.params(),
5546+
[](AnyFunctionType::Param Param) -> bool {
5547+
return Param.getPlainType()->isOptional();
5548+
});
5549+
} else {
5550+
return true;
5551+
}
5552+
}
5553+
case HandlerType::RESULT:
5554+
return true;
5555+
}
5556+
}
5557+
5558+
std::string create() {
5559+
assert(Buffer.empty() &&
5560+
"LegacyAlternativeBodyCreator can only be used once");
5561+
assert(canRewriteLegacyBody() &&
5562+
"Cannot create a legacy body if the body can't be rewritten");
5563+
OS << "{\n"; // start function body
5564+
OS << "async {\n";
5565+
if (HandlerDesc.HasError) {
5566+
OS << "do {\n";
5567+
if (!HandlerDesc.willAsyncReturnVoid()) {
5568+
OS << "let result";
5569+
addResultTypeAnnotationIfNecessary();
5570+
OS << " = ";
5571+
}
5572+
OS << "try await ";
5573+
addCallToAsyncMethod();
5574+
OS << "\n";
5575+
addCallToCompletionHandler(/*HasResult=*/true);
5576+
OS << "\n"
5577+
<< "} catch {\n";
5578+
addCallToCompletionHandler(/*HasResult=*/false);
5579+
OS << "\n"
5580+
<< "}\n"; // end catch
5581+
} else {
5582+
if (!HandlerDesc.willAsyncReturnVoid()) {
5583+
OS << "let result";
5584+
addResultTypeAnnotationIfNecessary();
5585+
OS << " = ";
5586+
}
5587+
OS << "await ";
5588+
addCallToAsyncMethod();
5589+
OS << "\n";
5590+
addCallToCompletionHandler(/*HasResult=*/true);
5591+
OS << "\n";
5592+
}
5593+
OS << "}\n"; // end 'async'
5594+
OS << "}\n"; // end function body
5595+
return Buffer;
5596+
}
5597+
};
5598+
53135599
} // namespace asyncrefactorings
53145600

53155601
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
@@ -5416,6 +5702,13 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
54165702
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
54175703
"@available(*, deprecated, message: \"Prefer async "
54185704
"alternative instead\")\n");
5705+
LegacyAlternativeBodyCreator LegacyBody(FD, HandlerDesc);
5706+
if (LegacyBody.canRewriteLegacyBody()) {
5707+
EditConsumer.accept(SM,
5708+
Lexer::getCharSourceRangeFromSourceRange(
5709+
SM, FD->getBody()->getSourceRange()),
5710+
LegacyBody.create());
5711+
}
54195712
Converter.insertAfter(FD, EditConsumer);
54205713

54215714
return false;

0 commit comments

Comments
 (0)