@@ -4204,6 +4204,13 @@ struct AsyncHandlerDesc {
4204
4204
return params ();
4205
4205
}
4206
4206
4207
+ // / If the completion handler has an Error parameter, return it.
4208
+ Optional<AnyFunctionType::Param> getErrorParam () const {
4209
+ if (HasError && Type == HandlerType::PARAMS)
4210
+ return params ().back ();
4211
+ return None;
4212
+ }
4213
+
4207
4214
// / Get the type of the error that will be thrown by the \c async method or \c
4208
4215
// / None if the completion handler doesn't accept an error parameter.
4209
4216
// / This may be more specialized than the generic 'Error' type if the
@@ -5405,6 +5412,41 @@ class AsyncConverter : private SourceEntityWalker {
5405
5412
return true ;
5406
5413
}
5407
5414
5415
+ // / Creates an async alternative function that forwards onto the completion
5416
+ // / handler function through
5417
+ // / withCheckedContinuation/withCheckedThrowingContinuation.
5418
+ bool createAsyncWrapper () {
5419
+ assert (Buffer.empty () && " AsyncConverter can only be used once" );
5420
+ auto *FD = cast<FuncDecl>(StartNode.get <Decl *>());
5421
+
5422
+ // First add the new async function declaration.
5423
+ addFuncDecl (FD);
5424
+ OS << tok::l_brace << " \n " ;
5425
+
5426
+ // Then add the body.
5427
+ OS << tok::kw_return << " " ;
5428
+ if (TopHandler.HasError )
5429
+ OS << tok::kw_try << " " ;
5430
+
5431
+ OS << " await " ;
5432
+
5433
+ // withChecked[Throwing]Continuation { cont in
5434
+ if (TopHandler.HasError ) {
5435
+ OS << " withCheckedThrowingContinuation" ;
5436
+ } else {
5437
+ OS << " withCheckedContinuation" ;
5438
+ }
5439
+ OS << " " << tok::l_brace << " cont " << tok::kw_in << " \n " ;
5440
+
5441
+ // fnWithHandler(args...) { ... }
5442
+ auto ClosureStr = getAsyncWrapperCompletionClosure (" cont" , TopHandler);
5443
+ addForwardingCallTo (FD, TopHandler, /* HandlerReplacement*/ ClosureStr);
5444
+
5445
+ OS << tok::r_brace << " \n " ; // end continuation closure
5446
+ OS << tok::r_brace << " \n " ; // end function body
5447
+ return true ;
5448
+ }
5449
+
5408
5450
void replace (ASTNode Node, SourceEditConsumer &EditConsumer,
5409
5451
SourceLoc StartOverride = SourceLoc()) {
5410
5452
SourceRange Range = Node.getSourceRange ();
@@ -5454,6 +5496,116 @@ class AsyncConverter : private SourceEntityWalker {
5454
5496
OS << tok::r_paren;
5455
5497
}
5456
5498
5499
+ // / Retrieve the completion handler closure argument for an async wrapper
5500
+ // / function.
5501
+ std::string
5502
+ getAsyncWrapperCompletionClosure (StringRef ContName,
5503
+ const AsyncHandlerParamDesc &HandlerDesc) {
5504
+ std::string OutputStr;
5505
+ llvm::raw_string_ostream OS (OutputStr);
5506
+
5507
+ OS << " " << tok::l_brace; // start closure
5508
+
5509
+ // Prepare parameter names for the closure.
5510
+ auto SuccessParams = HandlerDesc.getSuccessParams ();
5511
+ SmallVector<SmallString<4 >, 2 > SuccessParamNames;
5512
+ for (auto idx : indices (SuccessParams)) {
5513
+ SuccessParamNames.emplace_back (" res" );
5514
+
5515
+ // If we have multiple success params, number them e.g res1, res2...
5516
+ if (SuccessParams.size () > 1 )
5517
+ SuccessParamNames.back ().append (std::to_string (idx + 1 ));
5518
+ }
5519
+ Optional<SmallString<4 >> ErrName;
5520
+ if (HandlerDesc.getErrorParam ())
5521
+ ErrName.emplace (" err" );
5522
+
5523
+ auto HasAnyParams = !SuccessParamNames.empty () || ErrName;
5524
+ if (HasAnyParams)
5525
+ OS << " " ;
5526
+
5527
+ // res1, res2
5528
+ llvm::interleave (
5529
+ SuccessParamNames, [&](auto Name) { OS << Name; },
5530
+ [&]() { OS << tok::comma << " " ; });
5531
+
5532
+ // , err
5533
+ if (ErrName) {
5534
+ if (!SuccessParamNames.empty ())
5535
+ OS << tok::comma << " " ;
5536
+
5537
+ OS << *ErrName;
5538
+ }
5539
+ if (HasAnyParams)
5540
+ OS << " " << tok::kw_in;
5541
+
5542
+ OS << " \n " ;
5543
+
5544
+ // The closure body.
5545
+ switch (HandlerDesc.Type ) {
5546
+ case HandlerType::PARAMS: {
5547
+ // For a (Success?, Error?) -> Void handler, we do an if let on the error.
5548
+ if (ErrName) {
5549
+ // if let err = err {
5550
+ OS << tok::kw_if << " " << tok::kw_let << " " ;
5551
+ OS << *ErrName << " " << tok::equal << " " << *ErrName << " " ;
5552
+ OS << tok::l_brace << " \n " ;
5553
+
5554
+ // cont.resume(throwing: err)
5555
+ OS << ContName << tok::period << " resume" << tok::l_paren;
5556
+ OS << " throwing" << tok::colon << " " << *ErrName;
5557
+ OS << tok::r_paren << " \n " ;
5558
+
5559
+ // return }
5560
+ OS << tok::kw_return << " \n " ;
5561
+ OS << tok::r_brace << " \n " ;
5562
+ }
5563
+
5564
+ // If we have any success params that we need to unwrap, insert a guard.
5565
+ for (auto Idx : indices (SuccessParamNames)) {
5566
+ auto &Name = SuccessParamNames[Idx];
5567
+ auto ParamTy = SuccessParams[Idx].getParameterType ();
5568
+ if (!HandlerDesc.shouldUnwrap (ParamTy))
5569
+ continue ;
5570
+
5571
+ // guard let res = res else {
5572
+ OS << tok::kw_guard << " " << tok::kw_let << " " ;
5573
+ OS << Name << " " << tok::equal << " " << Name << " " << tok::kw_else;
5574
+ OS << " " << tok::l_brace << " \n " ;
5575
+
5576
+ // fatalError(...)
5577
+ OS << " fatalError" << tok::l_paren;
5578
+ OS << " \" Expected non-nil success param '" << Name;
5579
+ OS << " ' for nil error\" " ;
5580
+ OS << tok::r_paren << " \n " ;
5581
+
5582
+ // End guard.
5583
+ OS << tok::r_brace << " \n " ;
5584
+ }
5585
+
5586
+ // cont.resume(returning: (res1, res2, ...))
5587
+ OS << ContName << tok::period << " resume" << tok::l_paren;
5588
+ OS << " returning" << tok::colon << " " ;
5589
+ addTupleOf (llvm::makeArrayRef (SuccessParamNames), OS,
5590
+ [&](auto Ref) { OS << Ref; });
5591
+ OS << tok::r_paren << " \n " ;
5592
+ break ;
5593
+ }
5594
+ case HandlerType::RESULT: {
5595
+ // cont.resume(with: res)
5596
+ assert (SuccessParamNames.size () == 1 );
5597
+ OS << ContName << tok::period << " resume" << tok::l_paren;
5598
+ OS << " with" << tok::colon << " " << SuccessParamNames[0 ];
5599
+ OS << tok::r_paren << " \n " ;
5600
+ break ;
5601
+ }
5602
+ case HandlerType::INVALID:
5603
+ llvm_unreachable (" Should not have an invalid handler here" );
5604
+ }
5605
+
5606
+ OS << tok::r_brace << " \n " ; // end closure
5607
+ return OutputStr;
5608
+ }
5457
5609
5458
5610
// / Retrieves the location for the start of a comment attached to the token
5459
5611
// / at the provided location, or the location itself if there is no comment.
@@ -6480,6 +6632,24 @@ class AsyncConverter : private SourceEntityWalker {
6480
6632
}
6481
6633
};
6482
6634
6635
+ // / Adds an attribute to describe a completion handler function's async
6636
+ // / alternative if necessary.
6637
+ void addCompletionHandlerAsyncAttrIfNeccessary (
6638
+ ASTContext &Ctx, const FuncDecl *FD,
6639
+ const AsyncHandlerParamDesc &HandlerDesc,
6640
+ SourceEditConsumer &EditConsumer) {
6641
+ if (!Ctx.LangOpts .EnableExperimentalConcurrency )
6642
+ return ;
6643
+
6644
+ llvm::SmallString<0 > HandlerAttribute;
6645
+ llvm::raw_svector_ostream OS (HandlerAttribute);
6646
+ OS << " @completionHandlerAsync(\" " ;
6647
+ HandlerDesc.printAsyncFunctionName (OS);
6648
+ OS << " \" , completionHandlerIndex: " << HandlerDesc.Index << " )\n " ;
6649
+ EditConsumer.accept (Ctx.SourceMgr , FD->getAttributeInsertionLoc (false ),
6650
+ HandlerAttribute);
6651
+ }
6652
+
6483
6653
} // namespace asyncrefactorings
6484
6654
6485
6655
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable (
@@ -6601,16 +6771,7 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
6601
6771
" @available(*, deprecated, message: \" Prefer async "
6602
6772
" alternative instead\" )\n " );
6603
6773
6604
- if (Ctx.LangOpts .EnableExperimentalConcurrency ) {
6605
- // Add an attribute to describe its async alternative
6606
- llvm::SmallString<0 > HandlerAttribute;
6607
- llvm::raw_svector_ostream OS (HandlerAttribute);
6608
- OS << " @completionHandlerAsync(\" " ;
6609
- HandlerDesc.printAsyncFunctionName (OS);
6610
- OS << " \" , completionHandlerIndex: " << HandlerDesc.Index << " )\n " ;
6611
- EditConsumer.accept (SM, FD->getAttributeInsertionLoc (false ),
6612
- HandlerAttribute);
6613
- }
6774
+ addCompletionHandlerAsyncAttrIfNeccessary (Ctx, FD, HandlerDesc, EditConsumer);
6614
6775
6615
6776
AsyncConverter LegacyBodyCreator (TheFile, SM, DiagEngine, FD, HandlerDesc);
6616
6777
if (LegacyBodyCreator.createLegacyBody ()) {
@@ -6622,6 +6783,43 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
6622
6783
6623
6784
return false ;
6624
6785
}
6786
+
6787
+ bool RefactoringActionAddAsyncWrapper::isApplicable (
6788
+ const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) {
6789
+ using namespace asyncrefactorings ;
6790
+
6791
+ auto *FD = findFunction (CursorInfo);
6792
+ if (!FD)
6793
+ return false ;
6794
+
6795
+ auto HandlerDesc =
6796
+ AsyncHandlerParamDesc::find (FD, /* RequireAttributeOrName=*/ false );
6797
+ return HandlerDesc.isValid ();
6798
+ }
6799
+
6800
+ bool RefactoringActionAddAsyncWrapper::performChange () {
6801
+ using namespace asyncrefactorings ;
6802
+
6803
+ auto *FD = findFunction (CursorInfo);
6804
+ assert (FD &&
6805
+ " Should not run performChange when refactoring is not applicable" );
6806
+
6807
+ auto HandlerDesc =
6808
+ AsyncHandlerParamDesc::find (FD, /* RequireAttributeOrName=*/ false );
6809
+ assert (HandlerDesc.isValid () &&
6810
+ " Should not run performChange when refactoring is not applicable" );
6811
+
6812
+ AsyncConverter Converter (TheFile, SM, DiagEngine, FD, HandlerDesc);
6813
+ if (!Converter.createAsyncWrapper ())
6814
+ return true ;
6815
+
6816
+ addCompletionHandlerAsyncAttrIfNeccessary (Ctx, FD, HandlerDesc, EditConsumer);
6817
+
6818
+ // Add the async wrapper.
6819
+ Converter.insertAfter (FD, EditConsumer);
6820
+ return false ;
6821
+ }
6822
+
6625
6823
} // end of anonymous namespace
6626
6824
6627
6825
StringRef swift::ide::
0 commit comments