Skip to content

Commit 3c0bb38

Browse files
committed
[Refactoring] Add @completionHandlerAsync to sync function
When adding an async alternative, add the @completionHandlerAsync attribute to the sync function. Check for this attribute in addition to the name check, ie. convert a call if the callee has either @completionHandlerAsync or a name that is completion-handler-like name. The addition of the attribute is currently gated behind the experimental concurrency flag. Resolves rdar://77486504
1 parent b765eec commit 3c0bb38

File tree

4 files changed

+142
-22
lines changed

4 files changed

+142
-22
lines changed

lib/IDE/Refactoring.cpp

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4088,7 +4088,7 @@ struct AsyncHandlerDesc {
40884088
HandlerType Type = HandlerType::INVALID;
40894089
bool HasError = false;
40904090

4091-
static AsyncHandlerDesc get(const ValueDecl *Handler, bool ignoreName) {
4091+
static AsyncHandlerDesc get(const ValueDecl *Handler, bool RequireName) {
40924092
AsyncHandlerDesc HandlerDesc;
40934093
if (auto Var = dyn_cast<VarDecl>(Handler)) {
40944094
HandlerDesc.Handler = Var;
@@ -4099,8 +4099,8 @@ struct AsyncHandlerDesc {
40994099
return AsyncHandlerDesc();
41004100
}
41014101

4102-
// Callback must have a completion-like name (if we're not ignoring it)
4103-
if (!ignoreName && !isCompletionHandlerParamName(HandlerDesc.getNameStr()))
4102+
// Callback must have a completion-like name
4103+
if (RequireName && !isCompletionHandlerParamName(HandlerDesc.getNameStr()))
41044104
return AsyncHandlerDesc();
41054105

41064106
// Callback must be a function type and return void. Doesn't need to have
@@ -4349,17 +4349,26 @@ struct AsyncHandlerDesc {
43494349
/// information about that completion handler and its index within the function
43504350
/// declaration.
43514351
struct AsyncHandlerParamDesc : public AsyncHandlerDesc {
4352+
/// The function the completion handler is a parameter of.
4353+
const FuncDecl *Func = nullptr;
43524354
/// The index of the completion handler in the function that declares it.
43534355
int Index = -1;
43544356

43554357
AsyncHandlerParamDesc() : AsyncHandlerDesc() {}
4356-
AsyncHandlerParamDesc(const AsyncHandlerDesc &Handler, int Index)
4357-
: AsyncHandlerDesc(Handler), Index(Index) {}
4358+
AsyncHandlerParamDesc(const AsyncHandlerDesc &Handler, const FuncDecl *Func,
4359+
int Index)
4360+
: AsyncHandlerDesc(Handler), Func(Func), Index(Index) {}
43584361

4359-
static AsyncHandlerParamDesc find(const FuncDecl *FD, bool ignoreName) {
4362+
static AsyncHandlerParamDesc find(const FuncDecl *FD,
4363+
bool RequireAttributeOrName) {
43604364
if (!FD || FD->hasAsync() || FD->hasThrows())
43614365
return AsyncHandlerParamDesc();
43624366

4367+
bool RequireName = RequireAttributeOrName;
4368+
if (RequireAttributeOrName &&
4369+
FD->getAttrs().hasAttribute<CompletionHandlerAsyncAttr>())
4370+
RequireName = false;
4371+
43634372
// Require at least one parameter and void return type
43644373
auto *Params = FD->getParameters();
43654374
if (Params->size() == 0 || !FD->getResultInterfaceType()->isVoid())
@@ -4373,10 +4382,30 @@ struct AsyncHandlerParamDesc : public AsyncHandlerDesc {
43734382
if (Param->isAutoClosure())
43744383
return AsyncHandlerParamDesc();
43754384

4376-
return AsyncHandlerParamDesc(AsyncHandlerDesc::get(Param, ignoreName),
4385+
return AsyncHandlerParamDesc(AsyncHandlerDesc::get(Param, RequireName), FD,
43774386
Index);
43784387
}
43794388

4389+
/// Print the name of the function with the completion handler, without
4390+
/// the completion handler parameter, to \p OS. That is, the name of the
4391+
/// async alternative function.
4392+
void printAsyncFunctionName(llvm::raw_ostream &OS) const {
4393+
if (!Func || Index < 0)
4394+
return;
4395+
4396+
DeclName Name = Func->getName();
4397+
OS << Name.getBaseName();
4398+
4399+
OS << tok::l_paren;
4400+
ArrayRef<Identifier> ArgNames = Name.getArgumentNames();
4401+
for (size_t I = 0; I < ArgNames.size(); ++I) {
4402+
if (I != (size_t)Index) {
4403+
OS << ArgNames[I] << tok::colon;
4404+
}
4405+
}
4406+
OS << tok::r_paren;
4407+
}
4408+
43804409
bool operator==(const AsyncHandlerParamDesc &Other) const {
43814410
return Handler == Other.Handler && Type == Other.Type &&
43824411
HasError == Other.HasError && Index == Other.Index;
@@ -5554,8 +5583,12 @@ class AsyncConverter : private SourceEntityWalker {
55545583
return addCustom(CE->getSourceRange(), [&]() { addHandlerCall(CE); });
55555584

55565585
if (auto *CE = dyn_cast<CallExpr>(E)) {
5586+
// If the refactoring is on the call itself, do not require the callee
5587+
// to have the @completionHandlerAsync attribute or a completion-like
5588+
// name.
55575589
auto HandlerDesc = AsyncHandlerParamDesc::find(
5558-
getUnderlyingFunc(CE->getFn()), StartNode.dyn_cast<Expr *>() == CE);
5590+
getUnderlyingFunc(CE->getFn()),
5591+
/*RequireAttributeOrName=*/StartNode.dyn_cast<Expr *>() != CE);
55595592
if (HandlerDesc.isValid())
55605593
return addCustom(CE->getSourceRange(),
55615594
[&]() { addHoistedCallback(CE, HandlerDesc); });
@@ -5826,8 +5859,8 @@ class AsyncConverter : private SourceEntityWalker {
58265859

58275860
// The completion handler that is called as part of the \p CE call.
58285861
// This will be called once the async function returns.
5829-
auto CompletionHandler = AsyncHandlerDesc::get(CallbackDecl,
5830-
/*ignoreName=*/true);
5862+
auto CompletionHandler =
5863+
AsyncHandlerDesc::get(CallbackDecl, /*RequireAttributeOrName=*/false);
58315864
if (CompletionHandler.isValid()) {
58325865
if (auto CalledFunc = getUnderlyingFunc(CE->getFn())) {
58335866
StringRef HandlerName = Lexer::getCharSourceRangeFromSourceRange(
@@ -6420,8 +6453,8 @@ bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
64206453
if (!CE)
64216454
return false;
64226455

6423-
auto HandlerDesc = AsyncHandlerParamDesc::find(getUnderlyingFunc(CE->getFn()),
6424-
/*ignoreName=*/true);
6456+
auto HandlerDesc = AsyncHandlerParamDesc::find(
6457+
getUnderlyingFunc(CE->getFn()), /*RequireAttributeOrName=*/false);
64256458
return HandlerDesc.isValid();
64266459
}
64276460

@@ -6478,7 +6511,8 @@ bool RefactoringActionConvertToAsync::performChange() {
64786511
assert(FD &&
64796512
"Should not run performChange when refactoring is not applicable");
64806513

6481-
auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*ignoreName=*/true);
6514+
auto HandlerDesc =
6515+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
64826516
AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
64836517
if (!Converter.convert())
64846518
return true;
@@ -6495,7 +6529,8 @@ bool RefactoringActionAddAsyncAlternative::isApplicable(
64956529
if (!FD)
64966530
return false;
64976531

6498-
auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*ignoreName=*/true);
6532+
auto HandlerDesc =
6533+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
64996534
return HandlerDesc.isValid();
65006535
}
65016536

@@ -6512,21 +6547,37 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
65126547
assert(FD &&
65136548
"Should not run performChange when refactoring is not applicable");
65146549

6515-
auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*ignoreName=*/true);
6550+
auto HandlerDesc =
6551+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
65166552
assert(HandlerDesc.isValid() &&
65176553
"Should not run performChange when refactoring is not applicable");
65186554

65196555
AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
65206556
if (!Converter.convert())
65216557
return true;
65226558

6559+
// Deprecate the synchronous function
65236560
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
65246561
"@available(*, deprecated, message: \"Prefer async "
65256562
"alternative instead\")\n");
6563+
6564+
if (Ctx.LangOpts.EnableExperimentalConcurrency) {
6565+
// Add an attribute to describe its async alternative
6566+
llvm::SmallString<0> HandlerAttribute;
6567+
llvm::raw_svector_ostream OS(HandlerAttribute);
6568+
OS << "@completionHandlerAsync(\"";
6569+
HandlerDesc.printAsyncFunctionName(OS);
6570+
OS << "\", completionHandlerIndex: " << HandlerDesc.Index << ")\n";
6571+
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
6572+
HandlerAttribute);
6573+
}
6574+
65266575
AsyncConverter LegacyBodyCreator(TheFile, SM, DiagEngine, FD, HandlerDesc);
65276576
if (LegacyBodyCreator.createLegacyBody()) {
65286577
LegacyBodyCreator.replace(FD->getBody(), EditConsumer);
65296578
}
6579+
6580+
// Add the async alternative
65306581
Converter.insertAfter(FD, EditConsumer);
65316582

65326583
return false;
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: %empty-directory(%t)
2+
3+
// RUN: %refactor-check-compiles -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):1 -enable-experimental-concurrency | %FileCheck -check-prefix=SIMPLE %s
4+
func simple(completion: @escaping (String) -> Void) { }
5+
// SIMPLE: async_attribute_added.swift [[# @LINE-1]]:1 -> [[# @LINE-1]]:1
6+
// SIMPLE-NEXT: @available(*, deprecated, message: "Prefer async alternative instead")
7+
// SIMPLE-EMPTY:
8+
// SIMPLE-NEXT: async_attribute_added.swift [[# @LINE-4]]:1 -> [[# @LINE-4]]:1
9+
// SIMPLE-NEXT: @completionHandlerAsync("simple()", completionHandlerIndex: 0)
10+
// SIMPLE-EMPTY:
11+
// SIMPLE-NEXT: async_attribute_added.swift [[# @LINE-7]]:53 -> [[# @LINE-7]]:56
12+
// SIMPLE-NEXT: {
13+
// SIMPLE-NEXT: async {
14+
// SIMPLE-NEXT: let result = await simple()
15+
// SIMPLE-NEXT: completion(result)
16+
// SIMPLE-NEXT: }
17+
// SIMPLE-NEXT: }
18+
// SIMPLE-EMPTY:
19+
// SIMPLE-NEXT: async_attribute_added.swift [[# @LINE-15]]:56 -> [[# @LINE-15]]:56
20+
// SIMPLE-EMPTY:
21+
// SIMPLE-EMPTY:
22+
// SIMPLE-EMPTY:
23+
// SIMPLE-NEXT: async_attribute_added.swift [[# @LINE-19]]:56 -> [[# @LINE-19]]:56
24+
// SIMPLE-NEXT: func simple() async -> String { }
25+
26+
// RUN: %refactor-check-compiles -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):5 -enable-experimental-concurrency | %FileCheck -check-prefix=OTHER-ARGS %s
27+
func otherArgs(first: Int, second: String, completion: @escaping (String) -> Void) { }
28+
// OTHER-ARGS: @completionHandlerAsync("otherArgs(first:second:)", completionHandlerIndex: 2)
29+
30+
// RUN: %refactor-check-compiles -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):5 -enable-experimental-concurrency | %FileCheck -check-prefix=EMPTY-NAMES %s
31+
func emptyNames(first: Int, _ second: String, completion: @escaping (String) -> Void) { }
32+
// EMPTY-NAMES: @completionHandlerAsync("emptyNames(first:_:)", completionHandlerIndex: 2)
33+
34+
// Not a completion handler named parameter, but should still be converted
35+
// during function conversion since it has been attributed
36+
@completionHandlerAsync("otherName()", completionHandlerIndex: 0)
37+
func otherName(notHandlerName: @escaping (String) -> (Void)) {}
38+
func otherName() async -> String {}
39+
40+
// RUN: %refactor-check-compiles -convert-to-async -dump-text -source-filename %s -pos=%(line+1):5 -enable-experimental-concurrency | %FileCheck -check-prefix=OTHER-CONVERTED %s
41+
func otherStillConverted() {
42+
otherName { str in
43+
print(str)
44+
}
45+
}
46+
// OTHER-CONVERTED: func otherStillConverted() async {
47+
// OTHER-CONVERTED-NEXT: let str = await otherName()
48+
// OTHER-CONVERTED-NEXT: print(str)

tools/swift-refactor/swift-refactor.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ static llvm::cl::opt<bool>
118118
IsNonProtocolType("is-non-protocol-type",
119119
llvm::cl::desc("The symbol being renamed is a type and not a protocol"));
120120

121+
static llvm::cl::opt<bool> EnableExperimentalConcurrency(
122+
"enable-experimental-concurrency",
123+
llvm::cl::desc("Whether to enable experimental concurrency or not"));
124+
121125
enum class DumpType {
122126
REWRITTEN,
123127
JSON,
@@ -273,7 +277,9 @@ int main(int argc, char *argv[]) {
273277
Invocation.getLangOptions().AttachCommentsToDecls = true;
274278
Invocation.getLangOptions().CollectParsedToken = true;
275279
Invocation.getLangOptions().BuildSyntaxTree = true;
276-
Invocation.getLangOptions().EnableExperimentalConcurrency = true;
280+
281+
if (options::EnableExperimentalConcurrency)
282+
Invocation.getLangOptions().EnableExperimentalConcurrency = true;
277283

278284
for (auto FileName : options::InputFilenames)
279285
Invocation.getFrontendOptions().InputsAndOutputs.addInputFile(FileName);

utils/refactor-check-compiles.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,17 @@ def parse_args():
2121
formatter_class=argparse.RawDescriptionHelpFormatter,
2222
description="""
2323
A drop-in replacement for a 'swift-refactor -dump-text' call that
24-
1. Checkes that the file still compiles after the refactoring by doing
24+
1. Checks that the file still compiles after the refactoring by doing
2525
'swift-refactor -dump-rewritten' and feeding the result to
2626
'swift-frontend -typecheck'
2727
2. Outputting the result of the 'swift-refactor -dump-text' call
2828
29-
All arguments other than the following will be forwarded to :
29+
All arguments other than the following will be forwarded to
3030
'swift-refactor':
3131
- swift-frontend
3232
- swift-refactor
3333
- temp-dir
34+
- enable-experimental-concurrency (sent to both)
3435
""")
3536

3637
parser.add_argument(
@@ -61,30 +62,44 @@ def parse_args():
6162
'-pos',
6263
help='The position to invoke the refactoring at'
6364
)
65+
parser.add_argument(
66+
'-enable-experimental-concurrency',
67+
action='store_true',
68+
help='''
69+
Whether to enable experimental concurrency in both swift-refactor and
70+
swift-frontend
71+
'''
72+
)
6473

6574
return parser.parse_known_args()
6675

6776

6877
def main():
69-
(args, unknown_args) = parse_args()
78+
(args, extra_refactor_args) = parse_args()
7079
temp_file_name = os.path.split(args.source_filename)[-1] + '.' + \
7180
args.pos.replace(':', '.')
7281
temp_file_path = os.path.join(args.temp_dir, temp_file_name)
82+
83+
extra_frontend_args = []
84+
if args.enable_experimental_concurrency:
85+
extra_refactor_args.append('-enable-experimental-concurrency')
86+
extra_frontend_args.append('-enable-experimental-concurrency')
87+
7388
# FIXME: `refactor-check-compiles` should generate both `-dump-text` and
7489
# `dump-rewritten` from a single `swift-refactor` invocation (SR-14587).
7590
dump_text_output = run_cmd([
7691
args.swift_refactor,
7792
'-dump-text',
7893
'-source-filename', args.source_filename,
7994
'-pos', args.pos
80-
] + unknown_args, desc='producing edit').decode("utf-8")
95+
] + extra_refactor_args, desc='producing edit').decode("utf-8")
8196

8297
dump_rewritten_output = run_cmd([
8398
args.swift_refactor,
8499
'-dump-rewritten',
85100
'-source-filename', args.source_filename,
86101
'-pos', args.pos
87-
] + unknown_args, desc='producing rewritten file')
102+
] + extra_refactor_args, desc='producing rewritten file')
88103
with open(temp_file_path, 'wb') as f:
89104
f.write(dump_rewritten_output)
90105

@@ -93,7 +108,7 @@ def main():
93108
'-typecheck',
94109
temp_file_path,
95110
'-disable-availability-checking'
96-
], desc='checking that rewritten file compiles')
111+
] + extra_frontend_args, desc='checking that rewritten file compiles')
97112
sys.stdout.write(dump_text_output)
98113

99114

0 commit comments

Comments
 (0)