Skip to content

Commit e835b77

Browse files
authored
Merge pull request #37185 from ahoppen/pr/legacy-async-method-refactor
[Refactoring] When adding an async alternative refactor the old method to call the async method using `async`
2 parents 8b59e13 + 98e6680 commit e835b77

File tree

5 files changed

+666
-40
lines changed

5 files changed

+666
-40
lines changed

lib/IDE/Refactoring.cpp

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4149,6 +4149,36 @@ struct AsyncHandlerDesc {
41494149
return params();
41504150
}
41514151

4152+
/// Get the type of the error that will be thrown by the \c async method or \c
4153+
/// None if the completion handler doesn't accept an error parameter.
4154+
/// This may be more specialized than the generic 'Error' type if the
4155+
/// completion handler of the converted function takes a more specialized
4156+
/// error type.
4157+
Optional<swift::Type> getErrorType() const {
4158+
if (HasError) {
4159+
switch (Type) {
4160+
case HandlerType::INVALID:
4161+
return None;
4162+
case HandlerType::PARAMS:
4163+
// The last parameter of the completion handler is the error param
4164+
return params().back().getPlainType()->lookThroughSingleOptionalType();
4165+
case HandlerType::RESULT:
4166+
assert(
4167+
params().size() == 1 &&
4168+
"Result handler should have the Result type as the only parameter");
4169+
auto ResultType =
4170+
params().back().getPlainType()->getAs<BoundGenericType>();
4171+
auto GenericArgs = ResultType->getGenericArgs();
4172+
assert(GenericArgs.size() == 2 && "Result should have two params");
4173+
// The second (last) generic parameter of the Result type is the error
4174+
// type.
4175+
return GenericArgs.back();
4176+
}
4177+
} else {
4178+
return None;
4179+
}
4180+
}
4181+
41524182
/// The `CallExpr` if the given node is a call to the `Handler`
41534183
CallExpr *getAsHandlerCall(ASTNode Node) const {
41544184
if (!isValid())
@@ -5319,6 +5349,262 @@ class AsyncConverter : private SourceEntityWalker {
53195349
}
53205350
}
53215351
};
5352+
5353+
/// When adding an async alternative method for the function declaration \c FD,
5354+
/// this class tries to create a function body for the legacy function (the one
5355+
/// with a completion handler), which calls the newly converted async function.
5356+
/// There are certain situations in which we fail to create such a body, e.g.
5357+
/// if the completion handler has the signature `(String, Error?) -> Void` in
5358+
/// which case we can't synthesize the result of type \c String in the error
5359+
/// case.
5360+
class LegacyAlternativeBodyCreator {
5361+
/// The old function declaration for which an async alternative has been added
5362+
/// and whose body shall be rewritten to call the newly added async
5363+
/// alternative.
5364+
FuncDecl *FD;
5365+
5366+
/// The description of the completion handler in the old function declaration.
5367+
AsyncHandlerDesc HandlerDesc;
5368+
5369+
std::string Buffer;
5370+
llvm::raw_string_ostream OS;
5371+
5372+
/// Adds the call to the refactored 'async' method without the 'await'
5373+
/// keyword to the output stream.
5374+
void addCallToAsyncMethod() {
5375+
OS << FD->getBaseName() << "(";
5376+
bool FirstParam = true;
5377+
for (auto Param : *FD->getParameters()) {
5378+
if (Param == HandlerDesc.Handler) {
5379+
/// We don't need to pass the completion handler to the async method.
5380+
continue;
5381+
}
5382+
if (!FirstParam) {
5383+
OS << ", ";
5384+
} else {
5385+
FirstParam = false;
5386+
}
5387+
if (!Param->getArgumentName().empty()) {
5388+
OS << Param->getArgumentName() << ": ";
5389+
}
5390+
OS << Param->getParameterName();
5391+
}
5392+
OS << ")";
5393+
}
5394+
5395+
/// If the returned error type is more specialized than \c Error, adds an
5396+
/// 'as! CustomError' cast to the more specialized error type to the output
5397+
/// stream.
5398+
void addCastToCustomErrorTypeIfNecessary() {
5399+
auto ErrorType = *HandlerDesc.getErrorType();
5400+
if (ErrorType->getCanonicalType() !=
5401+
FD->getASTContext().getExceptionType()) {
5402+
OS << " as! ";
5403+
ErrorType->lookThroughSingleOptionalType()->print(OS);
5404+
}
5405+
}
5406+
5407+
/// Adds the \c Index -th parameter to the completion handler.
5408+
/// If \p HasResult is \c true, it is assumed that a variable named 'result'
5409+
/// contains the result returned from the async alternative. If the callback
5410+
/// also takes an error parameter, \c nil passed to the completion handler for
5411+
/// the error.
5412+
/// If \p HasResult is \c false, it is a assumed that a variable named 'error'
5413+
/// contains the error thrown from the async method and 'nil' will be passed
5414+
/// to the completion handler for all result parameters.
5415+
void addCompletionHandlerArgument(size_t Index, bool HasResult) {
5416+
if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) {
5417+
// The error parameter is the last argument of the completion handler.
5418+
if (!HasResult) {
5419+
OS << "error";
5420+
addCastToCustomErrorTypeIfNecessary();
5421+
} else {
5422+
OS << "nil";
5423+
}
5424+
} else {
5425+
if (!HasResult) {
5426+
OS << "nil";
5427+
} else if (HandlerDesc
5428+
.getSuccessParamAsyncReturnType(
5429+
HandlerDesc.params()[Index].getPlainType())
5430+
->isVoid()) {
5431+
// Void return types are not returned by the async function, synthesize
5432+
// a Void instance.
5433+
OS << "()";
5434+
} else if (HandlerDesc.getSuccessParams().size() > 1) {
5435+
// If the async method returns a tuple, we need to pass its elements to
5436+
// the completion handler separately. For example:
5437+
//
5438+
// func foo() async -> (String, Int) {}
5439+
//
5440+
// causes the following legacy body to be created:
5441+
//
5442+
// func foo(completion: (String, Int) -> Void) {
5443+
// async {
5444+
// let result = await foo()
5445+
// completion(result.0, result.1)
5446+
// }
5447+
// }
5448+
OS << "result." << Index;
5449+
} else {
5450+
OS << "result";
5451+
}
5452+
}
5453+
}
5454+
5455+
/// Adds the call to the completion handler. See \c
5456+
/// getCompletionHandlerArgument for how the arguments are synthesized if the
5457+
/// completion handler takes arguments, not a \c Result type.
5458+
void addCallToCompletionHandler(bool HasResult) {
5459+
OS << HandlerDesc.Handler->getParameterName() << "(";
5460+
5461+
// Construct arguments to pass to the completion handler
5462+
switch (HandlerDesc.Type) {
5463+
case HandlerType::INVALID:
5464+
llvm_unreachable("Cannot be rewritten");
5465+
break;
5466+
case HandlerType::PARAMS: {
5467+
for (size_t I = 0; I < HandlerDesc.params().size(); ++I) {
5468+
if (I > 0) {
5469+
OS << ", ";
5470+
}
5471+
addCompletionHandlerArgument(I, HasResult);
5472+
}
5473+
break;
5474+
}
5475+
case HandlerType::RESULT: {
5476+
if (HasResult) {
5477+
OS << ".success(result)";
5478+
} else {
5479+
OS << ".failure(error";
5480+
addCastToCustomErrorTypeIfNecessary();
5481+
OS << ")";
5482+
}
5483+
break;
5484+
}
5485+
}
5486+
OS << ")"; // Close the call to the completion handler
5487+
}
5488+
5489+
/// Adds the result type of the converted async function.
5490+
void addAsyncFuncReturnType() {
5491+
SmallVector<Type, 2> Scratch;
5492+
auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
5493+
if (ReturnTypes.size() > 1) {
5494+
OS << "(";
5495+
}
5496+
5497+
llvm::interleave(
5498+
ReturnTypes, [&](Type Ty) { Ty->print(OS); }, [&]() { OS << ", "; });
5499+
5500+
if (ReturnTypes.size() > 1) {
5501+
OS << ")";
5502+
}
5503+
}
5504+
5505+
/// If the async alternative function is generic, adds the type annotation
5506+
/// to the 'return' variable in the legacy function so that the generic
5507+
/// parameters of the legacy function are passed to the generic function.
5508+
/// For example for
5509+
/// \code
5510+
/// func foo<GenericParam>() async -> GenericParam {}
5511+
/// \endcode
5512+
/// we generate
5513+
/// \code
5514+
/// func foo<GenericParam>(completion: (T) -> Void) {
5515+
/// async {
5516+
/// let result: GenericParam = await foo()
5517+
/// <------------>
5518+
/// completion(result)
5519+
/// }
5520+
/// }
5521+
/// \endcode
5522+
/// This function adds the range marked by \c <----->
5523+
void addResultTypeAnnotationIfNecessary() {
5524+
if (FD->isGeneric()) {
5525+
OS << ": ";
5526+
addAsyncFuncReturnType();
5527+
}
5528+
}
5529+
5530+
public:
5531+
LegacyAlternativeBodyCreator(FuncDecl *FD, AsyncHandlerDesc HandlerDesc)
5532+
: FD(FD), HandlerDesc(HandlerDesc), OS(Buffer) {}
5533+
5534+
bool canRewriteLegacyBody() {
5535+
if (FD == nullptr || FD->getBody() == nullptr) {
5536+
return false;
5537+
}
5538+
if (FD->hasThrows()) {
5539+
assert(!HandlerDesc.isValid() && "We shouldn't have found a handler desc "
5540+
"if the original function throws");
5541+
return false;
5542+
}
5543+
switch (HandlerDesc.Type) {
5544+
case HandlerType::INVALID:
5545+
return false;
5546+
case HandlerType::PARAMS: {
5547+
if (HandlerDesc.HasError) {
5548+
// The non-error parameters must be optional so that we can set them to
5549+
// nil in the error case.
5550+
// The error parameter must be optional so we can set it to nil in the
5551+
// success case.
5552+
// Otherwise we can't synthesize the values to return for these
5553+
// parameters.
5554+
return llvm::all_of(HandlerDesc.params(),
5555+
[](AnyFunctionType::Param Param) -> bool {
5556+
return Param.getPlainType()->isOptional();
5557+
});
5558+
} else {
5559+
return true;
5560+
}
5561+
}
5562+
case HandlerType::RESULT:
5563+
return true;
5564+
}
5565+
}
5566+
5567+
std::string create() {
5568+
assert(Buffer.empty() &&
5569+
"LegacyAlternativeBodyCreator can only be used once");
5570+
assert(canRewriteLegacyBody() &&
5571+
"Cannot create a legacy body if the body can't be rewritten");
5572+
OS << "{\n"; // start function body
5573+
OS << "async {\n";
5574+
if (HandlerDesc.HasError) {
5575+
OS << "do {\n";
5576+
if (!HandlerDesc.willAsyncReturnVoid()) {
5577+
OS << "let result";
5578+
addResultTypeAnnotationIfNecessary();
5579+
OS << " = ";
5580+
}
5581+
OS << "try await ";
5582+
addCallToAsyncMethod();
5583+
OS << "\n";
5584+
addCallToCompletionHandler(/*HasResult=*/true);
5585+
OS << "\n"
5586+
<< "} catch {\n";
5587+
addCallToCompletionHandler(/*HasResult=*/false);
5588+
OS << "\n"
5589+
<< "}\n"; // end catch
5590+
} else {
5591+
if (!HandlerDesc.willAsyncReturnVoid()) {
5592+
OS << "let result";
5593+
addResultTypeAnnotationIfNecessary();
5594+
OS << " = ";
5595+
}
5596+
OS << "await ";
5597+
addCallToAsyncMethod();
5598+
OS << "\n";
5599+
addCallToCompletionHandler(/*HasResult=*/true);
5600+
OS << "\n";
5601+
}
5602+
OS << "}\n"; // end 'async'
5603+
OS << "}\n"; // end function body
5604+
return Buffer;
5605+
}
5606+
};
5607+
53225608
} // namespace asyncrefactorings
53235609

53245610
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
@@ -5425,6 +5711,13 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
54255711
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
54265712
"@available(*, deprecated, message: \"Prefer async "
54275713
"alternative instead\")\n");
5714+
LegacyAlternativeBodyCreator LegacyBody(FD, HandlerDesc);
5715+
if (LegacyBody.canRewriteLegacyBody()) {
5716+
EditConsumer.accept(SM,
5717+
Lexer::getCharSourceRangeFromSourceRange(
5718+
SM, FD->getBody()->getSourceRange()),
5719+
LegacyBody.create());
5720+
}
54285721
Converter.insertAfter(FD, EditConsumer);
54295722

54305723
return false;

test/lit.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ config.scale_test = make_path(config.swift_utils, 'scale-test')
334334
config.PathSanitizingFileCheck = make_path(config.swift_utils, 'PathSanitizingFileCheck')
335335
config.swift_lib_dir = make_path(config.swift, '..', '..', 'lib')
336336
config.round_trip_syntax_test = make_path(config.swift_utils, 'round-trip-syntax-test')
337+
config.refactor_check_compiles = make_path(config.swift_utils, 'refactor-check-compiles.py')
337338

338339
config.link = lit.util.which('link', config.environment.get('PATH', '')) or \
339340
lit.util.which('lld-link', config.environment.get('PATH', ''))

0 commit comments

Comments
 (0)