@@ -4158,6 +4158,36 @@ struct AsyncHandlerDesc {
4158
4158
return params ();
4159
4159
}
4160
4160
4161
+ // / Get the type of the error that will be thrown by the \c async method or \c
4162
+ // / None if the completion handler doesn't accept an error parameter.
4163
+ // / This may be more specialized than the generic 'Error' type if the
4164
+ // / completion handler of the converted function takes a more specialized
4165
+ // / error type.
4166
+ Optional<swift::Type> getErrorType () const {
4167
+ if (HasError) {
4168
+ switch (Type) {
4169
+ case HandlerType::INVALID:
4170
+ return None;
4171
+ case HandlerType::PARAMS:
4172
+ // The last parameter of the completion handler is the error param
4173
+ return params ().back ().getPlainType ()->lookThroughSingleOptionalType ();
4174
+ case HandlerType::RESULT:
4175
+ assert (
4176
+ params ().size () == 1 &&
4177
+ " Result handler should have the Result type as the only parameter" );
4178
+ auto ResultType =
4179
+ params ().back ().getPlainType ()->getAs <BoundGenericType>();
4180
+ auto GenericArgs = ResultType->getGenericArgs ();
4181
+ assert (GenericArgs.size () == 2 && " Result should have two params" );
4182
+ // The second (last) generic parameter of the Result type is the error
4183
+ // type.
4184
+ return GenericArgs.back ();
4185
+ }
4186
+ } else {
4187
+ return None;
4188
+ }
4189
+ }
4190
+
4161
4191
// / The `CallExpr` if the given node is a call to the `Handler`
4162
4192
CallExpr *getAsHandlerCall (ASTNode Node) const {
4163
4193
if (!isValid ())
@@ -5318,6 +5348,262 @@ class AsyncConverter : private SourceEntityWalker {
5318
5348
}
5319
5349
}
5320
5350
};
5351
+
5352
+ // / When adding an async alternative method for the function declaration \c FD,
5353
+ // / this class tries to create a function body for the legacy function (the one
5354
+ // / with a completion handler), which calls the newly converted async function.
5355
+ // / There are certain situations in which we fail to create such a body, e.g.
5356
+ // / if the completion handler has the signature `(String, Error?) -> Void` in
5357
+ // / which case we can't synthesize the result of type \c String in the error
5358
+ // / case.
5359
+ class LegacyAlternativeBodyCreator {
5360
+ // / The old function declaration for which an async alternative has been added
5361
+ // / and whose body shall be rewritten to call the newly added async
5362
+ // / alternative.
5363
+ FuncDecl *FD;
5364
+
5365
+ // / The description of the completion handler in the old function declaration.
5366
+ AsyncHandlerDesc HandlerDesc;
5367
+
5368
+ std::string Buffer;
5369
+ llvm::raw_string_ostream OS;
5370
+
5371
+ // / Adds the call to the refactored 'async' method without the 'await'
5372
+ // / keyword to the output stream.
5373
+ void addCallToAsyncMethod () {
5374
+ OS << FD->getBaseName () << " (" ;
5375
+ bool FirstParam = true ;
5376
+ for (auto Param : *FD->getParameters ()) {
5377
+ if (Param == HandlerDesc.Handler ) {
5378
+ // / We don't need to pass the completion handler to the async method.
5379
+ continue ;
5380
+ }
5381
+ if (!FirstParam) {
5382
+ OS << " , " ;
5383
+ } else {
5384
+ FirstParam = false ;
5385
+ }
5386
+ if (!Param->getArgumentName ().empty ()) {
5387
+ OS << Param->getArgumentName () << " : " ;
5388
+ }
5389
+ OS << Param->getParameterName ();
5390
+ }
5391
+ OS << " )" ;
5392
+ }
5393
+
5394
+ // / If the returned error type is more specialized than \c Error, adds an
5395
+ // / 'as! CustomError' cast to the more specialized error type to the output
5396
+ // / stream.
5397
+ void addCastToCustomErrorTypeIfNecessary () {
5398
+ auto ErrorType = *HandlerDesc.getErrorType ();
5399
+ if (ErrorType->getCanonicalType () !=
5400
+ FD->getASTContext ().getExceptionType ()) {
5401
+ OS << " as! " ;
5402
+ ErrorType->lookThroughSingleOptionalType ()->print (OS);
5403
+ }
5404
+ }
5405
+
5406
+ // / Adds the \c Index -th parameter to the completion handler.
5407
+ // / If \p HasResult is \c true, it is assumed that a variable named 'result'
5408
+ // / contains the result returned from the async alternative. If the callback
5409
+ // / also takes an error parameter, \c nil passed to the completion handler for
5410
+ // / the error.
5411
+ // / If \p HasResult is \c false, it is a assumed that a variable named 'error'
5412
+ // / contains the error thrown from the async method and 'nil' will be passed
5413
+ // / to the completion handler for all result parameters.
5414
+ void addCompletionHandlerArgument (size_t Index, bool HasResult) {
5415
+ if (HandlerDesc.HasError && Index == HandlerDesc.params ().size () - 1 ) {
5416
+ // The error parameter is the last argument of the completion handler.
5417
+ if (!HasResult) {
5418
+ OS << " error" ;
5419
+ addCastToCustomErrorTypeIfNecessary ();
5420
+ } else {
5421
+ OS << " nil" ;
5422
+ }
5423
+ } else {
5424
+ if (!HasResult) {
5425
+ OS << " nil" ;
5426
+ } else if (HandlerDesc
5427
+ .getSuccessParamAsyncReturnType (
5428
+ HandlerDesc.params ()[Index].getPlainType ())
5429
+ ->isVoid ()) {
5430
+ // Void return types are not returned by the async function, synthesize
5431
+ // a Void instance.
5432
+ OS << " ()" ;
5433
+ } else if (HandlerDesc.getSuccessParams ().size () > 1 ) {
5434
+ // If the async method returns a tuple, we need to pass its elements to
5435
+ // the completion handler separately. For example:
5436
+ //
5437
+ // func foo() async -> (String, Int) {}
5438
+ //
5439
+ // causes the following legacy body to be created:
5440
+ //
5441
+ // func foo(completion: (String, Int) -> Void) {
5442
+ // async {
5443
+ // let result = await foo()
5444
+ // completion(result.0, result.1)
5445
+ // }
5446
+ // }
5447
+ OS << " result." << Index;
5448
+ } else {
5449
+ OS << " result" ;
5450
+ }
5451
+ }
5452
+ }
5453
+
5454
+ // / Adds the call to the completion handler. See \c
5455
+ // / getCompletionHandlerArgument for how the arguments are synthesized if the
5456
+ // / completion handler takes arguments, not a \c Result type.
5457
+ void addCallToCompletionHandler (bool HasResult) {
5458
+ OS << HandlerDesc.Handler ->getParameterName () << " (" ;
5459
+
5460
+ // Construct arguments to pass to the completion handler
5461
+ switch (HandlerDesc.Type ) {
5462
+ case HandlerType::INVALID:
5463
+ llvm_unreachable (" Cannot be rewritten" );
5464
+ break ;
5465
+ case HandlerType::PARAMS: {
5466
+ for (size_t I = 0 ; I < HandlerDesc.params ().size (); ++I) {
5467
+ if (I > 0 ) {
5468
+ OS << " , " ;
5469
+ }
5470
+ addCompletionHandlerArgument (I, HasResult);
5471
+ }
5472
+ break ;
5473
+ }
5474
+ case HandlerType::RESULT: {
5475
+ if (HasResult) {
5476
+ OS << " .success(result)" ;
5477
+ } else {
5478
+ OS << " .failure(error" ;
5479
+ addCastToCustomErrorTypeIfNecessary ();
5480
+ OS << " )" ;
5481
+ }
5482
+ break ;
5483
+ }
5484
+ }
5485
+ OS << " )" ; // Close the call to the completion handler
5486
+ }
5487
+
5488
+ // / Adds the result type of the converted async function.
5489
+ void addAsyncFuncReturnType () {
5490
+ SmallVector<Type, 2 > Scratch;
5491
+ auto ReturnTypes = HandlerDesc.getAsyncReturnTypes (Scratch);
5492
+ if (ReturnTypes.size () > 1 ) {
5493
+ OS << " (" ;
5494
+ }
5495
+
5496
+ llvm::interleave (
5497
+ ReturnTypes, [&](Type Ty) { Ty->print (OS); }, [&]() { OS << " , " ; });
5498
+
5499
+ if (ReturnTypes.size () > 1 ) {
5500
+ OS << " )" ;
5501
+ }
5502
+ }
5503
+
5504
+ // / If the async alternative function is generic, adds the type annotation
5505
+ // / to the 'return' variable in the legacy function so that the generic
5506
+ // / parameters of the legacy function are passed to the generic function.
5507
+ // / For example for
5508
+ // / \code
5509
+ // / func foo<GenericParam>() async -> GenericParam {}
5510
+ // / \endcode
5511
+ // / we generate
5512
+ // / \code
5513
+ // / func foo<GenericParam>(completion: (T) -> Void) {
5514
+ // / async {
5515
+ // / let result: GenericParam = await foo()
5516
+ // / <------------>
5517
+ // / completion(result)
5518
+ // / }
5519
+ // / }
5520
+ // / \endcode
5521
+ // / This function adds the range marked by \c <----->
5522
+ void addResultTypeAnnotationIfNecessary () {
5523
+ if (FD->isGeneric ()) {
5524
+ OS << " : " ;
5525
+ addAsyncFuncReturnType ();
5526
+ }
5527
+ }
5528
+
5529
+ public:
5530
+ LegacyAlternativeBodyCreator (FuncDecl *FD, AsyncHandlerDesc HandlerDesc)
5531
+ : FD(FD), HandlerDesc(HandlerDesc), OS(Buffer) {}
5532
+
5533
+ bool canRewriteLegacyBody () {
5534
+ if (FD == nullptr || FD->getBody () == nullptr ) {
5535
+ return false ;
5536
+ }
5537
+ if (FD->hasThrows ()) {
5538
+ assert (!HandlerDesc.isValid () && " We shouldn't have found a handler desc "
5539
+ " if the original function throws" );
5540
+ return false ;
5541
+ }
5542
+ switch (HandlerDesc.Type ) {
5543
+ case HandlerType::INVALID:
5544
+ return false ;
5545
+ case HandlerType::PARAMS: {
5546
+ if (HandlerDesc.HasError ) {
5547
+ // The non-error parameters must be optional so that we can set them to
5548
+ // nil in the error case.
5549
+ // The error parameter must be optional so we can set it to nil in the
5550
+ // success case.
5551
+ // Otherwise we can't synthesize the values to return for these
5552
+ // parameters.
5553
+ return llvm::all_of (HandlerDesc.params (),
5554
+ [](AnyFunctionType::Param Param) -> bool {
5555
+ return Param.getPlainType ()->isOptional ();
5556
+ });
5557
+ } else {
5558
+ return true ;
5559
+ }
5560
+ }
5561
+ case HandlerType::RESULT:
5562
+ return true ;
5563
+ }
5564
+ }
5565
+
5566
+ std::string create () {
5567
+ assert (Buffer.empty () &&
5568
+ " LegacyAlternativeBodyCreator can only be used once" );
5569
+ assert (canRewriteLegacyBody () &&
5570
+ " Cannot create a legacy body if the body can't be rewritten" );
5571
+ OS << " {\n " ; // start function body
5572
+ OS << " async {\n " ;
5573
+ if (HandlerDesc.HasError ) {
5574
+ OS << " do {\n " ;
5575
+ if (!HandlerDesc.willAsyncReturnVoid ()) {
5576
+ OS << " let result" ;
5577
+ addResultTypeAnnotationIfNecessary ();
5578
+ OS << " = " ;
5579
+ }
5580
+ OS << " try await " ;
5581
+ addCallToAsyncMethod ();
5582
+ OS << " \n " ;
5583
+ addCallToCompletionHandler (/* HasResult=*/ true );
5584
+ OS << " \n "
5585
+ << " } catch {\n " ;
5586
+ addCallToCompletionHandler (/* HasResult=*/ false );
5587
+ OS << " \n "
5588
+ << " }\n " ; // end catch
5589
+ } else {
5590
+ if (!HandlerDesc.willAsyncReturnVoid ()) {
5591
+ OS << " let result" ;
5592
+ addResultTypeAnnotationIfNecessary ();
5593
+ OS << " = " ;
5594
+ }
5595
+ OS << " await " ;
5596
+ addCallToAsyncMethod ();
5597
+ OS << " \n " ;
5598
+ addCallToCompletionHandler (/* HasResult=*/ true );
5599
+ OS << " \n " ;
5600
+ }
5601
+ OS << " }\n " ; // end 'async'
5602
+ OS << " }\n " ; // end function body
5603
+ return Buffer;
5604
+ }
5605
+ };
5606
+
5321
5607
} // namespace asyncrefactorings
5322
5608
5323
5609
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable (
@@ -5424,6 +5710,13 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
5424
5710
EditConsumer.accept (SM, FD->getAttributeInsertionLoc (false ),
5425
5711
" @available(*, deprecated, message: \" Prefer async "
5426
5712
" alternative instead\" )\n " );
5713
+ LegacyAlternativeBodyCreator LegacyBody (FD, HandlerDesc);
5714
+ if (LegacyBody.canRewriteLegacyBody ()) {
5715
+ EditConsumer.accept (SM,
5716
+ Lexer::getCharSourceRangeFromSourceRange (
5717
+ SM, FD->getBody ()->getSourceRange ()),
5718
+ LegacyBody.create ());
5719
+ }
5427
5720
Converter.insertAfter (FD, EditConsumer);
5428
5721
5429
5722
return false ;
0 commit comments