Skip to content

[5.5][Refactoring] When adding an async alternative refactor the old method to call the async method using detach #37203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,9 @@ class alignas(1 << TypeAlignInBits) TypeBase {
/// Check if this type is equal to Swift.Bool.
bool isBool();

/// Check if this type is equal to Swift.Optional.
bool isOptional();

/// Check if this type is equal to Builtin.IntN.
bool isBuiltinIntegerType(unsigned bitWidth);

Expand Down
10 changes: 10 additions & 0 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,16 @@ bool TypeBase::isBool() {
return false;
}

/// Check if this type is equal to Swift.Bool.
bool TypeBase::isOptional() {
if (auto generic = getAnyGeneric()) {
if (isa<EnumDecl>(generic)) {
return getASTContext().getOptionalDecl() == generic;
}
}
return false;
}

Type TypeBase::getRValueType() {
// If the type is not an lvalue, this is a no-op.
if (!hasLValueType())
Expand Down
293 changes: 293 additions & 0 deletions lib/IDE/Refactoring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4158,6 +4158,36 @@ struct AsyncHandlerDesc {
return params();
}

/// Get the type of the error that will be thrown by the \c async method or \c
/// None if the completion handler doesn't accept an error parameter.
/// This may be more specialized than the generic 'Error' type if the
/// completion handler of the converted function takes a more specialized
/// error type.
Optional<swift::Type> getErrorType() const {
if (HasError) {
switch (Type) {
case HandlerType::INVALID:
return None;
case HandlerType::PARAMS:
// The last parameter of the completion handler is the error param
return params().back().getPlainType()->lookThroughSingleOptionalType();
case HandlerType::RESULT:
assert(
params().size() == 1 &&
"Result handler should have the Result type as the only parameter");
auto ResultType =
params().back().getPlainType()->getAs<BoundGenericType>();
auto GenericArgs = ResultType->getGenericArgs();
assert(GenericArgs.size() == 2 && "Result should have two params");
// The second (last) generic parameter of the Result type is the error
// type.
return GenericArgs.back();
}
} else {
return None;
}
}

/// The `CallExpr` if the given node is a call to the `Handler`
CallExpr *getAsHandlerCall(ASTNode Node) const {
if (!isValid())
Expand Down Expand Up @@ -5318,6 +5348,262 @@ class AsyncConverter : private SourceEntityWalker {
}
}
};

/// When adding an async alternative method for the function declaration \c FD,
/// this class tries to create a function body for the legacy function (the one
/// with a completion handler), which calls the newly converted async function.
/// There are certain situations in which we fail to create such a body, e.g.
/// if the completion handler has the signature `(String, Error?) -> Void` in
/// which case we can't synthesize the result of type \c String in the error
/// case.
class LegacyAlternativeBodyCreator {
/// The old function declaration for which an async alternative has been added
/// and whose body shall be rewritten to call the newly added async
/// alternative.
FuncDecl *FD;

/// The description of the completion handler in the old function declaration.
AsyncHandlerDesc HandlerDesc;

std::string Buffer;
llvm::raw_string_ostream OS;

/// Adds the call to the refactored 'async' method without the 'await'
/// keyword to the output stream.
void addCallToAsyncMethod() {
OS << FD->getBaseName() << "(";
bool FirstParam = true;
for (auto Param : *FD->getParameters()) {
if (Param == HandlerDesc.Handler) {
/// We don't need to pass the completion handler to the async method.
continue;
}
if (!FirstParam) {
OS << ", ";
} else {
FirstParam = false;
}
if (!Param->getArgumentName().empty()) {
OS << Param->getArgumentName() << ": ";
}
OS << Param->getParameterName();
}
OS << ")";
}

/// If the returned error type is more specialized than \c Error, adds an
/// 'as! CustomError' cast to the more specialized error type to the output
/// stream.
void addCastToCustomErrorTypeIfNecessary() {
auto ErrorType = *HandlerDesc.getErrorType();
if (ErrorType->getCanonicalType() !=
FD->getASTContext().getExceptionType()) {
OS << " as! ";
ErrorType->lookThroughSingleOptionalType()->print(OS);
}
}

/// Adds the \c Index -th parameter to the completion handler.
/// If \p HasResult is \c true, it is assumed that a variable named 'result'
/// contains the result returned from the async alternative. If the callback
/// also takes an error parameter, \c nil passed to the completion handler for
/// the error.
/// If \p HasResult is \c false, it is a assumed that a variable named 'error'
/// contains the error thrown from the async method and 'nil' will be passed
/// to the completion handler for all result parameters.
void addCompletionHandlerArgument(size_t Index, bool HasResult) {
if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) {
// The error parameter is the last argument of the completion handler.
if (!HasResult) {
OS << "error";
addCastToCustomErrorTypeIfNecessary();
} else {
OS << "nil";
}
} else {
if (!HasResult) {
OS << "nil";
} else if (HandlerDesc
.getSuccessParamAsyncReturnType(
HandlerDesc.params()[Index].getPlainType())
->isVoid()) {
// Void return types are not returned by the async function, synthesize
// a Void instance.
OS << "()";
} else if (HandlerDesc.getSuccessParams().size() > 1) {
// If the async method returns a tuple, we need to pass its elements to
// the completion handler separately. For example:
//
// func foo() async -> (String, Int) {}
//
// causes the following legacy body to be created:
//
// func foo(completion: (String, Int) -> Void) {
// async {
// let result = await foo()
// completion(result.0, result.1)
// }
// }
OS << "result." << Index;
} else {
OS << "result";
}
}
}

/// Adds the call to the completion handler. See \c
/// getCompletionHandlerArgument for how the arguments are synthesized if the
/// completion handler takes arguments, not a \c Result type.
void addCallToCompletionHandler(bool HasResult) {
OS << HandlerDesc.Handler->getParameterName() << "(";

// Construct arguments to pass to the completion handler
switch (HandlerDesc.Type) {
case HandlerType::INVALID:
llvm_unreachable("Cannot be rewritten");
break;
case HandlerType::PARAMS: {
for (size_t I = 0; I < HandlerDesc.params().size(); ++I) {
if (I > 0) {
OS << ", ";
}
addCompletionHandlerArgument(I, HasResult);
}
break;
}
case HandlerType::RESULT: {
if (HasResult) {
OS << ".success(result)";
} else {
OS << ".failure(error";
addCastToCustomErrorTypeIfNecessary();
OS << ")";
}
break;
}
}
OS << ")"; // Close the call to the completion handler
}

/// Adds the result type of the converted async function.
void addAsyncFuncReturnType() {
SmallVector<Type, 2> Scratch;
auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
if (ReturnTypes.size() > 1) {
OS << "(";
}

llvm::interleave(
ReturnTypes, [&](Type Ty) { Ty->print(OS); }, [&]() { OS << ", "; });

if (ReturnTypes.size() > 1) {
OS << ")";
}
}

/// If the async alternative function is generic, adds the type annotation
/// to the 'return' variable in the legacy function so that the generic
/// parameters of the legacy function are passed to the generic function.
/// For example for
/// \code
/// func foo<GenericParam>() async -> GenericParam {}
/// \endcode
/// we generate
/// \code
/// func foo<GenericParam>(completion: (T) -> Void) {
/// async {
/// let result: GenericParam = await foo()
/// <------------>
/// completion(result)
/// }
/// }
/// \endcode
/// This function adds the range marked by \c <----->
void addResultTypeAnnotationIfNecessary() {
if (FD->isGeneric()) {
OS << ": ";
addAsyncFuncReturnType();
}
}

public:
LegacyAlternativeBodyCreator(FuncDecl *FD, AsyncHandlerDesc HandlerDesc)
: FD(FD), HandlerDesc(HandlerDesc), OS(Buffer) {}

bool canRewriteLegacyBody() {
if (FD == nullptr || FD->getBody() == nullptr) {
return false;
}
if (FD->hasThrows()) {
assert(!HandlerDesc.isValid() && "We shouldn't have found a handler desc "
"if the original function throws");
return false;
}
switch (HandlerDesc.Type) {
case HandlerType::INVALID:
return false;
case HandlerType::PARAMS: {
if (HandlerDesc.HasError) {
// The non-error parameters must be optional so that we can set them to
// nil in the error case.
// The error parameter must be optional so we can set it to nil in the
// success case.
// Otherwise we can't synthesize the values to return for these
// parameters.
return llvm::all_of(HandlerDesc.params(),
[](AnyFunctionType::Param Param) -> bool {
return Param.getPlainType()->isOptional();
});
} else {
return true;
}
}
case HandlerType::RESULT:
return true;
}
}

std::string create() {
assert(Buffer.empty() &&
"LegacyAlternativeBodyCreator can only be used once");
assert(canRewriteLegacyBody() &&
"Cannot create a legacy body if the body can't be rewritten");
OS << "{\n"; // start function body
OS << "async {\n";
if (HandlerDesc.HasError) {
OS << "do {\n";
if (!HandlerDesc.willAsyncReturnVoid()) {
OS << "let result";
addResultTypeAnnotationIfNecessary();
OS << " = ";
}
OS << "try await ";
addCallToAsyncMethod();
OS << "\n";
addCallToCompletionHandler(/*HasResult=*/true);
OS << "\n"
<< "} catch {\n";
addCallToCompletionHandler(/*HasResult=*/false);
OS << "\n"
<< "}\n"; // end catch
} else {
if (!HandlerDesc.willAsyncReturnVoid()) {
OS << "let result";
addResultTypeAnnotationIfNecessary();
OS << " = ";
}
OS << "await ";
addCallToAsyncMethod();
OS << "\n";
addCallToCompletionHandler(/*HasResult=*/true);
OS << "\n";
}
OS << "}\n"; // end 'async'
OS << "}\n"; // end function body
return Buffer;
}
};

} // namespace asyncrefactorings

bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
Expand Down Expand Up @@ -5424,6 +5710,13 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
"@available(*, deprecated, message: \"Prefer async "
"alternative instead\")\n");
LegacyAlternativeBodyCreator LegacyBody(FD, HandlerDesc);
if (LegacyBody.canRewriteLegacyBody()) {
EditConsumer.accept(SM,
Lexer::getCharSourceRangeFromSourceRange(
SM, FD->getBody()->getSourceRange()),
LegacyBody.create());
}
Converter.insertAfter(FD, EditConsumer);

return false;
Expand Down
1 change: 1 addition & 0 deletions test/lit.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ config.scale_test = make_path(config.swift_utils, 'scale-test')
config.PathSanitizingFileCheck = make_path(config.swift_utils, 'PathSanitizingFileCheck')
config.swift_lib_dir = make_path(config.swift, '..', '..', 'lib')
config.round_trip_syntax_test = make_path(config.swift_utils, 'round-trip-syntax-test')
config.refactor_check_compiles = make_path(config.swift_utils, 'refactor-check-compiles.py')

config.link = lit.util.which('link', config.environment.get('PATH', '')) or \
lit.util.which('lld-link', config.environment.get('PATH', ''))
Expand Down
Loading