Skip to content

Commit b0bade6

Browse files
authored
Merge pull request #37417 from bnbarham/add-completionhandler-attribute
[Refactoring] Add @completionHandlerAsync to sync function
2 parents af4ebb4 + 762337c commit b0bade6

File tree

4 files changed

+142
-22
lines changed

4 files changed

+142
-22
lines changed

lib/IDE/Refactoring.cpp

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4079,7 +4079,7 @@ struct AsyncHandlerDesc {
40794079
HandlerType Type = HandlerType::INVALID;
40804080
bool HasError = false;
40814081

4082-
static AsyncHandlerDesc get(const ValueDecl *Handler, bool ignoreName) {
4082+
static AsyncHandlerDesc get(const ValueDecl *Handler, bool RequireName) {
40834083
AsyncHandlerDesc HandlerDesc;
40844084
if (auto Var = dyn_cast<VarDecl>(Handler)) {
40854085
HandlerDesc.Handler = Var;
@@ -4090,8 +4090,8 @@ struct AsyncHandlerDesc {
40904090
return AsyncHandlerDesc();
40914091
}
40924092

4093-
// Callback must have a completion-like name (if we're not ignoring it)
4094-
if (!ignoreName && !isCompletionHandlerParamName(HandlerDesc.getNameStr()))
4093+
// Callback must have a completion-like name
4094+
if (RequireName && !isCompletionHandlerParamName(HandlerDesc.getNameStr()))
40954095
return AsyncHandlerDesc();
40964096

40974097
// Callback must be a function type and return void. Doesn't need to have
@@ -4340,17 +4340,26 @@ struct AsyncHandlerDesc {
43404340
/// information about that completion handler and its index within the function
43414341
/// declaration.
43424342
struct AsyncHandlerParamDesc : public AsyncHandlerDesc {
4343+
/// The function the completion handler is a parameter of.
4344+
const FuncDecl *Func = nullptr;
43434345
/// The index of the completion handler in the function that declares it.
43444346
int Index = -1;
43454347

43464348
AsyncHandlerParamDesc() : AsyncHandlerDesc() {}
4347-
AsyncHandlerParamDesc(const AsyncHandlerDesc &Handler, int Index)
4348-
: AsyncHandlerDesc(Handler), Index(Index) {}
4349+
AsyncHandlerParamDesc(const AsyncHandlerDesc &Handler, const FuncDecl *Func,
4350+
int Index)
4351+
: AsyncHandlerDesc(Handler), Func(Func), Index(Index) {}
43494352

4350-
static AsyncHandlerParamDesc find(const FuncDecl *FD, bool ignoreName) {
4353+
static AsyncHandlerParamDesc find(const FuncDecl *FD,
4354+
bool RequireAttributeOrName) {
43514355
if (!FD || FD->hasAsync() || FD->hasThrows())
43524356
return AsyncHandlerParamDesc();
43534357

4358+
bool RequireName = RequireAttributeOrName;
4359+
if (RequireAttributeOrName &&
4360+
FD->getAttrs().hasAttribute<CompletionHandlerAsyncAttr>())
4361+
RequireName = false;
4362+
43544363
// Require at least one parameter and void return type
43554364
auto *Params = FD->getParameters();
43564365
if (Params->size() == 0 || !FD->getResultInterfaceType()->isVoid())
@@ -4364,10 +4373,30 @@ struct AsyncHandlerParamDesc : public AsyncHandlerDesc {
43644373
if (Param->isAutoClosure())
43654374
return AsyncHandlerParamDesc();
43664375

4367-
return AsyncHandlerParamDesc(AsyncHandlerDesc::get(Param, ignoreName),
4376+
return AsyncHandlerParamDesc(AsyncHandlerDesc::get(Param, RequireName), FD,
43684377
Index);
43694378
}
43704379

4380+
/// Print the name of the function with the completion handler, without
4381+
/// the completion handler parameter, to \p OS. That is, the name of the
4382+
/// async alternative function.
4383+
void printAsyncFunctionName(llvm::raw_ostream &OS) const {
4384+
if (!Func || Index < 0)
4385+
return;
4386+
4387+
DeclName Name = Func->getName();
4388+
OS << Name.getBaseName();
4389+
4390+
OS << tok::l_paren;
4391+
ArrayRef<Identifier> ArgNames = Name.getArgumentNames();
4392+
for (size_t I = 0; I < ArgNames.size(); ++I) {
4393+
if (I != (size_t)Index) {
4394+
OS << ArgNames[I] << tok::colon;
4395+
}
4396+
}
4397+
OS << tok::r_paren;
4398+
}
4399+
43714400
bool operator==(const AsyncHandlerParamDesc &Other) const {
43724401
return Handler == Other.Handler && Type == Other.Type &&
43734402
HasError == Other.HasError && Index == Other.Index;
@@ -5554,8 +5583,12 @@ class AsyncConverter : private SourceEntityWalker {
55545583
return addCustom(CE->getSourceRange(), [&]() { addHandlerCall(CE); });
55555584

55565585
if (auto *CE = dyn_cast<CallExpr>(E)) {
5586+
// If the refactoring is on the call itself, do not require the callee
5587+
// to have the @completionHandlerAsync attribute or a completion-like
5588+
// name.
55575589
auto HandlerDesc = AsyncHandlerParamDesc::find(
5558-
getUnderlyingFunc(CE->getFn()), StartNode.dyn_cast<Expr *>() == CE);
5590+
getUnderlyingFunc(CE->getFn()),
5591+
/*RequireAttributeOrName=*/StartNode.dyn_cast<Expr *>() != CE);
55595592
if (HandlerDesc.isValid())
55605593
return addCustom(CE->getSourceRange(),
55615594
[&]() { addHoistedCallback(CE, HandlerDesc); });
@@ -5836,8 +5869,8 @@ class AsyncConverter : private SourceEntityWalker {
58365869

58375870
// The completion handler that is called as part of the \p CE call.
58385871
// This will be called once the async function returns.
5839-
auto CompletionHandler = AsyncHandlerDesc::get(CallbackDecl,
5840-
/*ignoreName=*/true);
5872+
auto CompletionHandler =
5873+
AsyncHandlerDesc::get(CallbackDecl, /*RequireAttributeOrName=*/false);
58415874
if (CompletionHandler.isValid()) {
58425875
if (auto CalledFunc = getUnderlyingFunc(CE->getFn())) {
58435876
StringRef HandlerName = Lexer::getCharSourceRangeFromSourceRange(
@@ -6430,8 +6463,8 @@ bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
64306463
if (!CE)
64316464
return false;
64326465

6433-
auto HandlerDesc = AsyncHandlerParamDesc::find(getUnderlyingFunc(CE->getFn()),
6434-
/*ignoreName=*/true);
6466+
auto HandlerDesc = AsyncHandlerParamDesc::find(
6467+
getUnderlyingFunc(CE->getFn()), /*RequireAttributeOrName=*/false);
64356468
return HandlerDesc.isValid();
64366469
}
64376470

@@ -6488,7 +6521,8 @@ bool RefactoringActionConvertToAsync::performChange() {
64886521
assert(FD &&
64896522
"Should not run performChange when refactoring is not applicable");
64906523

6491-
auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*ignoreName=*/true);
6524+
auto HandlerDesc =
6525+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
64926526
AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
64936527
if (!Converter.convert())
64946528
return true;
@@ -6505,7 +6539,8 @@ bool RefactoringActionAddAsyncAlternative::isApplicable(
65056539
if (!FD)
65066540
return false;
65076541

6508-
auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*ignoreName=*/true);
6542+
auto HandlerDesc =
6543+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
65096544
return HandlerDesc.isValid();
65106545
}
65116546

@@ -6522,21 +6557,37 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
65226557
assert(FD &&
65236558
"Should not run performChange when refactoring is not applicable");
65246559

6525-
auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*ignoreName=*/true);
6560+
auto HandlerDesc =
6561+
AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false);
65266562
assert(HandlerDesc.isValid() &&
65276563
"Should not run performChange when refactoring is not applicable");
65286564

65296565
AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc);
65306566
if (!Converter.convert())
65316567
return true;
65326568

6569+
// Deprecate the synchronous function
65336570
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
65346571
"@available(*, deprecated, message: \"Prefer async "
65356572
"alternative instead\")\n");
6573+
6574+
if (Ctx.LangOpts.EnableExperimentalConcurrency) {
6575+
// Add an attribute to describe its async alternative
6576+
llvm::SmallString<0> HandlerAttribute;
6577+
llvm::raw_svector_ostream OS(HandlerAttribute);
6578+
OS << "@completionHandlerAsync(\"";
6579+
HandlerDesc.printAsyncFunctionName(OS);
6580+
OS << "\", completionHandlerIndex: " << HandlerDesc.Index << ")\n";
6581+
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
6582+
HandlerAttribute);
6583+
}
6584+
65366585
AsyncConverter LegacyBodyCreator(TheFile, SM, DiagEngine, FD, HandlerDesc);
65376586
if (LegacyBodyCreator.createLegacyBody()) {
65386587
LegacyBodyCreator.replace(FD->getBody(), EditConsumer);
65396588
}
6589+
6590+
// Add the async alternative
65406591
Converter.insertAfter(FD, EditConsumer);
65416592

65426593
return false;
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: %empty-directory(%t)
2+
3+
// RUN: %refactor-check-compiles -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):1 -enable-experimental-concurrency | %FileCheck -check-prefix=SIMPLE %s
4+
func simple(completion: @escaping (String) -> Void) { }
5+
// SIMPLE: async_attribute_added.swift [[# @LINE-1]]:1 -> [[# @LINE-1]]:1
6+
// SIMPLE-NEXT: @available(*, deprecated, message: "Prefer async alternative instead")
7+
// SIMPLE-EMPTY:
8+
// SIMPLE-NEXT: async_attribute_added.swift [[# @LINE-4]]:1 -> [[# @LINE-4]]:1
9+
// SIMPLE-NEXT: @completionHandlerAsync("simple()", completionHandlerIndex: 0)
10+
// SIMPLE-EMPTY:
11+
// SIMPLE-NEXT: async_attribute_added.swift [[# @LINE-7]]:53 -> [[# @LINE-7]]:56
12+
// SIMPLE-NEXT: {
13+
// SIMPLE-NEXT: async {
14+
// SIMPLE-NEXT: let result = await simple()
15+
// SIMPLE-NEXT: completion(result)
16+
// SIMPLE-NEXT: }
17+
// SIMPLE-NEXT: }
18+
// SIMPLE-EMPTY:
19+
// SIMPLE-NEXT: async_attribute_added.swift [[# @LINE-15]]:56 -> [[# @LINE-15]]:56
20+
// SIMPLE-EMPTY:
21+
// SIMPLE-EMPTY:
22+
// SIMPLE-EMPTY:
23+
// SIMPLE-NEXT: async_attribute_added.swift [[# @LINE-19]]:56 -> [[# @LINE-19]]:56
24+
// SIMPLE-NEXT: func simple() async -> String { }
25+
26+
// RUN: %refactor-check-compiles -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):5 -enable-experimental-concurrency | %FileCheck -check-prefix=OTHER-ARGS %s
27+
func otherArgs(first: Int, second: String, completion: @escaping (String) -> Void) { }
28+
// OTHER-ARGS: @completionHandlerAsync("otherArgs(first:second:)", completionHandlerIndex: 2)
29+
30+
// RUN: %refactor-check-compiles -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):5 -enable-experimental-concurrency | %FileCheck -check-prefix=EMPTY-NAMES %s
31+
func emptyNames(first: Int, _ second: String, completion: @escaping (String) -> Void) { }
32+
// EMPTY-NAMES: @completionHandlerAsync("emptyNames(first:_:)", completionHandlerIndex: 2)
33+
34+
// Not a completion handler named parameter, but should still be converted
35+
// during function conversion since it has been attributed
36+
@completionHandlerAsync("otherName()", completionHandlerIndex: 0)
37+
func otherName(notHandlerName: @escaping (String) -> (Void)) {}
38+
func otherName() async -> String {}
39+
40+
// RUN: %refactor-check-compiles -convert-to-async -dump-text -source-filename %s -pos=%(line+1):5 -enable-experimental-concurrency | %FileCheck -check-prefix=OTHER-CONVERTED %s
41+
func otherStillConverted() {
42+
otherName { str in
43+
print(str)
44+
}
45+
}
46+
// OTHER-CONVERTED: func otherStillConverted() async {
47+
// OTHER-CONVERTED-NEXT: let str = await otherName()
48+
// OTHER-CONVERTED-NEXT: print(str)

tools/swift-refactor/swift-refactor.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ static llvm::cl::opt<bool>
118118
IsNonProtocolType("is-non-protocol-type",
119119
llvm::cl::desc("The symbol being renamed is a type and not a protocol"));
120120

121+
static llvm::cl::opt<bool> EnableExperimentalConcurrency(
122+
"enable-experimental-concurrency",
123+
llvm::cl::desc("Whether to enable experimental concurrency or not"));
124+
121125
enum class DumpType {
122126
REWRITTEN,
123127
JSON,
@@ -273,7 +277,9 @@ int main(int argc, char *argv[]) {
273277
Invocation.getLangOptions().AttachCommentsToDecls = true;
274278
Invocation.getLangOptions().CollectParsedToken = true;
275279
Invocation.getLangOptions().BuildSyntaxTree = true;
276-
Invocation.getLangOptions().EnableExperimentalConcurrency = true;
280+
281+
if (options::EnableExperimentalConcurrency)
282+
Invocation.getLangOptions().EnableExperimentalConcurrency = true;
277283

278284
for (auto FileName : options::InputFilenames)
279285
Invocation.getFrontendOptions().InputsAndOutputs.addInputFile(FileName);

utils/refactor-check-compiles.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,17 @@ def parse_args():
2121
formatter_class=argparse.RawDescriptionHelpFormatter,
2222
description="""
2323
A drop-in replacement for a 'swift-refactor -dump-text' call that
24-
1. Checkes that the file still compiles after the refactoring by doing
24+
1. Checks that the file still compiles after the refactoring by doing
2525
'swift-refactor -dump-rewritten' and feeding the result to
2626
'swift-frontend -typecheck'
2727
2. Outputting the result of the 'swift-refactor -dump-text' call
2828
29-
All arguments other than the following will be forwarded to :
29+
All arguments other than the following will be forwarded to
3030
'swift-refactor':
3131
- swift-frontend
3232
- swift-refactor
3333
- temp-dir
34+
- enable-experimental-concurrency (sent to both)
3435
""")
3536

3637
parser.add_argument(
@@ -61,30 +62,44 @@ def parse_args():
6162
'-pos',
6263
help='The position to invoke the refactoring at'
6364
)
65+
parser.add_argument(
66+
'-enable-experimental-concurrency',
67+
action='store_true',
68+
help='''
69+
Whether to enable experimental concurrency in both swift-refactor and
70+
swift-frontend
71+
'''
72+
)
6473

6574
return parser.parse_known_args()
6675

6776

6877
def main():
69-
(args, unknown_args) = parse_args()
78+
(args, extra_refactor_args) = parse_args()
7079
temp_file_name = os.path.split(args.source_filename)[-1] + '.' + \
7180
args.pos.replace(':', '.')
7281
temp_file_path = os.path.join(args.temp_dir, temp_file_name)
82+
83+
extra_frontend_args = []
84+
if args.enable_experimental_concurrency:
85+
extra_refactor_args.append('-enable-experimental-concurrency')
86+
extra_frontend_args.append('-enable-experimental-concurrency')
87+
7388
# FIXME: `refactor-check-compiles` should generate both `-dump-text` and
7489
# `dump-rewritten` from a single `swift-refactor` invocation (SR-14587).
7590
dump_text_output = run_cmd([
7691
args.swift_refactor,
7792
'-dump-text',
7893
'-source-filename', args.source_filename,
7994
'-pos', args.pos
80-
] + unknown_args, desc='producing edit').decode("utf-8")
95+
] + extra_refactor_args, desc='producing edit').decode("utf-8")
8196

8297
dump_rewritten_output = run_cmd([
8398
args.swift_refactor,
8499
'-dump-rewritten',
85100
'-source-filename', args.source_filename,
86101
'-pos', args.pos
87-
] + unknown_args, desc='producing rewritten file')
102+
] + extra_refactor_args, desc='producing rewritten file')
88103
with open(temp_file_path, 'wb') as f:
89104
f.write(dump_rewritten_output)
90105

@@ -93,7 +108,7 @@ def main():
93108
'-typecheck',
94109
temp_file_path,
95110
'-disable-availability-checking'
96-
], desc='checking that rewritten file compiles')
111+
] + extra_frontend_args, desc='checking that rewritten file compiles')
97112
sys.stdout.write(dump_text_output)
98113

99114

0 commit comments

Comments
 (0)