|
23 | 23 | #include "llvm/Support/CommandLine.h"
|
24 | 24 | #include "llvm/Support/raw_ostream.h"
|
25 | 25 |
|
| 26 | +#include "Callee.h" |
26 | 27 | #include "Explosion.h"
|
27 | 28 | #include "IRGenDebugInfo.h"
|
28 | 29 | #include "IRGenFunction.h"
|
@@ -511,3 +512,263 @@ llvm::Value *IRGenFunction::alignUpToMaximumAlignment(llvm::Type *sizeTy, llvm::
|
511 | 512 | auto *invertedMask = Builder.CreateNot(alignMask);
|
512 | 513 | return Builder.CreateAnd(Builder.CreateAdd(val, alignMask), invertedMask);
|
513 | 514 | }
|
| 515 | + |
| 516 | +/// Returns the current task \p currTask as an UnsafeContinuation at +1. |
| 517 | +static llvm::Value *unsafeContinuationFromTask(IRGenFunction &IGF, |
| 518 | + SILType unsafeContinuationTy, |
| 519 | + llvm::Value *currTask) { |
| 520 | + auto &IGM = IGF.IGM; |
| 521 | + auto &Builder = IGF.Builder; |
| 522 | + |
| 523 | + auto &rawPonterTI = IGM.getRawPointerTypeInfo(); |
| 524 | + auto object = |
| 525 | + Builder.CreateBitOrPointerCast(currTask, rawPonterTI.getStorageType()); |
| 526 | + |
| 527 | + // Wrap the native object in the UnsafeContinuation struct. |
| 528 | + // struct UnsafeContinuation<T> { |
| 529 | + // let _continuation : Builtin.RawPointer |
| 530 | + // } |
| 531 | + auto &unsafeContinuationTI = |
| 532 | + cast<LoadableTypeInfo>(IGF.getTypeInfo(unsafeContinuationTy)); |
| 533 | + auto unsafeContinuationStructTy = |
| 534 | + cast<llvm::StructType>(unsafeContinuationTI.getStorageType()); |
| 535 | + auto fieldTy = |
| 536 | + cast<llvm::StructType>(unsafeContinuationStructTy->getElementType(0)); |
| 537 | + auto reference = |
| 538 | + Builder.CreateBitOrPointerCast(object, fieldTy->getElementType(0)); |
| 539 | + auto field = |
| 540 | + Builder.CreateInsertValue(llvm::UndefValue::get(fieldTy), reference, 0); |
| 541 | + auto unsafeContinuation = Builder.CreateInsertValue( |
| 542 | + llvm::UndefValue::get(unsafeContinuationStructTy), field, 0); |
| 543 | + |
| 544 | + return unsafeContinuation; |
| 545 | +} |
| 546 | + |
| 547 | +void IRGenFunction::emitGetAsyncContinuation(SILType unsafeContinuationTy, |
| 548 | + StackAddress resultAddr, |
| 549 | + Explosion &out) { |
| 550 | + // Create the continuation. |
| 551 | + // void current_sil_function(AsyncTask *currTask, Executor *currExecutor, |
| 552 | + // AsyncContext *currCtxt) { |
| 553 | + // |
| 554 | + // A continuation is the current AsyncTask 'currTask' with: |
| 555 | + // currTask->ResumeTask = @llvm.coro.async.resume(); |
| 556 | + // currTask->ResumeContext = &continuation_context; |
| 557 | + // |
| 558 | + // Where: |
| 559 | + // |
| 560 | + // struct { |
| 561 | + // AsyncContext *resumeCtxt; |
| 562 | + // void *awaitSynchronization; |
| 563 | + // SwiftError *errResult; |
| 564 | + // union { |
| 565 | + // IndirectResult *result; |
| 566 | + // DirectResult *result; |
| 567 | + // }; |
| 568 | + // ExecutorRef *resumeExecutor; |
| 569 | + // } continuation_context; // local variable of current_sil_function |
| 570 | + // |
| 571 | + // continuation_context.resumeCtxt = currCtxt; |
| 572 | + // continuation_context.errResult = nulllptr; |
| 573 | + // continuation_context.result = ... // local alloca. |
| 574 | + // continuation_context.resumeExecutor = .. // current executor |
| 575 | + |
| 576 | + auto currTask = getAsyncTask(); |
| 577 | + auto unsafeContinuation = |
| 578 | + unsafeContinuationFromTask(*this, unsafeContinuationTy, currTask); |
| 579 | + |
| 580 | + // Create and setup the continuation context for UnsafeContinuation<T>. |
| 581 | + // continuation_context.resumeCtxt = currCtxt; |
| 582 | + // continuation_context.errResult = nulllptr; |
| 583 | + // continuation_context.result = ... // local alloca T |
| 584 | + auto pointerAlignment = IGM.getPointerAlignment(); |
| 585 | + auto continuationContext = |
| 586 | + createAlloca(IGM.AsyncContinuationContextTy, pointerAlignment); |
| 587 | + AsyncCoroutineCurrentContinuationContext = continuationContext.getAddress(); |
| 588 | + // TODO: add lifetime with matching lifetime in await_async_continuation |
| 589 | + auto contResumeAddr = |
| 590 | + Builder.CreateStructGEP(continuationContext.getAddress(), 0); |
| 591 | + Builder.CreateStore(getAsyncContext(), |
| 592 | + Address(contResumeAddr, pointerAlignment)); |
| 593 | + auto contErrResultAddr = |
| 594 | + Builder.CreateStructGEP(continuationContext.getAddress(), 2); |
| 595 | + Builder.CreateStore( |
| 596 | + llvm::Constant::getNullValue( |
| 597 | + contErrResultAddr->getType()->getPointerElementType()), |
| 598 | + Address(contErrResultAddr, pointerAlignment)); |
| 599 | + auto contResultAddr = |
| 600 | + Builder.CreateStructGEP(continuationContext.getAddress(), 3); |
| 601 | + if (!resultAddr.getAddress().isValid()) { |
| 602 | + assert(unsafeContinuationTy.getASTType() |
| 603 | + ->castTo<BoundGenericType>() |
| 604 | + ->getGenericArgs() |
| 605 | + .size() == 1 && |
| 606 | + "expect UnsafeContinuation<T> to have one generic arg"); |
| 607 | + auto resultTy = IGM.getLoweredType(unsafeContinuationTy.getASTType() |
| 608 | + ->castTo<BoundGenericType>() |
| 609 | + ->getGenericArgs()[0] |
| 610 | + ->getCanonicalType()); |
| 611 | + auto &resultTI = getTypeInfo(resultTy); |
| 612 | + auto resultAddr = |
| 613 | + resultTI.allocateStack(*this, resultTy, "async.continuation.result"); |
| 614 | + Builder.CreateStore(Builder.CreateBitOrPointerCast( |
| 615 | + resultAddr.getAddress().getAddress(), |
| 616 | + contResultAddr->getType()->getPointerElementType()), |
| 617 | + Address(contResultAddr, pointerAlignment)); |
| 618 | + } else { |
| 619 | + Builder.CreateStore(Builder.CreateBitOrPointerCast( |
| 620 | + resultAddr.getAddress().getAddress(), |
| 621 | + contResultAddr->getType()->getPointerElementType()), |
| 622 | + Address(contResultAddr, pointerAlignment)); |
| 623 | + } |
| 624 | + // continuation_context.resumeExecutor = // current executor |
| 625 | + auto contExecutorRefAddr = |
| 626 | + Builder.CreateStructGEP(continuationContext.getAddress(), 4); |
| 627 | + Builder.CreateStore( |
| 628 | + Builder.CreateBitOrPointerCast( |
| 629 | + getAsyncExecutor(), |
| 630 | + contExecutorRefAddr->getType()->getPointerElementType()), |
| 631 | + Address(contExecutorRefAddr, pointerAlignment)); |
| 632 | + |
| 633 | + // Fill the current task (i.e the continuation) with the continuation |
| 634 | + // information. |
| 635 | + // currTask->ResumeTask = @llvm.coro.async.resume(); |
| 636 | + assert(currTask->getType() == IGM.SwiftTaskPtrTy); |
| 637 | + auto currTaskResumeTaskAddr = Builder.CreateStructGEP(currTask,3); |
| 638 | + auto coroResume = |
| 639 | + Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_async_resume, {}); |
| 640 | + |
| 641 | + assert(AsyncCoroutineCurrentResume == nullptr && |
| 642 | + "Don't support nested get_async_continuation"); |
| 643 | + AsyncCoroutineCurrentResume = coroResume; |
| 644 | + Builder.CreateStore( |
| 645 | + Builder.CreateBitOrPointerCast(coroResume, IGM.FunctionPtrTy), |
| 646 | + Address(currTaskResumeTaskAddr, pointerAlignment)); |
| 647 | + // currTask->ResumeContext = &continuation_context; |
| 648 | + auto currTaskResumeCtxtAddr = Builder.CreateStructGEP(currTask, 4); |
| 649 | + Builder.CreateStore( |
| 650 | + Builder.CreateBitOrPointerCast(continuationContext.getAddress(), |
| 651 | + IGM.SwiftContextPtrTy), |
| 652 | + Address(currTaskResumeCtxtAddr, pointerAlignment)); |
| 653 | + |
| 654 | + // Publish all the writes. |
| 655 | + // continuation_context.awaitSynchronization =(atomic release) nullptr; |
| 656 | + auto contAwaitSyncAddr = |
| 657 | + Builder.CreateStructGEP(continuationContext.getAddress(), 1); |
| 658 | + auto null = llvm::ConstantInt::get( |
| 659 | + contAwaitSyncAddr->getType()->getPointerElementType(), 0); |
| 660 | + auto atomicStore = |
| 661 | + Builder.CreateStore(null, Address(contAwaitSyncAddr, pointerAlignment)); |
| 662 | + atomicStore->setAtomic(llvm::AtomicOrdering::Release, |
| 663 | + llvm::SyncScope::System); |
| 664 | + out.add(unsafeContinuation); |
| 665 | +} |
| 666 | + |
| 667 | +void IRGenFunction::emitAwaitAsyncContinuation( |
| 668 | + SILType unsafeContinuationTy, bool isIndirectResult, |
| 669 | + Explosion &outDirectResult, llvm::BasicBlock *&normalBB, |
| 670 | + llvm::PHINode *&optionalErrorResult, llvm::BasicBlock *&optionalErrorBB) { |
| 671 | + assert(AsyncCoroutineCurrentContinuationContext && "no active continuation"); |
| 672 | + auto pointerAlignment = IGM.getPointerAlignment(); |
| 673 | + |
| 674 | + // First check whether the await reached this point first. Meaning we still |
| 675 | + // have to wait for the continuation result. If the await reaches first we |
| 676 | + // abort the control flow here (resuming the continuation will execute the |
| 677 | + // remaining control flow). |
| 678 | + auto contAwaitSyncAddr = |
| 679 | + Builder.CreateStructGEP(AsyncCoroutineCurrentContinuationContext, 1); |
| 680 | + auto null = llvm::ConstantInt::get( |
| 681 | + contAwaitSyncAddr->getType()->getPointerElementType(), 0); |
| 682 | + auto one = llvm::ConstantInt::get( |
| 683 | + contAwaitSyncAddr->getType()->getPointerElementType(), 1); |
| 684 | + auto results = Builder.CreateAtomicCmpXchg( |
| 685 | + contAwaitSyncAddr, null, one, |
| 686 | + llvm::AtomicOrdering::Release /*success ordering*/, |
| 687 | + llvm::AtomicOrdering::Acquire /* failure ordering */, |
| 688 | + llvm::SyncScope::System); |
| 689 | + auto firstAtAwait = Builder.CreateExtractValue(results, 1); |
| 690 | + auto contBB = createBasicBlock("await.async.maybe.resume"); |
| 691 | + auto abortBB = createBasicBlock("await.async.abort"); |
| 692 | + Builder.CreateCondBr(firstAtAwait, abortBB, contBB); |
| 693 | + Builder.emitBlock(abortBB); |
| 694 | + { |
| 695 | + // We are first to the sync point. Abort. The continuation's result is not |
| 696 | + // available yet. |
| 697 | + emitCoroutineOrAsyncExit(); |
| 698 | + } |
| 699 | + |
| 700 | + auto contBB2 = createBasicBlock("await.async.resume"); |
| 701 | + Builder.emitBlock(contBB); |
| 702 | + { |
| 703 | + // Setup the suspend point. |
| 704 | + SmallVector<llvm::Value *, 8> arguments; |
| 705 | + arguments.push_back(AsyncCoroutineCurrentResume); |
| 706 | + auto resumeProjFn = getOrCreateResumePrjFn(); |
| 707 | + arguments.push_back( |
| 708 | + Builder.CreateBitOrPointerCast(resumeProjFn, IGM.Int8PtrTy)); |
| 709 | + // The dispatch function just calls the resume point. |
| 710 | + auto resumeFnPtr = |
| 711 | + getFunctionPointerForResumeIntrinsic(AsyncCoroutineCurrentResume); |
| 712 | + arguments.push_back(Builder.CreateBitOrPointerCast( |
| 713 | + createAsyncDispatchFn(resumeFnPtr, |
| 714 | + {IGM.Int8PtrTy, IGM.Int8PtrTy, IGM.Int8PtrTy}), |
| 715 | + IGM.Int8PtrTy)); |
| 716 | + arguments.push_back(AsyncCoroutineCurrentResume); |
| 717 | + arguments.push_back( |
| 718 | + Builder.CreateBitOrPointerCast(getAsyncTask(), IGM.Int8PtrTy)); |
| 719 | + arguments.push_back( |
| 720 | + Builder.CreateBitOrPointerCast(getAsyncExecutor(), IGM.Int8PtrTy)); |
| 721 | + arguments.push_back(Builder.CreateBitOrPointerCast( |
| 722 | + AsyncCoroutineCurrentContinuationContext, IGM.Int8PtrTy)); |
| 723 | + Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_suspend_async, arguments); |
| 724 | + |
| 725 | + auto results = Builder.CreateAtomicCmpXchg( |
| 726 | + contAwaitSyncAddr, null, one, |
| 727 | + llvm::AtomicOrdering::Release /*success ordering*/, |
| 728 | + llvm::AtomicOrdering::Acquire /* failure ordering */, |
| 729 | + llvm::SyncScope::System); |
| 730 | + // Again, are we first at the wait (can only reach that state after |
| 731 | + // continuation.resume/abort is called)? If so abort to wait for the end of |
| 732 | + // the await point to be reached. |
| 733 | + auto firstAtAwait = Builder.CreateExtractValue(results, 1); |
| 734 | + Builder.CreateCondBr(firstAtAwait, abortBB, contBB2); |
| 735 | + } |
| 736 | + |
| 737 | + Builder.emitBlock(contBB2); |
| 738 | + auto contBB3 = createBasicBlock("await.async.normal"); |
| 739 | + if (optionalErrorBB) { |
| 740 | + auto contErrResultAddr = Address( |
| 741 | + Builder.CreateStructGEP(AsyncCoroutineCurrentContinuationContext, 2), |
| 742 | + pointerAlignment); |
| 743 | + auto errorRes = Builder.CreateLoad(contErrResultAddr); |
| 744 | + auto nullError = llvm::Constant::getNullValue(errorRes->getType()); |
| 745 | + auto hasError = Builder.CreateICmpNE(errorRes, nullError); |
| 746 | + optionalErrorResult->addIncoming(errorRes, Builder.GetInsertBlock()); |
| 747 | + Builder.CreateCondBr(hasError, optionalErrorBB, contBB3); |
| 748 | + } else { |
| 749 | + Builder.CreateBr(contBB3); |
| 750 | + } |
| 751 | + |
| 752 | + Builder.emitBlock(contBB3); |
| 753 | + if (!isIndirectResult) { |
| 754 | + auto contResultAddrAddr = |
| 755 | + Builder.CreateStructGEP(AsyncCoroutineCurrentContinuationContext, 3); |
| 756 | + auto resultAddrVal = |
| 757 | + Builder.CreateLoad(Address(contResultAddrAddr, pointerAlignment)); |
| 758 | + // Take the result. |
| 759 | + auto resultTy = IGM.getLoweredType(unsafeContinuationTy.getASTType() |
| 760 | + ->castTo<BoundGenericType>() |
| 761 | + ->getGenericArgs()[0] |
| 762 | + ->getCanonicalType()); |
| 763 | + auto &resultTI = cast<LoadableTypeInfo>(getTypeInfo(resultTy)); |
| 764 | + auto resultStorageTy = resultTI.getStorageType(); |
| 765 | + auto resultAddr = |
| 766 | + Address(Builder.CreateBitOrPointerCast(resultAddrVal, |
| 767 | + resultStorageTy->getPointerTo()), |
| 768 | + resultTI.getFixedAlignment()); |
| 769 | + resultTI.loadAsTake(*this, resultAddr, outDirectResult); |
| 770 | + } |
| 771 | + Builder.CreateBr(normalBB); |
| 772 | + AsyncCoroutineCurrentResume = nullptr; |
| 773 | + AsyncCoroutineCurrentContinuationContext = nullptr; |
| 774 | +} |
0 commit comments