@@ -4158,6 +4158,36 @@ struct AsyncHandlerDesc {
4158
4158
return params ();
4159
4159
}
4160
4160
4161
+ // / Get the type of the error that will be thrown by the \c async method or \c
4162
+ // / None if the completion handler doesn't accept an error parameter.
4163
+ // / This may be more specialized than the generic 'Error' type if the
4164
+ // / completion handler of the converted function takes a more specialized
4165
+ // / error type.
4166
+ Optional<swift::Type> getErrorType () const {
4167
+ if (HasError) {
4168
+ switch (Type) {
4169
+ case HandlerType::INVALID:
4170
+ return None;
4171
+ case HandlerType::PARAMS:
4172
+ // The last parameter of the completion handler is the error param
4173
+ return params ().back ().getPlainType ()->lookThroughSingleOptionalType ();
4174
+ case HandlerType::RESULT:
4175
+ assert (
4176
+ params ().size () == 1 &&
4177
+ " Result handler should have the Result type as the only parameter" );
4178
+ auto ResultType =
4179
+ params ().back ().getPlainType ()->getAs <BoundGenericType>();
4180
+ auto GenericArgs = ResultType->getGenericArgs ();
4181
+ assert (GenericArgs.size () == 2 && " Result should have two params" );
4182
+ // The second (last) generic parameter of the Result type is the error
4183
+ // type.
4184
+ return GenericArgs.back ();
4185
+ }
4186
+ } else {
4187
+ return None;
4188
+ }
4189
+ }
4190
+
4161
4191
// / The `CallExpr` if the given node is a call to the `Handler`
4162
4192
CallExpr *getAsHandlerCall (ASTNode Node) const {
4163
4193
if (!isValid ())
@@ -5370,6 +5400,262 @@ class AsyncConverter : private SourceEntityWalker {
5370
5400
}
5371
5401
}
5372
5402
};
5403
+
5404
+ // / When adding an async alternative method for the function declaration \c FD,
5405
+ // / this class tries to create a function body for the legacy function (the one
5406
+ // / with a completion handler), which calls the newly converted async function.
5407
+ // / There are certain situations in which we fail to create such a body, e.g.
5408
+ // / if the completion handler has the signature `(String, Error?) -> Void` in
5409
+ // / which case we can't synthesize the result of type \c String in the error
5410
+ // / case.
5411
+ class LegacyAlternativeBodyCreator {
5412
+ // / The old function declaration for which an async alternative has been added
5413
+ // / and whose body shall be rewritten to call the newly added async
5414
+ // / alternative.
5415
+ FuncDecl *FD;
5416
+
5417
+ // / The description of the completion handler in the old function declaration.
5418
+ AsyncHandlerDesc HandlerDesc;
5419
+
5420
+ std::string Buffer;
5421
+ llvm::raw_string_ostream OS;
5422
+
5423
+ // / Adds the call to the refactored 'async' method without the 'await'
5424
+ // / keyword to the output stream.
5425
+ void addCallToAsyncMethod () {
5426
+ OS << FD->getBaseName () << " (" ;
5427
+ bool FirstParam = true ;
5428
+ for (auto Param : *FD->getParameters ()) {
5429
+ if (Param == HandlerDesc.Handler ) {
5430
+ // / We don't need to pass the completion handler to the async method.
5431
+ continue ;
5432
+ }
5433
+ if (!FirstParam) {
5434
+ OS << " , " ;
5435
+ } else {
5436
+ FirstParam = false ;
5437
+ }
5438
+ if (!Param->getArgumentName ().empty ()) {
5439
+ OS << Param->getArgumentName () << " : " ;
5440
+ }
5441
+ OS << Param->getParameterName ();
5442
+ }
5443
+ OS << " )" ;
5444
+ }
5445
+
5446
+ // / If the returned error type is more specialized than \c Error, adds an
5447
+ // / 'as! CustomError' cast to the more specialized error type to the output
5448
+ // / stream.
5449
+ void addCastToCustomErrorTypeIfNecessary () {
5450
+ auto ErrorType = *HandlerDesc.getErrorType ();
5451
+ if (ErrorType->getCanonicalType () !=
5452
+ FD->getASTContext ().getExceptionType ()) {
5453
+ OS << " as! " ;
5454
+ ErrorType->lookThroughSingleOptionalType ()->print (OS);
5455
+ }
5456
+ }
5457
+
5458
+ // / Adds the \c Index -th parameter to the completion handler.
5459
+ // / If \p HasResult is \c true, it is assumed that a variable named 'result'
5460
+ // / contains the result returned from the async alternative. If the callback
5461
+ // / also takes an error parameter, \c nil passed to the completion handler for
5462
+ // / the error.
5463
+ // / If \p HasResult is \c false, it is a assumed that a variable named 'error'
5464
+ // / contains the error thrown from the async method and 'nil' will be passed
5465
+ // / to the completion handler for all result parameters.
5466
+ void addCompletionHandlerArgument (size_t Index, bool HasResult) {
5467
+ if (HandlerDesc.HasError && Index == HandlerDesc.params ().size () - 1 ) {
5468
+ // The error parameter is the last argument of the completion handler.
5469
+ if (!HasResult) {
5470
+ OS << " error" ;
5471
+ addCastToCustomErrorTypeIfNecessary ();
5472
+ } else {
5473
+ OS << " nil" ;
5474
+ }
5475
+ } else {
5476
+ if (!HasResult) {
5477
+ OS << " nil" ;
5478
+ } else if (HandlerDesc
5479
+ .getSuccessParamAsyncReturnType (
5480
+ HandlerDesc.params ()[Index].getPlainType ())
5481
+ ->isVoid ()) {
5482
+ // Void return types are not returned by the async function, synthesize
5483
+ // a Void instance.
5484
+ OS << " ()" ;
5485
+ } else if (HandlerDesc.getSuccessParams ().size () > 1 ) {
5486
+ // If the async method returns a tuple, we need to pass its elements to
5487
+ // the completion handler separately. For example:
5488
+ //
5489
+ // func foo() async -> (String, Int) {}
5490
+ //
5491
+ // causes the following legacy body to be created:
5492
+ //
5493
+ // func foo(completion: (String, Int) -> Void) {
5494
+ // async {
5495
+ // let result = await foo()
5496
+ // completion(result.0, result.1)
5497
+ // }
5498
+ // }
5499
+ OS << " result." << Index;
5500
+ } else {
5501
+ OS << " result" ;
5502
+ }
5503
+ }
5504
+ }
5505
+
5506
+ // / Adds the call to the completion handler. See \c
5507
+ // / getCompletionHandlerArgument for how the arguments are synthesized if the
5508
+ // / completion handler takes arguments, not a \c Result type.
5509
+ void addCallToCompletionHandler (bool HasResult) {
5510
+ OS << HandlerDesc.Handler ->getParameterName () << " (" ;
5511
+
5512
+ // Construct arguments to pass to the completion handler
5513
+ switch (HandlerDesc.Type ) {
5514
+ case HandlerType::INVALID:
5515
+ llvm_unreachable (" Cannot be rewritten" );
5516
+ break ;
5517
+ case HandlerType::PARAMS: {
5518
+ for (size_t I = 0 ; I < HandlerDesc.params ().size (); ++I) {
5519
+ if (I > 0 ) {
5520
+ OS << " , " ;
5521
+ }
5522
+ addCompletionHandlerArgument (I, HasResult);
5523
+ }
5524
+ break ;
5525
+ }
5526
+ case HandlerType::RESULT: {
5527
+ if (HasResult) {
5528
+ OS << " .success(result)" ;
5529
+ } else {
5530
+ OS << " .failure(error" ;
5531
+ addCastToCustomErrorTypeIfNecessary ();
5532
+ OS << " )" ;
5533
+ }
5534
+ break ;
5535
+ }
5536
+ }
5537
+ OS << " )" ; // Close the call to the completion handler
5538
+ }
5539
+
5540
+ // / Adds the result type of the converted async function.
5541
+ void addAsyncFuncReturnType () {
5542
+ SmallVector<Type, 2 > Scratch;
5543
+ auto ReturnTypes = HandlerDesc.getAsyncReturnTypes (Scratch);
5544
+ if (ReturnTypes.size () > 1 ) {
5545
+ OS << " (" ;
5546
+ }
5547
+
5548
+ llvm::interleave (
5549
+ ReturnTypes, [&](Type Ty) { Ty->print (OS); }, [&]() { OS << " , " ; });
5550
+
5551
+ if (ReturnTypes.size () > 1 ) {
5552
+ OS << " )" ;
5553
+ }
5554
+ }
5555
+
5556
+ // / If the async alternative function is generic, adds the type annotation
5557
+ // / to the 'return' variable in the legacy function so that the generic
5558
+ // / parameters of the legacy function are passed to the generic function.
5559
+ // / For example for
5560
+ // / \code
5561
+ // / func foo<GenericParam>() async -> GenericParam {}
5562
+ // / \endcode
5563
+ // / we generate
5564
+ // / \code
5565
+ // / func foo<GenericParam>(completion: (T) -> Void) {
5566
+ // / async {
5567
+ // / let result: GenericParam = await foo()
5568
+ // / <------------>
5569
+ // / completion(result)
5570
+ // / }
5571
+ // / }
5572
+ // / \endcode
5573
+ // / This function adds the range marked by \c <----->
5574
+ void addResultTypeAnnotationIfNecessary () {
5575
+ if (FD->isGeneric ()) {
5576
+ OS << " : " ;
5577
+ addAsyncFuncReturnType ();
5578
+ }
5579
+ }
5580
+
5581
+ public:
5582
+ LegacyAlternativeBodyCreator (FuncDecl *FD, AsyncHandlerDesc HandlerDesc)
5583
+ : FD(FD), HandlerDesc(HandlerDesc), OS(Buffer) {}
5584
+
5585
+ bool canRewriteLegacyBody () {
5586
+ if (FD == nullptr || FD->getBody () == nullptr ) {
5587
+ return false ;
5588
+ }
5589
+ if (FD->hasThrows ()) {
5590
+ assert (!HandlerDesc.isValid () && " We shouldn't have found a handler desc "
5591
+ " if the original function throws" );
5592
+ return false ;
5593
+ }
5594
+ switch (HandlerDesc.Type ) {
5595
+ case HandlerType::INVALID:
5596
+ return false ;
5597
+ case HandlerType::PARAMS: {
5598
+ if (HandlerDesc.HasError ) {
5599
+ // The non-error parameters must be optional so that we can set them to
5600
+ // nil in the error case.
5601
+ // The error parameter must be optional so we can set it to nil in the
5602
+ // success case.
5603
+ // Otherwise we can't synthesize the values to return for these
5604
+ // parameters.
5605
+ return llvm::all_of (HandlerDesc.params (),
5606
+ [](AnyFunctionType::Param Param) -> bool {
5607
+ return Param.getPlainType ()->isOptional ();
5608
+ });
5609
+ } else {
5610
+ return true ;
5611
+ }
5612
+ }
5613
+ case HandlerType::RESULT:
5614
+ return true ;
5615
+ }
5616
+ }
5617
+
5618
+ std::string create () {
5619
+ assert (Buffer.empty () &&
5620
+ " LegacyAlternativeBodyCreator can only be used once" );
5621
+ assert (canRewriteLegacyBody () &&
5622
+ " Cannot create a legacy body if the body can't be rewritten" );
5623
+ OS << " {\n " ; // start function body
5624
+ OS << " async {\n " ;
5625
+ if (HandlerDesc.HasError ) {
5626
+ OS << " do {\n " ;
5627
+ if (!HandlerDesc.willAsyncReturnVoid ()) {
5628
+ OS << " let result" ;
5629
+ addResultTypeAnnotationIfNecessary ();
5630
+ OS << " = " ;
5631
+ }
5632
+ OS << " try await " ;
5633
+ addCallToAsyncMethod ();
5634
+ OS << " \n " ;
5635
+ addCallToCompletionHandler (/* HasResult=*/ true );
5636
+ OS << " \n "
5637
+ << " } catch {\n " ;
5638
+ addCallToCompletionHandler (/* HasResult=*/ false );
5639
+ OS << " \n "
5640
+ << " }\n " ; // end catch
5641
+ } else {
5642
+ if (!HandlerDesc.willAsyncReturnVoid ()) {
5643
+ OS << " let result" ;
5644
+ addResultTypeAnnotationIfNecessary ();
5645
+ OS << " = " ;
5646
+ }
5647
+ OS << " await " ;
5648
+ addCallToAsyncMethod ();
5649
+ OS << " \n " ;
5650
+ addCallToCompletionHandler (/* HasResult=*/ true );
5651
+ OS << " \n " ;
5652
+ }
5653
+ OS << " }\n " ; // end 'async'
5654
+ OS << " }\n " ; // end function body
5655
+ return Buffer;
5656
+ }
5657
+ };
5658
+
5373
5659
} // namespace asyncrefactorings
5374
5660
5375
5661
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable (
@@ -5476,6 +5762,13 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
5476
5762
EditConsumer.accept (SM, FD->getAttributeInsertionLoc (false ),
5477
5763
" @available(*, deprecated, message: \" Prefer async "
5478
5764
" alternative instead\" )\n " );
5765
+ LegacyAlternativeBodyCreator LegacyBody (FD, HandlerDesc);
5766
+ if (LegacyBody.canRewriteLegacyBody ()) {
5767
+ EditConsumer.accept (SM,
5768
+ Lexer::getCharSourceRangeFromSourceRange (
5769
+ SM, FD->getBody ()->getSourceRange ()),
5770
+ LegacyBody.create ());
5771
+ }
5479
5772
Converter.insertAfter (FD, EditConsumer);
5480
5773
5481
5774
return false ;
0 commit comments