Skip to content

Commit 3a2015f

Browse files
authored
Merge pull request #37203 from ahoppen/pr-5.5/legacy-async-method-refactor
[5.5][Refactoring] When adding an async alternative refactor the old method to call the async method using detach
2 parents aac7f1d + fc23339 commit 3a2015f

File tree

7 files changed

+679
-40
lines changed

7 files changed

+679
-40
lines changed

include/swift/AST/Types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,9 @@ class alignas(1 << TypeAlignInBits) TypeBase {
795795
/// Check if this type is equal to Swift.Bool.
796796
bool isBool();
797797

798+
/// Check if this type is equal to Swift.Optional.
799+
bool isOptional();
800+
798801
/// Check if this type is equal to Builtin.IntN.
799802
bool isBuiltinIntegerType(unsigned bitWidth);
800803

lib/AST/Type.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,16 @@ bool TypeBase::isBool() {
628628
return false;
629629
}
630630

631+
/// Check if this type is equal to Swift.Bool.
632+
bool TypeBase::isOptional() {
633+
if (auto generic = getAnyGeneric()) {
634+
if (isa<EnumDecl>(generic)) {
635+
return getASTContext().getOptionalDecl() == generic;
636+
}
637+
}
638+
return false;
639+
}
640+
631641
Type TypeBase::getRValueType() {
632642
// If the type is not an lvalue, this is a no-op.
633643
if (!hasLValueType())

lib/IDE/Refactoring.cpp

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4158,6 +4158,36 @@ struct AsyncHandlerDesc {
41584158
return params();
41594159
}
41604160

4161+
/// Get the type of the error that will be thrown by the \c async method or \c
4162+
/// None if the completion handler doesn't accept an error parameter.
4163+
/// This may be more specialized than the generic 'Error' type if the
4164+
/// completion handler of the converted function takes a more specialized
4165+
/// error type.
4166+
Optional<swift::Type> getErrorType() const {
4167+
if (HasError) {
4168+
switch (Type) {
4169+
case HandlerType::INVALID:
4170+
return None;
4171+
case HandlerType::PARAMS:
4172+
// The last parameter of the completion handler is the error param
4173+
return params().back().getPlainType()->lookThroughSingleOptionalType();
4174+
case HandlerType::RESULT:
4175+
assert(
4176+
params().size() == 1 &&
4177+
"Result handler should have the Result type as the only parameter");
4178+
auto ResultType =
4179+
params().back().getPlainType()->getAs<BoundGenericType>();
4180+
auto GenericArgs = ResultType->getGenericArgs();
4181+
assert(GenericArgs.size() == 2 && "Result should have two params");
4182+
// The second (last) generic parameter of the Result type is the error
4183+
// type.
4184+
return GenericArgs.back();
4185+
}
4186+
} else {
4187+
return None;
4188+
}
4189+
}
4190+
41614191
/// The `CallExpr` if the given node is a call to the `Handler`
41624192
CallExpr *getAsHandlerCall(ASTNode Node) const {
41634193
if (!isValid())
@@ -5370,6 +5400,262 @@ class AsyncConverter : private SourceEntityWalker {
53705400
}
53715401
}
53725402
};
5403+
5404+
/// When adding an async alternative method for the function declaration \c FD,
5405+
/// this class tries to create a function body for the legacy function (the one
5406+
/// with a completion handler), which calls the newly converted async function.
5407+
/// There are certain situations in which we fail to create such a body, e.g.
5408+
/// if the completion handler has the signature `(String, Error?) -> Void` in
5409+
/// which case we can't synthesize the result of type \c String in the error
5410+
/// case.
5411+
class LegacyAlternativeBodyCreator {
5412+
/// The old function declaration for which an async alternative has been added
5413+
/// and whose body shall be rewritten to call the newly added async
5414+
/// alternative.
5415+
FuncDecl *FD;
5416+
5417+
/// The description of the completion handler in the old function declaration.
5418+
AsyncHandlerDesc HandlerDesc;
5419+
5420+
std::string Buffer;
5421+
llvm::raw_string_ostream OS;
5422+
5423+
/// Adds the call to the refactored 'async' method without the 'await'
5424+
/// keyword to the output stream.
5425+
void addCallToAsyncMethod() {
5426+
OS << FD->getBaseName() << "(";
5427+
bool FirstParam = true;
5428+
for (auto Param : *FD->getParameters()) {
5429+
if (Param == HandlerDesc.Handler) {
5430+
/// We don't need to pass the completion handler to the async method.
5431+
continue;
5432+
}
5433+
if (!FirstParam) {
5434+
OS << ", ";
5435+
} else {
5436+
FirstParam = false;
5437+
}
5438+
if (!Param->getArgumentName().empty()) {
5439+
OS << Param->getArgumentName() << ": ";
5440+
}
5441+
OS << Param->getParameterName();
5442+
}
5443+
OS << ")";
5444+
}
5445+
5446+
/// If the returned error type is more specialized than \c Error, adds an
5447+
/// 'as! CustomError' cast to the more specialized error type to the output
5448+
/// stream.
5449+
void addCastToCustomErrorTypeIfNecessary() {
5450+
auto ErrorType = *HandlerDesc.getErrorType();
5451+
if (ErrorType->getCanonicalType() !=
5452+
FD->getASTContext().getExceptionType()) {
5453+
OS << " as! ";
5454+
ErrorType->lookThroughSingleOptionalType()->print(OS);
5455+
}
5456+
}
5457+
5458+
/// Adds the \c Index -th parameter to the completion handler.
5459+
/// If \p HasResult is \c true, it is assumed that a variable named 'result'
5460+
/// contains the result returned from the async alternative. If the callback
5461+
/// also takes an error parameter, \c nil passed to the completion handler for
5462+
/// the error.
5463+
/// If \p HasResult is \c false, it is a assumed that a variable named 'error'
5464+
/// contains the error thrown from the async method and 'nil' will be passed
5465+
/// to the completion handler for all result parameters.
5466+
void addCompletionHandlerArgument(size_t Index, bool HasResult) {
5467+
if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) {
5468+
// The error parameter is the last argument of the completion handler.
5469+
if (!HasResult) {
5470+
OS << "error";
5471+
addCastToCustomErrorTypeIfNecessary();
5472+
} else {
5473+
OS << "nil";
5474+
}
5475+
} else {
5476+
if (!HasResult) {
5477+
OS << "nil";
5478+
} else if (HandlerDesc
5479+
.getSuccessParamAsyncReturnType(
5480+
HandlerDesc.params()[Index].getPlainType())
5481+
->isVoid()) {
5482+
// Void return types are not returned by the async function, synthesize
5483+
// a Void instance.
5484+
OS << "()";
5485+
} else if (HandlerDesc.getSuccessParams().size() > 1) {
5486+
// If the async method returns a tuple, we need to pass its elements to
5487+
// the completion handler separately. For example:
5488+
//
5489+
// func foo() async -> (String, Int) {}
5490+
//
5491+
// causes the following legacy body to be created:
5492+
//
5493+
// func foo(completion: (String, Int) -> Void) {
5494+
// async {
5495+
// let result = await foo()
5496+
// completion(result.0, result.1)
5497+
// }
5498+
// }
5499+
OS << "result." << Index;
5500+
} else {
5501+
OS << "result";
5502+
}
5503+
}
5504+
}
5505+
5506+
/// Adds the call to the completion handler. See \c
5507+
/// getCompletionHandlerArgument for how the arguments are synthesized if the
5508+
/// completion handler takes arguments, not a \c Result type.
5509+
void addCallToCompletionHandler(bool HasResult) {
5510+
OS << HandlerDesc.Handler->getParameterName() << "(";
5511+
5512+
// Construct arguments to pass to the completion handler
5513+
switch (HandlerDesc.Type) {
5514+
case HandlerType::INVALID:
5515+
llvm_unreachable("Cannot be rewritten");
5516+
break;
5517+
case HandlerType::PARAMS: {
5518+
for (size_t I = 0; I < HandlerDesc.params().size(); ++I) {
5519+
if (I > 0) {
5520+
OS << ", ";
5521+
}
5522+
addCompletionHandlerArgument(I, HasResult);
5523+
}
5524+
break;
5525+
}
5526+
case HandlerType::RESULT: {
5527+
if (HasResult) {
5528+
OS << ".success(result)";
5529+
} else {
5530+
OS << ".failure(error";
5531+
addCastToCustomErrorTypeIfNecessary();
5532+
OS << ")";
5533+
}
5534+
break;
5535+
}
5536+
}
5537+
OS << ")"; // Close the call to the completion handler
5538+
}
5539+
5540+
/// Adds the result type of the converted async function.
5541+
void addAsyncFuncReturnType() {
5542+
SmallVector<Type, 2> Scratch;
5543+
auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
5544+
if (ReturnTypes.size() > 1) {
5545+
OS << "(";
5546+
}
5547+
5548+
llvm::interleave(
5549+
ReturnTypes, [&](Type Ty) { Ty->print(OS); }, [&]() { OS << ", "; });
5550+
5551+
if (ReturnTypes.size() > 1) {
5552+
OS << ")";
5553+
}
5554+
}
5555+
5556+
/// If the async alternative function is generic, adds the type annotation
5557+
/// to the 'return' variable in the legacy function so that the generic
5558+
/// parameters of the legacy function are passed to the generic function.
5559+
/// For example for
5560+
/// \code
5561+
/// func foo<GenericParam>() async -> GenericParam {}
5562+
/// \endcode
5563+
/// we generate
5564+
/// \code
5565+
/// func foo<GenericParam>(completion: (T) -> Void) {
5566+
/// async {
5567+
/// let result: GenericParam = await foo()
5568+
/// <------------>
5569+
/// completion(result)
5570+
/// }
5571+
/// }
5572+
/// \endcode
5573+
/// This function adds the range marked by \c <----->
5574+
void addResultTypeAnnotationIfNecessary() {
5575+
if (FD->isGeneric()) {
5576+
OS << ": ";
5577+
addAsyncFuncReturnType();
5578+
}
5579+
}
5580+
5581+
public:
5582+
LegacyAlternativeBodyCreator(FuncDecl *FD, AsyncHandlerDesc HandlerDesc)
5583+
: FD(FD), HandlerDesc(HandlerDesc), OS(Buffer) {}
5584+
5585+
bool canRewriteLegacyBody() {
5586+
if (FD == nullptr || FD->getBody() == nullptr) {
5587+
return false;
5588+
}
5589+
if (FD->hasThrows()) {
5590+
assert(!HandlerDesc.isValid() && "We shouldn't have found a handler desc "
5591+
"if the original function throws");
5592+
return false;
5593+
}
5594+
switch (HandlerDesc.Type) {
5595+
case HandlerType::INVALID:
5596+
return false;
5597+
case HandlerType::PARAMS: {
5598+
if (HandlerDesc.HasError) {
5599+
// The non-error parameters must be optional so that we can set them to
5600+
// nil in the error case.
5601+
// The error parameter must be optional so we can set it to nil in the
5602+
// success case.
5603+
// Otherwise we can't synthesize the values to return for these
5604+
// parameters.
5605+
return llvm::all_of(HandlerDesc.params(),
5606+
[](AnyFunctionType::Param Param) -> bool {
5607+
return Param.getPlainType()->isOptional();
5608+
});
5609+
} else {
5610+
return true;
5611+
}
5612+
}
5613+
case HandlerType::RESULT:
5614+
return true;
5615+
}
5616+
}
5617+
5618+
std::string create() {
5619+
assert(Buffer.empty() &&
5620+
"LegacyAlternativeBodyCreator can only be used once");
5621+
assert(canRewriteLegacyBody() &&
5622+
"Cannot create a legacy body if the body can't be rewritten");
5623+
OS << "{\n"; // start function body
5624+
OS << "async {\n";
5625+
if (HandlerDesc.HasError) {
5626+
OS << "do {\n";
5627+
if (!HandlerDesc.willAsyncReturnVoid()) {
5628+
OS << "let result";
5629+
addResultTypeAnnotationIfNecessary();
5630+
OS << " = ";
5631+
}
5632+
OS << "try await ";
5633+
addCallToAsyncMethod();
5634+
OS << "\n";
5635+
addCallToCompletionHandler(/*HasResult=*/true);
5636+
OS << "\n"
5637+
<< "} catch {\n";
5638+
addCallToCompletionHandler(/*HasResult=*/false);
5639+
OS << "\n"
5640+
<< "}\n"; // end catch
5641+
} else {
5642+
if (!HandlerDesc.willAsyncReturnVoid()) {
5643+
OS << "let result";
5644+
addResultTypeAnnotationIfNecessary();
5645+
OS << " = ";
5646+
}
5647+
OS << "await ";
5648+
addCallToAsyncMethod();
5649+
OS << "\n";
5650+
addCallToCompletionHandler(/*HasResult=*/true);
5651+
OS << "\n";
5652+
}
5653+
OS << "}\n"; // end 'async'
5654+
OS << "}\n"; // end function body
5655+
return Buffer;
5656+
}
5657+
};
5658+
53735659
} // namespace asyncrefactorings
53745660

53755661
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
@@ -5476,6 +5762,13 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
54765762
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
54775763
"@available(*, deprecated, message: \"Prefer async "
54785764
"alternative instead\")\n");
5765+
LegacyAlternativeBodyCreator LegacyBody(FD, HandlerDesc);
5766+
if (LegacyBody.canRewriteLegacyBody()) {
5767+
EditConsumer.accept(SM,
5768+
Lexer::getCharSourceRangeFromSourceRange(
5769+
SM, FD->getBody()->getSourceRange()),
5770+
LegacyBody.create());
5771+
}
54795772
Converter.insertAfter(FD, EditConsumer);
54805773

54815774
return false;

test/lit.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ config.scale_test = make_path(config.swift_utils, 'scale-test')
334334
config.PathSanitizingFileCheck = make_path(config.swift_utils, 'PathSanitizingFileCheck')
335335
config.swift_lib_dir = make_path(config.swift, '..', '..', 'lib')
336336
config.round_trip_syntax_test = make_path(config.swift_utils, 'round-trip-syntax-test')
337+
config.refactor_check_compiles = make_path(config.swift_utils, 'refactor-check-compiles.py')
337338

338339
config.link = lit.util.which('link', config.environment.get('PATH', '')) or \
339340
lit.util.which('lld-link', config.environment.get('PATH', ''))

0 commit comments

Comments
 (0)