Skip to content

[5.5][Refactoring] Add @completionHandlerAsync to sync function #37418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 66 additions & 15 deletions lib/IDE/Refactoring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4088,7 +4088,7 @@ struct AsyncHandlerDesc {
HandlerType Type = HandlerType::INVALID;
bool HasError = false;

static AsyncHandlerDesc get(const ValueDecl *Handler, bool ignoreName) {
static AsyncHandlerDesc get(const ValueDecl *Handler, bool RequireName) {
AsyncHandlerDesc HandlerDesc;
if (auto Var = dyn_cast<VarDecl>(Handler)) {
HandlerDesc.Handler = Var;
Expand All @@ -4099,8 +4099,8 @@ struct AsyncHandlerDesc {
return AsyncHandlerDesc();
}

// Callback must have a completion-like name (if we're not ignoring it)
if (!ignoreName && !isCompletionHandlerParamName(HandlerDesc.getNameStr()))
// Callback must have a completion-like name
if (RequireName && !isCompletionHandlerParamName(HandlerDesc.getNameStr()))
return AsyncHandlerDesc();

// Callback must be a function type and return void. Doesn't need to have
Expand Down Expand Up @@ -4349,17 +4349,26 @@ struct AsyncHandlerDesc {
/// information about that completion handler and its index within the function
/// declaration.
struct AsyncHandlerParamDesc : public AsyncHandlerDesc {
/// The function the completion handler is a parameter of.
const FuncDecl *Func = nullptr;
/// The index of the completion handler in the function that declares it.
int Index = -1;

AsyncHandlerParamDesc() : AsyncHandlerDesc() {}
AsyncHandlerParamDesc(const AsyncHandlerDesc &Handler, int Index)
: AsyncHandlerDesc(Handler), Index(Index) {}
AsyncHandlerParamDesc(const AsyncHandlerDesc &Handler, const FuncDecl *Func,
int Index)
: AsyncHandlerDesc(Handler), Func(Func), Index(Index) {}

static AsyncHandlerParamDesc find(const FuncDecl *FD, bool ignoreName) {
static AsyncHandlerParamDesc find(const FuncDecl *FD,
bool RequireAttributeOrName) {
if (!FD || FD->hasAsync() || FD->hasThrows())
return AsyncHandlerParamDesc();

bool RequireName = RequireAttributeOrName;
if (RequireAttributeOrName &&
FD->getAttrs().hasAttribute<CompletionHandlerAsyncAttr>())
RequireName = false;

// Require at least one parameter and void return type
auto *Params = FD->getParameters();
if (Params->size() == 0 || !FD->getResultInterfaceType()->isVoid())
Expand All @@ -4373,10 +4382,30 @@ struct AsyncHandlerParamDesc : public AsyncHandlerDesc {
if (Param->isAutoClosure())
return AsyncHandlerParamDesc();

return AsyncHandlerParamDesc(AsyncHandlerDesc::get(Param, ignoreName),
return AsyncHandlerParamDesc(AsyncHandlerDesc::get(Param, RequireName), FD,
Index);
}

/// Print the name of the function with the completion handler, without
/// the completion handler parameter, to \p OS. That is, the name of the
/// async alternative function.
void printAsyncFunctionName(llvm::raw_ostream &OS) const {
if (!Func || Index < 0)
return;

DeclName Name = Func->getName();
OS << Name.getBaseName();

OS << tok::l_paren;
ArrayRef<Identifier> ArgNames = Name.getArgumentNames();
for (size_t I = 0; I < ArgNames.size(); ++I) {
if (I != (size_t)Index) {
OS << ArgNames[I] << tok::colon;
}
}
OS << tok::r_paren;
}

bool operator==(const AsyncHandlerParamDesc &Other) const {
return Handler == Other.Handler && Type == Other.Type &&
HasError == Other.HasError && Index == Other.Index;
Expand Down Expand Up @@ -5554,8 +5583,12 @@ class AsyncConverter : private SourceEntityWalker {
return addCustom(CE->getSourceRange(), [&]() { addHandlerCall(CE); });

if (auto *CE = dyn_cast<CallExpr>(E)) {
// If the refactoring is on the call itself, do not require the callee
// to have the @completionHandlerAsync attribute or a completion-like
// name.
auto HandlerDesc = AsyncHandlerParamDesc::find(
getUnderlyingFunc(CE->getFn()), StartNode.dyn_cast<Expr *>() == CE);
getUnderlyingFunc(CE->getFn()),
/*RequireAttributeOrName=*/StartNode.dyn_cast<Expr *>() != CE);
if (HandlerDesc.isValid())
return addCustom(CE->getSourceRange(),
[&]() { addHoistedCallback(CE, HandlerDesc); });
Expand Down Expand Up @@ -5826,8 +5859,8 @@ class AsyncConverter : private SourceEntityWalker {

// The completion handler that is called as part of the \p CE call.
// This will be called once the async function returns.
auto CompletionHandler = AsyncHandlerDesc::get(CallbackDecl,
/*ignoreName=*/true);
auto CompletionHandler =
AsyncHandlerDesc::get(CallbackDecl, /*RequireAttributeOrName=*/false);
if (CompletionHandler.isValid()) {
if (auto CalledFunc = getUnderlyingFunc(CE->getFn())) {
StringRef HandlerName = Lexer::getCharSourceRangeFromSourceRange(
Expand Down Expand Up @@ -6420,8 +6453,8 @@ bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
if (!CE)
return false;

auto HandlerDesc = AsyncHandlerParamDesc::find(getUnderlyingFunc(CE->getFn()),
/*ignoreName=*/true);
auto HandlerDesc = AsyncHandlerParamDesc::find(
getUnderlyingFunc(CE->getFn()), /*RequireAttributeOrName=*/false);
return HandlerDesc.isValid();
}

Expand Down Expand Up @@ -6478,7 +6511,8 @@ bool RefactoringActionConvertToAsync::performChange() {
assert(FD &&
"Should not run performChange when refactoring is not applicable");

auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*ignoreName=*/true);
auto HandlerDesc =
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
if (!Converter.convert())
return true;
Expand All @@ -6495,7 +6529,8 @@ bool RefactoringActionAddAsyncAlternative::isApplicable(
if (!FD)
return false;

auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*ignoreName=*/true);
auto HandlerDesc =
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
return HandlerDesc.isValid();
}

Expand All @@ -6512,21 +6547,37 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
assert(FD &&
"Should not run performChange when refactoring is not applicable");

auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*ignoreName=*/true);
auto HandlerDesc =
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
assert(HandlerDesc.isValid() &&
"Should not run performChange when refactoring is not applicable");

AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
if (!Converter.convert())
return true;

// Deprecate the synchronous function
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
"@available(*, deprecated, message: \"Prefer async "
"alternative instead\")\n");

if (Ctx.LangOpts.EnableExperimentalConcurrency) {
// Add an attribute to describe its async alternative
llvm::SmallString<0> HandlerAttribute;
llvm::raw_svector_ostream OS(HandlerAttribute);
OS << "@completionHandlerAsync(\"";
HandlerDesc.printAsyncFunctionName(OS);
OS << "\", completionHandlerIndex: " << HandlerDesc.Index << ")\n";
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
HandlerAttribute);
}

AsyncConverter LegacyBodyCreator(TheFile, SM, DiagEngine, FD, HandlerDesc);
if (LegacyBodyCreator.createLegacyBody()) {
LegacyBodyCreator.replace(FD->getBody(), EditConsumer);
}

// Add the async alternative
Converter.insertAfter(FD, EditConsumer);

return false;
Expand Down
48 changes: 48 additions & 0 deletions test/refactoring/ConvertAsync/async_attribute_added.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// RUN: %empty-directory(%t)

// RUN: %refactor-check-compiles -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):1 -enable-experimental-concurrency | %FileCheck -check-prefix=SIMPLE %s
func simple(completion: @escaping (String) -> Void) { }
// SIMPLE: async_attribute_added.swift [[# @LINE-1]]:1 -> [[# @LINE-1]]:1
// SIMPLE-NEXT: @available(*, deprecated, message: "Prefer async alternative instead")
// SIMPLE-EMPTY:
// SIMPLE-NEXT: async_attribute_added.swift [[# @LINE-4]]:1 -> [[# @LINE-4]]:1
// SIMPLE-NEXT: @completionHandlerAsync("simple()", completionHandlerIndex: 0)
// SIMPLE-EMPTY:
// SIMPLE-NEXT: async_attribute_added.swift [[# @LINE-7]]:53 -> [[# @LINE-7]]:56
// SIMPLE-NEXT: {
// SIMPLE-NEXT: async {
// SIMPLE-NEXT: let result = await simple()
// SIMPLE-NEXT: completion(result)
// SIMPLE-NEXT: }
// SIMPLE-NEXT: }
// SIMPLE-EMPTY:
// SIMPLE-NEXT: async_attribute_added.swift [[# @LINE-15]]:56 -> [[# @LINE-15]]:56
// SIMPLE-EMPTY:
// SIMPLE-EMPTY:
// SIMPLE-EMPTY:
// SIMPLE-NEXT: async_attribute_added.swift [[# @LINE-19]]:56 -> [[# @LINE-19]]:56
// SIMPLE-NEXT: func simple() async -> String { }

// RUN: %refactor-check-compiles -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):5 -enable-experimental-concurrency | %FileCheck -check-prefix=OTHER-ARGS %s
func otherArgs(first: Int, second: String, completion: @escaping (String) -> Void) { }
// OTHER-ARGS: @completionHandlerAsync("otherArgs(first:second:)", completionHandlerIndex: 2)

// RUN: %refactor-check-compiles -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):5 -enable-experimental-concurrency | %FileCheck -check-prefix=EMPTY-NAMES %s
func emptyNames(first: Int, _ second: String, completion: @escaping (String) -> Void) { }
// EMPTY-NAMES: @completionHandlerAsync("emptyNames(first:_:)", completionHandlerIndex: 2)

// Not a completion handler named parameter, but should still be converted
// during function conversion since it has been attributed
@completionHandlerAsync("otherName()", completionHandlerIndex: 0)
func otherName(notHandlerName: @escaping (String) -> (Void)) {}
func otherName() async -> String {}

// RUN: %refactor-check-compiles -convert-to-async -dump-text -source-filename %s -pos=%(line+1):5 -enable-experimental-concurrency | %FileCheck -check-prefix=OTHER-CONVERTED %s
func otherStillConverted() {
otherName { str in
print(str)
}
}
// OTHER-CONVERTED: func otherStillConverted() async {
// OTHER-CONVERTED-NEXT: let str = await otherName()
// OTHER-CONVERTED-NEXT: print(str)
8 changes: 7 additions & 1 deletion tools/swift-refactor/swift-refactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ static llvm::cl::opt<bool>
IsNonProtocolType("is-non-protocol-type",
llvm::cl::desc("The symbol being renamed is a type and not a protocol"));

static llvm::cl::opt<bool> EnableExperimentalConcurrency(
"enable-experimental-concurrency",
llvm::cl::desc("Whether to enable experimental concurrency or not"));

enum class DumpType {
REWRITTEN,
JSON,
Expand Down Expand Up @@ -273,7 +277,9 @@ int main(int argc, char *argv[]) {
Invocation.getLangOptions().AttachCommentsToDecls = true;
Invocation.getLangOptions().CollectParsedToken = true;
Invocation.getLangOptions().BuildSyntaxTree = true;
Invocation.getLangOptions().EnableExperimentalConcurrency = true;

if (options::EnableExperimentalConcurrency)
Invocation.getLangOptions().EnableExperimentalConcurrency = true;

for (auto FileName : options::InputFilenames)
Invocation.getFrontendOptions().InputsAndOutputs.addInputFile(FileName);
Expand Down
27 changes: 21 additions & 6 deletions utils/refactor-check-compiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ def parse_args():
formatter_class=argparse.RawDescriptionHelpFormatter,
description="""
A drop-in replacement for a 'swift-refactor -dump-text' call that
1. Checkes that the file still compiles after the refactoring by doing
1. Checks that the file still compiles after the refactoring by doing
'swift-refactor -dump-rewritten' and feeding the result to
'swift-frontend -typecheck'
2. Outputting the result of the 'swift-refactor -dump-text' call

All arguments other than the following will be forwarded to :
All arguments other than the following will be forwarded to
'swift-refactor':
- swift-frontend
- swift-refactor
- temp-dir
- enable-experimental-concurrency (sent to both)
""")

parser.add_argument(
Expand Down Expand Up @@ -61,30 +62,44 @@ def parse_args():
'-pos',
help='The position to invoke the refactoring at'
)
parser.add_argument(
'-enable-experimental-concurrency',
action='store_true',
help='''
Whether to enable experimental concurrency in both swift-refactor and
swift-frontend
'''
)

return parser.parse_known_args()


def main():
(args, unknown_args) = parse_args()
(args, extra_refactor_args) = parse_args()
temp_file_name = os.path.split(args.source_filename)[-1] + '.' + \
args.pos.replace(':', '.')
temp_file_path = os.path.join(args.temp_dir, temp_file_name)

extra_frontend_args = []
if args.enable_experimental_concurrency:
extra_refactor_args.append('-enable-experimental-concurrency')
extra_frontend_args.append('-enable-experimental-concurrency')

# FIXME: `refactor-check-compiles` should generate both `-dump-text` and
# `dump-rewritten` from a single `swift-refactor` invocation (SR-14587).
dump_text_output = run_cmd([
args.swift_refactor,
'-dump-text',
'-source-filename', args.source_filename,
'-pos', args.pos
] + unknown_args, desc='producing edit')
] + extra_refactor_args, desc='producing edit').decode("utf-8")

dump_rewritten_output = run_cmd([
args.swift_refactor,
'-dump-rewritten',
'-source-filename', args.source_filename,
'-pos', args.pos
] + unknown_args, desc='producing rewritten file')
] + extra_refactor_args, desc='producing rewritten file')
with open(temp_file_path, 'wb') as f:
f.write(dump_rewritten_output)

Expand All @@ -93,7 +108,7 @@ def main():
'-typecheck',
temp_file_path,
'-disable-availability-checking'
], desc='checking that rewritten file compiles')
] + extra_frontend_args, desc='checking that rewritten file compiles')
sys.stdout.write(dump_text_output)


Expand Down