@@ -4149,6 +4149,36 @@ struct AsyncHandlerDesc {
4149
4149
return params ();
4150
4150
}
4151
4151
4152
+ // / Get the type of the error that will be thrown by the \c async method or \c
4153
+ // / None if the completion handler doesn't accept an error parameter.
4154
+ // / This may be more specialized than the generic 'Error' type if the
4155
+ // / completion handler of the converted function takes a more specialized
4156
+ // / error type.
4157
+ Optional<swift::Type> getErrorType () const {
4158
+ if (HasError) {
4159
+ switch (Type) {
4160
+ case HandlerType::INVALID:
4161
+ return None;
4162
+ case HandlerType::PARAMS:
4163
+ // The last parameter of the completion handler is the error param
4164
+ return params ().back ().getPlainType ()->lookThroughSingleOptionalType ();
4165
+ case HandlerType::RESULT:
4166
+ assert (
4167
+ params ().size () == 1 &&
4168
+ " Result handler should have the Result type as the only parameter" );
4169
+ auto ResultType =
4170
+ params ().back ().getPlainType ()->getAs <BoundGenericType>();
4171
+ auto GenericArgs = ResultType->getGenericArgs ();
4172
+ assert (GenericArgs.size () == 2 && " Result should have two params" );
4173
+ // The second (last) generic parameter of the Result type is the error
4174
+ // type.
4175
+ return GenericArgs.back ();
4176
+ }
4177
+ } else {
4178
+ return None;
4179
+ }
4180
+ }
4181
+
4152
4182
// / The `CallExpr` if the given node is a call to the `Handler`
4153
4183
CallExpr *getAsHandlerCall (ASTNode Node) const {
4154
4184
if (!isValid ())
@@ -5319,6 +5349,262 @@ class AsyncConverter : private SourceEntityWalker {
5319
5349
}
5320
5350
}
5321
5351
};
5352
+
5353
+ // / When adding an async alternative method for the function declaration \c FD,
5354
+ // / this class tries to create a function body for the legacy function (the one
5355
+ // / with a completion handler), which calls the newly converted async function.
5356
+ // / There are certain situations in which we fail to create such a body, e.g.
5357
+ // / if the completion handler has the signature `(String, Error?) -> Void` in
5358
+ // / which case we can't synthesize the result of type \c String in the error
5359
+ // / case.
5360
+ class LegacyAlternativeBodyCreator {
5361
+ // / The old function declaration for which an async alternative has been added
5362
+ // / and whose body shall be rewritten to call the newly added async
5363
+ // / alternative.
5364
+ FuncDecl *FD;
5365
+
5366
+ // / The description of the completion handler in the old function declaration.
5367
+ AsyncHandlerDesc HandlerDesc;
5368
+
5369
+ std::string Buffer;
5370
+ llvm::raw_string_ostream OS;
5371
+
5372
+ // / Adds the call to the refactored 'async' method without the 'await'
5373
+ // / keyword to the output stream.
5374
+ void addCallToAsyncMethod () {
5375
+ OS << FD->getBaseName () << " (" ;
5376
+ bool FirstParam = true ;
5377
+ for (auto Param : *FD->getParameters ()) {
5378
+ if (Param == HandlerDesc.Handler ) {
5379
+ // / We don't need to pass the completion handler to the async method.
5380
+ continue ;
5381
+ }
5382
+ if (!FirstParam) {
5383
+ OS << " , " ;
5384
+ } else {
5385
+ FirstParam = false ;
5386
+ }
5387
+ if (!Param->getArgumentName ().empty ()) {
5388
+ OS << Param->getArgumentName () << " : " ;
5389
+ }
5390
+ OS << Param->getParameterName ();
5391
+ }
5392
+ OS << " )" ;
5393
+ }
5394
+
5395
+ // / If the returned error type is more specialized than \c Error, adds an
5396
+ // / 'as! CustomError' cast to the more specialized error type to the output
5397
+ // / stream.
5398
+ void addCastToCustomErrorTypeIfNecessary () {
5399
+ auto ErrorType = *HandlerDesc.getErrorType ();
5400
+ if (ErrorType->getCanonicalType () !=
5401
+ FD->getASTContext ().getExceptionType ()) {
5402
+ OS << " as! " ;
5403
+ ErrorType->lookThroughSingleOptionalType ()->print (OS);
5404
+ }
5405
+ }
5406
+
5407
+ // / Adds the \c Index -th parameter to the completion handler.
5408
+ // / If \p HasResult is \c true, it is assumed that a variable named 'result'
5409
+ // / contains the result returned from the async alternative. If the callback
5410
+ // / also takes an error parameter, \c nil passed to the completion handler for
5411
+ // / the error.
5412
+ // / If \p HasResult is \c false, it is a assumed that a variable named 'error'
5413
+ // / contains the error thrown from the async method and 'nil' will be passed
5414
+ // / to the completion handler for all result parameters.
5415
+ void addCompletionHandlerArgument (size_t Index, bool HasResult) {
5416
+ if (HandlerDesc.HasError && Index == HandlerDesc.params ().size () - 1 ) {
5417
+ // The error parameter is the last argument of the completion handler.
5418
+ if (!HasResult) {
5419
+ OS << " error" ;
5420
+ addCastToCustomErrorTypeIfNecessary ();
5421
+ } else {
5422
+ OS << " nil" ;
5423
+ }
5424
+ } else {
5425
+ if (!HasResult) {
5426
+ OS << " nil" ;
5427
+ } else if (HandlerDesc
5428
+ .getSuccessParamAsyncReturnType (
5429
+ HandlerDesc.params ()[Index].getPlainType ())
5430
+ ->isVoid ()) {
5431
+ // Void return types are not returned by the async function, synthesize
5432
+ // a Void instance.
5433
+ OS << " ()" ;
5434
+ } else if (HandlerDesc.getSuccessParams ().size () > 1 ) {
5435
+ // If the async method returns a tuple, we need to pass its elements to
5436
+ // the completion handler separately. For example:
5437
+ //
5438
+ // func foo() async -> (String, Int) {}
5439
+ //
5440
+ // causes the following legacy body to be created:
5441
+ //
5442
+ // func foo(completion: (String, Int) -> Void) {
5443
+ // async {
5444
+ // let result = await foo()
5445
+ // completion(result.0, result.1)
5446
+ // }
5447
+ // }
5448
+ OS << " result." << Index;
5449
+ } else {
5450
+ OS << " result" ;
5451
+ }
5452
+ }
5453
+ }
5454
+
5455
+ // / Adds the call to the completion handler. See \c
5456
+ // / getCompletionHandlerArgument for how the arguments are synthesized if the
5457
+ // / completion handler takes arguments, not a \c Result type.
5458
+ void addCallToCompletionHandler (bool HasResult) {
5459
+ OS << HandlerDesc.Handler ->getParameterName () << " (" ;
5460
+
5461
+ // Construct arguments to pass to the completion handler
5462
+ switch (HandlerDesc.Type ) {
5463
+ case HandlerType::INVALID:
5464
+ llvm_unreachable (" Cannot be rewritten" );
5465
+ break ;
5466
+ case HandlerType::PARAMS: {
5467
+ for (size_t I = 0 ; I < HandlerDesc.params ().size (); ++I) {
5468
+ if (I > 0 ) {
5469
+ OS << " , " ;
5470
+ }
5471
+ addCompletionHandlerArgument (I, HasResult);
5472
+ }
5473
+ break ;
5474
+ }
5475
+ case HandlerType::RESULT: {
5476
+ if (HasResult) {
5477
+ OS << " .success(result)" ;
5478
+ } else {
5479
+ OS << " .failure(error" ;
5480
+ addCastToCustomErrorTypeIfNecessary ();
5481
+ OS << " )" ;
5482
+ }
5483
+ break ;
5484
+ }
5485
+ }
5486
+ OS << " )" ; // Close the call to the completion handler
5487
+ }
5488
+
5489
+ // / Adds the result type of the converted async function.
5490
+ void addAsyncFuncReturnType () {
5491
+ SmallVector<Type, 2 > Scratch;
5492
+ auto ReturnTypes = HandlerDesc.getAsyncReturnTypes (Scratch);
5493
+ if (ReturnTypes.size () > 1 ) {
5494
+ OS << " (" ;
5495
+ }
5496
+
5497
+ llvm::interleave (
5498
+ ReturnTypes, [&](Type Ty) { Ty->print (OS); }, [&]() { OS << " , " ; });
5499
+
5500
+ if (ReturnTypes.size () > 1 ) {
5501
+ OS << " )" ;
5502
+ }
5503
+ }
5504
+
5505
+ // / If the async alternative function is generic, adds the type annotation
5506
+ // / to the 'return' variable in the legacy function so that the generic
5507
+ // / parameters of the legacy function are passed to the generic function.
5508
+ // / For example for
5509
+ // / \code
5510
+ // / func foo<GenericParam>() async -> GenericParam {}
5511
+ // / \endcode
5512
+ // / we generate
5513
+ // / \code
5514
+ // / func foo<GenericParam>(completion: (T) -> Void) {
5515
+ // / async {
5516
+ // / let result: GenericParam = await foo()
5517
+ // / <------------>
5518
+ // / completion(result)
5519
+ // / }
5520
+ // / }
5521
+ // / \endcode
5522
+ // / This function adds the range marked by \c <----->
5523
+ void addResultTypeAnnotationIfNecessary () {
5524
+ if (FD->isGeneric ()) {
5525
+ OS << " : " ;
5526
+ addAsyncFuncReturnType ();
5527
+ }
5528
+ }
5529
+
5530
+ public:
5531
+ LegacyAlternativeBodyCreator (FuncDecl *FD, AsyncHandlerDesc HandlerDesc)
5532
+ : FD(FD), HandlerDesc(HandlerDesc), OS(Buffer) {}
5533
+
5534
+ bool canRewriteLegacyBody () {
5535
+ if (FD == nullptr || FD->getBody () == nullptr ) {
5536
+ return false ;
5537
+ }
5538
+ if (FD->hasThrows ()) {
5539
+ assert (!HandlerDesc.isValid () && " We shouldn't have found a handler desc "
5540
+ " if the original function throws" );
5541
+ return false ;
5542
+ }
5543
+ switch (HandlerDesc.Type ) {
5544
+ case HandlerType::INVALID:
5545
+ return false ;
5546
+ case HandlerType::PARAMS: {
5547
+ if (HandlerDesc.HasError ) {
5548
+ // The non-error parameters must be optional so that we can set them to
5549
+ // nil in the error case.
5550
+ // The error parameter must be optional so we can set it to nil in the
5551
+ // success case.
5552
+ // Otherwise we can't synthesize the values to return for these
5553
+ // parameters.
5554
+ return llvm::all_of (HandlerDesc.params (),
5555
+ [](AnyFunctionType::Param Param) -> bool {
5556
+ return Param.getPlainType ()->isOptional ();
5557
+ });
5558
+ } else {
5559
+ return true ;
5560
+ }
5561
+ }
5562
+ case HandlerType::RESULT:
5563
+ return true ;
5564
+ }
5565
+ }
5566
+
5567
+ std::string create () {
5568
+ assert (Buffer.empty () &&
5569
+ " LegacyAlternativeBodyCreator can only be used once" );
5570
+ assert (canRewriteLegacyBody () &&
5571
+ " Cannot create a legacy body if the body can't be rewritten" );
5572
+ OS << " {\n " ; // start function body
5573
+ OS << " async {\n " ;
5574
+ if (HandlerDesc.HasError ) {
5575
+ OS << " do {\n " ;
5576
+ if (!HandlerDesc.willAsyncReturnVoid ()) {
5577
+ OS << " let result" ;
5578
+ addResultTypeAnnotationIfNecessary ();
5579
+ OS << " = " ;
5580
+ }
5581
+ OS << " try await " ;
5582
+ addCallToAsyncMethod ();
5583
+ OS << " \n " ;
5584
+ addCallToCompletionHandler (/* HasResult=*/ true );
5585
+ OS << " \n "
5586
+ << " } catch {\n " ;
5587
+ addCallToCompletionHandler (/* HasResult=*/ false );
5588
+ OS << " \n "
5589
+ << " }\n " ; // end catch
5590
+ } else {
5591
+ if (!HandlerDesc.willAsyncReturnVoid ()) {
5592
+ OS << " let result" ;
5593
+ addResultTypeAnnotationIfNecessary ();
5594
+ OS << " = " ;
5595
+ }
5596
+ OS << " await " ;
5597
+ addCallToAsyncMethod ();
5598
+ OS << " \n " ;
5599
+ addCallToCompletionHandler (/* HasResult=*/ true );
5600
+ OS << " \n " ;
5601
+ }
5602
+ OS << " }\n " ; // end 'async'
5603
+ OS << " }\n " ; // end function body
5604
+ return Buffer;
5605
+ }
5606
+ };
5607
+
5322
5608
} // namespace asyncrefactorings
5323
5609
5324
5610
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable (
@@ -5425,6 +5711,13 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
5425
5711
EditConsumer.accept (SM, FD->getAttributeInsertionLoc (false ),
5426
5712
" @available(*, deprecated, message: \" Prefer async "
5427
5713
" alternative instead\" )\n " );
5714
+ LegacyAlternativeBodyCreator LegacyBody (FD, HandlerDesc);
5715
+ if (LegacyBody.canRewriteLegacyBody ()) {
5716
+ EditConsumer.accept (SM,
5717
+ Lexer::getCharSourceRangeFromSourceRange (
5718
+ SM, FD->getBody ()->getSourceRange ()),
5719
+ LegacyBody.create ());
5720
+ }
5428
5721
Converter.insertAfter (FD, EditConsumer);
5429
5722
5430
5723
return false ;
0 commit comments