Skip to content

Commit aa5b505

Browse files
authored
Allow normal function results of @yield_once coroutines (#69843)
This adds SIL-level support and LLVM codegen for normal results of a coroutine. The main user of this will be autodiff as VJP of a coroutine must be a coroutine itself (in order to produce the yielded result) and return a pullback closure as a normal result. For now only direct results are supported, but this seems to be enough for autodiff purposes.
1 parent 01bce39 commit aa5b505

File tree

71 files changed

+571
-295
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+571
-295
lines changed

docs/SIL.rst

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6067,6 +6067,14 @@ executing the ``begin_apply``) were being "called" by the ``yield``:
60676067
or move the value from that position before ending or aborting the
60686068
coroutine.
60696069

6070+
A coroutine optionally may produce normal results. These do not have
6071+
``@yields`` annotation in the result type tuple.
6072+
::
6073+
(%float, %token) = begin_apply %0() : $@yield_once () -> (@yields Float, Int)
6074+
6075+
Normal results of a coroutine are produced by the corresponding ``end_apply``
6076+
instruction.
6077+
60706078
A ``begin_apply`` must be uniquely either ended or aborted before
60716079
exiting the function or looping to an earlier portion of the function.
60726080

@@ -6096,9 +6104,9 @@ end_apply
60966104
`````````
60976105
::
60986106

6099-
sil-instruction ::= 'end_apply' sil-value
6107+
sil-instruction ::= 'end_apply' sil-value 'as' sil-type
61006108

6101-
end_apply %token
6109+
end_apply %token as $()
61026110

61036111
Ends the given coroutine activation, which is currently suspended at
61046112
a ``yield`` instruction. Transfers control to the coroutine and takes
@@ -6108,8 +6116,8 @@ when the coroutine reaches a ``return`` instruction.
61086116
The operand must always be the token result of a ``begin_apply``
61096117
instruction, which is why it need not specify a type.
61106118

6111-
``end_apply`` currently has no instruction results. If coroutines were
6112-
allowed to have normal results, they would be producted by ``end_apply``.
6119+
The result of ``end_apply`` is the normal result of the coroutine function (the
6120+
operand of the ``return`` instruction)."
61136121

61146122
When throwing coroutines are supported, there will need to be a
61156123
``try_end_apply`` instruction.

include/swift/AST/Types.h

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4729,24 +4729,27 @@ class SILFunctionType final
47294729
using Representation = SILExtInfoBuilder::Representation;
47304730

47314731
private:
4732-
unsigned NumParameters;
4732+
unsigned NumParameters = 0;
47334733

4734-
// These are *normal* results if this is not a coroutine and *yield* results
4735-
// otherwise.
4736-
unsigned NumAnyResults; // Not including the ErrorResult.
4737-
unsigned NumAnyIndirectFormalResults; // Subset of NumAnyResults.
4738-
unsigned NumPackResults; // Subset of NumAnyIndirectFormalResults.
4734+
// These are *normal* results
4735+
unsigned NumAnyResults = 0; // Not including the ErrorResult.
4736+
unsigned NumAnyIndirectFormalResults = 0; // Subset of NumAnyResults.
4737+
unsigned NumPackResults = 0; // Subset of NumAnyIndirectFormalResults.
4738+
// These are *yield* results
4739+
unsigned NumAnyYieldResults = 0; // Not including the ErrorResult.
4740+
unsigned NumAnyIndirectFormalYieldResults = 0; // Subset of NumAnyYieldResults.
4741+
unsigned NumPackYieldResults = 0; // Subset of NumAnyIndirectFormalYieldResults.
47394742

47404743
// [NOTE: SILFunctionType-layout]
47414744
// The layout of a SILFunctionType in memory is:
47424745
// SILFunctionType
47434746
// SILParameterInfo[NumParameters]
4744-
// SILResultInfo[isCoroutine() ? 0 : NumAnyResults]
4747+
// SILResultInfo[NumAnyResults]
47454748
// SILResultInfo? // if hasErrorResult()
4746-
// SILYieldInfo[isCoroutine() ? NumAnyResults : 0]
4749+
// SILYieldInfo[NumAnyYieldResults]
47474750
// SubstitutionMap[HasPatternSubs + HasInvocationSubs]
4748-
// CanType? // if !isCoro && NumAnyResults > 1, formal result cache
4749-
// CanType? // if !isCoro && NumAnyResults > 1, all result cache
4751+
// CanType? // if NumAnyResults > 1, formal result cache
4752+
// CanType? // if NumAnyResults > 1, all result cache
47504753

47514754
CanGenericSignature InvocationGenericSig;
47524755
ProtocolConformanceRef WitnessMethodConformance;
@@ -4785,7 +4788,7 @@ class SILFunctionType final
47854788

47864789
/// Do we have slots for caches of the normal-result tuple type?
47874790
bool hasResultCache() const {
4788-
return NumAnyResults > 1 && !isCoroutine();
4791+
return NumAnyResults > 1;
47894792
}
47904793

47914794
CanType &getMutableFormalResultsCache() const {
@@ -4873,14 +4876,14 @@ class SILFunctionType final
48734876
ArrayRef<SILYieldInfo> getYields() const {
48744877
return const_cast<SILFunctionType *>(this)->getMutableYields();
48754878
}
4876-
unsigned getNumYields() const { return isCoroutine() ? NumAnyResults : 0; }
4879+
unsigned getNumYields() const { return NumAnyYieldResults; }
48774880

48784881
/// Return the array of all result information. This may contain inter-mingled
48794882
/// direct and indirect results.
48804883
ArrayRef<SILResultInfo> getResults() const {
48814884
return const_cast<SILFunctionType *>(this)->getMutableResults();
48824885
}
4883-
unsigned getNumResults() const { return isCoroutine() ? 0 : NumAnyResults; }
4886+
unsigned getNumResults() const { return NumAnyResults; }
48844887

48854888
ArrayRef<SILResultInfo> getResultsWithError() const {
48864889
return const_cast<SILFunctionType *>(this)->getMutableResultsWithError();
@@ -4917,17 +4920,17 @@ class SILFunctionType final
49174920
// indirect property, not the SIL indirect property, should be consulted to
49184921
// determine whether function reabstraction is necessary.
49194922
unsigned getNumIndirectFormalResults() const {
4920-
return isCoroutine() ? 0 : NumAnyIndirectFormalResults;
4923+
return NumAnyIndirectFormalResults;
49214924
}
49224925
/// Does this function have any formally indirect results?
49234926
bool hasIndirectFormalResults() const {
49244927
return getNumIndirectFormalResults() != 0;
49254928
}
49264929
unsigned getNumDirectFormalResults() const {
4927-
return isCoroutine() ? 0 : NumAnyResults - NumAnyIndirectFormalResults;
4930+
return NumAnyResults - NumAnyIndirectFormalResults;
49284931
}
49294932
unsigned getNumPackResults() const {
4930-
return isCoroutine() ? 0 : NumPackResults;
4933+
return NumPackResults;
49314934
}
49324935
bool hasIndirectErrorResult() const {
49334936
return hasErrorResult() && getErrorResult().isFormalIndirect();
@@ -4985,17 +4988,17 @@ class SILFunctionType final
49854988
TypeExpansionContext expansion);
49864989

49874990
unsigned getNumIndirectFormalYields() const {
4988-
return isCoroutine() ? NumAnyIndirectFormalResults : 0;
4991+
return NumAnyIndirectFormalYieldResults;
49894992
}
49904993
/// Does this function have any formally indirect yields?
49914994
bool hasIndirectFormalYields() const {
49924995
return getNumIndirectFormalYields() != 0;
49934996
}
49944997
unsigned getNumDirectFormalYields() const {
4995-
return isCoroutine() ? NumAnyResults - NumAnyIndirectFormalResults : 0;
4998+
return NumAnyYieldResults - NumAnyIndirectFormalYieldResults;
49964999
}
49975000
unsigned getNumPackYields() const {
4998-
return isCoroutine() ? NumPackResults : 0;
5001+
return NumPackYieldResults;
49995002
}
50005003

50015004
struct IndirectFormalYieldFilter {

include/swift/SIL/SILBuilder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -576,11 +576,11 @@ class SILBuilder {
576576
beginApply));
577577
}
578578

579-
EndApplyInst *createEndApply(SILLocation loc, SILValue beginApply) {
579+
EndApplyInst *createEndApply(SILLocation loc, SILValue beginApply, SILType ResultType) {
580580
return insert(new (getModule()) EndApplyInst(getSILDebugLocation(loc),
581-
beginApply));
581+
beginApply, ResultType));
582582
}
583-
583+
584584
BuiltinInst *createBuiltin(SILLocation Loc, Identifier Name, SILType ResultTy,
585585
SubstitutionMap Subs,
586586
ArrayRef<SILValue> Args) {

include/swift/SIL/SILCloner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,8 @@ SILCloner<ImplClass>::visitEndApplyInst(EndApplyInst *Inst) {
10761076
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
10771077
recordClonedInstruction(
10781078
Inst, getBuilder().createEndApply(getOpLocation(Inst->getLoc()),
1079-
getOpValue(Inst->getOperand())));
1079+
getOpValue(Inst->getOperand()),
1080+
getOpType(Inst->getType())));
10801081
}
10811082

10821083
template<typename ImplClass>

include/swift/SIL/SILInstruction.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3200,11 +3200,12 @@ class AbortApplyInst
32003200
/// normally.
32013201
class EndApplyInst
32023202
: public UnaryInstructionBase<SILInstructionKind::EndApplyInst,
3203-
NonValueInstruction> {
3203+
SingleValueInstruction> {
32043204
friend SILBuilder;
32053205

3206-
EndApplyInst(SILDebugLocation debugLoc, SILValue beginApplyToken)
3207-
: UnaryInstructionBase(debugLoc, beginApplyToken) {
3206+
EndApplyInst(SILDebugLocation debugLoc, SILValue beginApplyToken,
3207+
SILType Ty)
3208+
: UnaryInstructionBase(debugLoc, beginApplyToken, Ty) {
32083209
assert(isaResultOf<BeginApplyInst>(beginApplyToken) &&
32093210
isaResultOf<BeginApplyInst>(beginApplyToken)->isBeginApplyToken());
32103211
}

include/swift/SIL/SILNodes.def

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction)
568568
SingleValueInstruction, MayHaveSideEffects, MayRelease)
569569
SINGLE_VALUE_INST(PartialApplyInst, partial_apply,
570570
SingleValueInstruction, MayHaveSideEffects, DoesNotRelease)
571+
SINGLE_VALUE_INST(EndApplyInst, end_apply,
572+
SILInstruction, MayHaveSideEffects, MayRelease)
571573

572574
// Metatypes
573575
SINGLE_VALUE_INST(MetatypeInst, metatype,
@@ -871,8 +873,6 @@ NON_VALUE_INST(UncheckedRefCastAddrInst, unchecked_ref_cast_addr,
871873
SILInstruction, MayHaveSideEffects, DoesNotRelease)
872874
NON_VALUE_INST(AllocGlobalInst, alloc_global,
873875
SILInstruction, MayHaveSideEffects, DoesNotRelease)
874-
NON_VALUE_INST(EndApplyInst, end_apply,
875-
SILInstruction, MayHaveSideEffects, MayRelease)
876876
NON_VALUE_INST(AbortApplyInst, abort_apply,
877877
SILInstruction, MayHaveSideEffects, MayRelease)
878878
NON_VALUE_INST(PackElementSetInst, pack_element_set,

lib/AST/ASTContext.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4611,29 +4611,29 @@ SILFunctionType::SILFunctionType(
46114611
!ext.getLifetimeDependenceInfo().empty();
46124612
Bits.SILFunctionType.CoroutineKind = unsigned(coroutineKind);
46134613
NumParameters = params.size();
4614-
if (coroutineKind == SILCoroutineKind::None) {
4615-
assert(yields.empty());
4616-
NumAnyResults = normalResults.size();
4617-
NumAnyIndirectFormalResults = 0;
4618-
NumPackResults = 0;
4619-
for (auto &resultInfo : normalResults) {
4620-
if (resultInfo.isFormalIndirect())
4621-
NumAnyIndirectFormalResults++;
4622-
if (resultInfo.isPack())
4623-
NumPackResults++;
4624-
}
4625-
memcpy(getMutableResults().data(), normalResults.data(),
4626-
normalResults.size() * sizeof(SILResultInfo));
4627-
} else {
4628-
assert(normalResults.empty());
4629-
NumAnyResults = yields.size();
4630-
NumAnyIndirectFormalResults = 0;
4614+
assert((coroutineKind == SILCoroutineKind::None && yields.empty()) ||
4615+
coroutineKind != SILCoroutineKind::None);
4616+
4617+
NumAnyResults = normalResults.size();
4618+
NumAnyIndirectFormalResults = 0;
4619+
NumPackResults = 0;
4620+
for (auto &resultInfo : normalResults) {
4621+
if (resultInfo.isFormalIndirect())
4622+
NumAnyIndirectFormalResults++;
4623+
if (resultInfo.isPack())
4624+
NumPackResults++;
4625+
}
4626+
memcpy(getMutableResults().data(), normalResults.data(),
4627+
normalResults.size() * sizeof(SILResultInfo));
4628+
if (coroutineKind != SILCoroutineKind::None) {
4629+
NumAnyYieldResults = yields.size();
4630+
NumAnyIndirectFormalYieldResults = 0;
46314631
NumPackResults = 0;
46324632
for (auto &yieldInfo : yields) {
46334633
if (yieldInfo.isFormalIndirect())
4634-
NumAnyIndirectFormalResults++;
4634+
NumAnyIndirectFormalYieldResults++;
46354635
if (yieldInfo.isPack())
4636-
NumPackResults++;
4636+
NumPackYieldResults++;
46374637
}
46384638
memcpy(getMutableYields().data(), yields.data(),
46394639
yields.size() * sizeof(SILYieldInfo));
@@ -4805,7 +4805,6 @@ CanSILFunctionType SILFunctionType::get(
48054805
llvm::Optional<SILResultInfo> errorResult, SubstitutionMap patternSubs,
48064806
SubstitutionMap invocationSubs, const ASTContext &ctx,
48074807
ProtocolConformanceRef witnessMethodConformance) {
4808-
assert(coroutineKind == SILCoroutineKind::None || normalResults.empty());
48094808
assert(coroutineKind != SILCoroutineKind::None || yields.empty());
48104809
assert(!ext.isPseudogeneric() || genericSig ||
48114810
coroutineKind != SILCoroutineKind::None);

lib/IRGen/GenCall.cpp

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -656,24 +656,34 @@ namespace {
656656
}
657657

658658
void SignatureExpansion::expandCoroutineResult(bool forContinuation) {
659-
assert(FnType->getNumResults() == 0 &&
660-
"having both normal and yield results is currently unsupported");
661-
662659
// The return type may be different for the ramp function vs. the
663660
// continuations.
664661
if (forContinuation) {
665662
switch (FnType->getCoroutineKind()) {
666663
case SILCoroutineKind::None:
667664
llvm_unreachable("should have been filtered out before here");
668665

669-
// Yield-once coroutines just return void from the continuation.
670-
case SILCoroutineKind::YieldOnce:
671-
ResultIRType = IGM.VoidTy;
666+
// Yield-once coroutines may optionaly return a value from the continuation.
667+
case SILCoroutineKind::YieldOnce: {
668+
auto fnConv = getSILFuncConventions();
669+
670+
assert(fnConv.getNumIndirectSILResults() == 0);
671+
// Ensure that no parameters were added before to correctly record their ABI
672+
// details.
673+
assert(ParamIRTypes.empty());
674+
675+
// Expand the direct result.
676+
const TypeInfo *directResultTypeInfo;
677+
std::tie(ResultIRType, directResultTypeInfo) = expandDirectResult();
678+
672679
return;
680+
}
673681

674682
// Yield-many coroutines yield the same types from the continuation
675683
// as they do from the ramp function.
676684
case SILCoroutineKind::YieldMany:
685+
assert(FnType->getNumResults() == 0 &&
686+
"having both normal and yield results is currently unsupported");
677687
break;
678688
}
679689
}
@@ -5803,6 +5813,53 @@ void irgen::emitAsyncReturn(IRGenFunction &IGF, AsyncContextLayout &asyncLayout,
58035813
emitAsyncReturn(IGF, asyncLayout, fnType, nativeResults);
58045814
}
58055815

5816+
void irgen::emitYieldOnceCoroutineResult(IRGenFunction &IGF, Explosion &result,
5817+
SILType funcResultType, SILType returnResultType) {
5818+
auto &Builder = IGF.Builder;
5819+
auto &IGM = IGF.IGM;
5820+
5821+
// Create coroutine exit block and branch to it.
5822+
auto coroEndBB = IGF.createBasicBlock("coro.end.normal");
5823+
IGF.setCoroutineExitBlock(coroEndBB);
5824+
Builder.CreateBr(coroEndBB);
5825+
5826+
// Emit the block.
5827+
Builder.emitBlock(coroEndBB);
5828+
auto handle = IGF.getCoroutineHandle();
5829+
5830+
llvm::Value *resultToken = nullptr;
5831+
if (result.empty()) {
5832+
assert(IGM.getTypeInfo(returnResultType)
5833+
.nativeReturnValueSchema(IGM)
5834+
.empty() &&
5835+
"Empty explosion must match the native calling convention");
5836+
// No results: just use none token
5837+
resultToken = llvm::ConstantTokenNone::get(Builder.getContext());
5838+
} else {
5839+
// Capture results via `coro_end_results` intrinsic
5840+
result = IGF.coerceValueTo(returnResultType, result, funcResultType);
5841+
auto &nativeSchema =
5842+
IGM.getTypeInfo(funcResultType).nativeReturnValueSchema(IGM);
5843+
assert(!nativeSchema.requiresIndirect());
5844+
5845+
Explosion native = nativeSchema.mapIntoNative(IGM, IGF, result,
5846+
funcResultType,
5847+
false /* isOutlined */);
5848+
SmallVector<llvm::Value *, 1> args;
5849+
for (unsigned i = 0, e = native.size(); i != e; ++i)
5850+
args.push_back(native.claimNext());
5851+
5852+
resultToken =
5853+
Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_end_results, args);
5854+
}
5855+
5856+
Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_end,
5857+
{handle,
5858+
/*is unwind*/ Builder.getFalse(),
5859+
resultToken});
5860+
Builder.CreateUnreachable();
5861+
}
5862+
58065863
FunctionPointer
58075864
IRGenFunction::getFunctionPointerForResumeIntrinsic(llvm::Value *resume) {
58085865
auto *fnTy = llvm::FunctionType::get(

lib/IRGen/GenCall.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ namespace irgen {
266266
SILType funcResultTypeInContext,
267267
CanSILFunctionType fnType, Explosion &result,
268268
Explosion &error);
269+
void emitYieldOnceCoroutineResult(IRGenFunction &IGF, Explosion &result,
270+
SILType funcResultType, SILType returnResultType);
269271

270272
Address emitAutoDiffCreateLinearMapContextWithType(
271273
IRGenFunction &IGF, llvm::Value *topLevelSubcontextMetatype);

lib/IRGen/IRGenFunction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ void IRGenFunction::emitAwaitAsyncContinuation(
709709
// because the continuation result is not available yet. When the
710710
// continuation is later resumed, the task will get scheduled
711711
// starting from the suspension point.
712-
emitCoroutineOrAsyncExit();
712+
emitCoroutineOrAsyncExit(false);
713713
}
714714

715715
Builder.emitBlock(contBB);

lib/IRGen/IRGenFunction.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,16 @@ class IRGenFunction {
155155
CoroutineHandle = handle;
156156
}
157157

158+
llvm::BasicBlock *getCoroutineExitBlock() const {
159+
return CoroutineExitBlock;
160+
}
161+
162+
void setCoroutineExitBlock(llvm::BasicBlock *block) {
163+
assert(CoroutineExitBlock == nullptr && "already set exit BB");
164+
assert(block != nullptr && "setting a null exit BB");
165+
CoroutineExitBlock = block;
166+
}
167+
158168
llvm::Value *getAsyncTask();
159169
llvm::Value *getAsyncContext();
160170
void storeCurrentAsyncContext(llvm::Value *context);
@@ -236,7 +246,7 @@ class IRGenFunction {
236246
bool callsAnyAlwaysInlineThunksWithForeignExceptionTraps = false;
237247

238248
public:
239-
void emitCoroutineOrAsyncExit();
249+
void emitCoroutineOrAsyncExit(bool isUnwind);
240250

241251
//--- Helper methods -----------------------------------------------------------
242252
public:

0 commit comments

Comments
 (0)