Skip to content

Commit 3f5278a

Browse files
asavonicasl
authored andcommitted
Partial apply for coroutines
The patch adds lowering of partial_apply instructions for coroutines. This pattern seems to trigger a lot of type mismatch errors in IRGen, because coroutine functions are not substituted in the same way as regular functions (see the patch 07f03bd "Use pattern substitutions to consistently abstract yields" for more details). The odd type conversions in the patch are related to this issue, and these should be checked carefully. Perhaps it is better to enable substitutions for coroutine functions instead (at least for some cases). Other than that, lowering of partial_apply for coroutines is straightforward: we generate another coroutine that captures arguments passed to the partial_apply instructions. It calls the original coroutine for yields (first return) and yields the resulting values. Then it calls the original function's continuation for return or unwind, and forwards them to the caller as well. After IRGen, LLVM's Coroutine pass transforms the generated coroutine (along with all other coroutines) and eliminates llvm.coro.* intrinsics. LIT tests check LLVM IR after this transformation.
1 parent b55651f commit 3f5278a

File tree

4 files changed

+2214
-16
lines changed

4 files changed

+2214
-16
lines changed

lib/IRGen/GenDecl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6163,6 +6163,7 @@ IRGenModule::getAddrOfContinuationPrototype(CanSILFunctionType fnType) {
61636163
llvm::Function *&entry = GlobalFuncs[entity];
61646164
if (entry) return entry;
61656165

6166+
GenericContextScope scope(*this, fnType->getInvocationGenericSignature());
61666167
auto signature = Signature::forCoroutineContinuation(*this, fnType);
61676168
LinkInfo link = LinkInfo::get(*this, entity, NotForDefinition);
61686169
entry = createFunction(*this, link, signature);

lib/IRGen/GenFunc.cpp

Lines changed: 264 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
#include "llvm/Support/Debug.h"
9292

9393
#include "BitPatternBuilder.h"
94+
#include "CallEmission.h"
9495
#include "Callee.h"
9596
#include "ConstantBuilder.h"
9697
#include "EnumPayload.h"
@@ -109,11 +110,11 @@
109110
#include "HeapTypeInfo.h"
110111
#include "IRGenDebugInfo.h"
111112
#include "IRGenFunction.h"
113+
#include "IRGenMangler.h"
112114
#include "IRGenModule.h"
113115
#include "IndirectTypeInfo.h"
114116
#include "ScalarPairTypeInfo.h"
115117
#include "Signature.h"
116-
#include "IRGenMangler.h"
117118

118119
using namespace swift;
119120
using namespace irgen;
@@ -1119,11 +1120,11 @@ class PartialApplicationForwarderEmission {
11191120
virtual void addDynamicFunctionContext(Explosion &explosion) = 0;
11201121
virtual void addDynamicFunctionPointer(Explosion &explosion) = 0;
11211122

1122-
void addSelf(Explosion &explosion) { addArgument(explosion); }
1123-
void addWitnessSelfMetadata(llvm::Value *value) {
1123+
virtual void addSelf(Explosion &explosion) { addArgument(explosion); }
1124+
virtual void addWitnessSelfMetadata(llvm::Value *value) {
11241125
addArgument(value);
11251126
}
1126-
void addWitnessSelfWitnessTable(llvm::Value *value) {
1127+
virtual void addWitnessSelfWitnessTable(llvm::Value *value) {
11271128
addArgument(value);
11281129
}
11291130
virtual void forwardErrorResult() = 0;
@@ -1136,6 +1137,14 @@ class PartialApplicationForwarderEmission {
11361137
virtual void end(){};
11371138
virtual ~PartialApplicationForwarderEmission() {}
11381139
};
1140+
1141+
static Size getYieldOnceCoroutineBufferSize(IRGenModule &IGM) {
1142+
return NumWords_YieldOnceBuffer * IGM.getPointerSize();
1143+
}
1144+
static Alignment getYieldOnceCoroutineBufferAlignment(IRGenModule &IGM) {
1145+
return IGM.getPointerAlignment();
1146+
}
1147+
11391148
class SyncPartialApplicationForwarderEmission
11401149
: public PartialApplicationForwarderEmission {
11411150
using super = PartialApplicationForwarderEmission;
@@ -1422,6 +1431,239 @@ class AsyncPartialApplicationForwarderEmission
14221431
super::end();
14231432
}
14241433
};
1434+
1435+
class CoroPartialApplicationForwarderEmission
1436+
: public PartialApplicationForwarderEmission {
1437+
using super = PartialApplicationForwarderEmission;
1438+
1439+
private:
1440+
llvm::Value *Self;
1441+
llvm::Value *FirstData;
1442+
llvm::Value *SecondData;
1443+
WitnessMetadata Witness;
1444+
1445+
public:
1446+
CoroPartialApplicationForwarderEmission(
1447+
IRGenModule &IGM, IRGenFunction &subIGF, llvm::Function *fwd,
1448+
const std::optional<FunctionPointer> &staticFnPtr, bool calleeHasContext,
1449+
const Signature &origSig, CanSILFunctionType origType,
1450+
CanSILFunctionType substType, CanSILFunctionType outType,
1451+
SubstitutionMap subs, HeapLayout const *layout,
1452+
ArrayRef<ParameterConvention> conventions)
1453+
: PartialApplicationForwarderEmission(
1454+
IGM, subIGF, fwd, staticFnPtr, calleeHasContext, origSig, origType,
1455+
substType, outType, subs, layout, conventions),
1456+
Self(nullptr), FirstData(nullptr), SecondData(nullptr) {}
1457+
1458+
void begin() override {
1459+
auto prototype = subIGF.IGM.getOpaquePtr(
1460+
subIGF.IGM.getAddrOfContinuationPrototype(origType));
1461+
1462+
// Use malloc and free as our allocator.
1463+
auto allocFn = subIGF.IGM.getOpaquePtr(subIGF.IGM.getMallocFn());
1464+
auto deallocFn = subIGF.IGM.getOpaquePtr(subIGF.IGM.getFreeFn());
1465+
1466+
// Call the right 'llvm.coro.id.retcon' variant.
1467+
llvm::Value *buffer = origParams.claimNext();
1468+
llvm::Value *id = subIGF.Builder.CreateIntrinsicCall(
1469+
llvm::Intrinsic::coro_id_retcon_once,
1470+
{llvm::ConstantInt::get(
1471+
subIGF.IGM.Int32Ty,
1472+
getYieldOnceCoroutineBufferSize(subIGF.IGM).getValue()),
1473+
llvm::ConstantInt::get(
1474+
subIGF.IGM.Int32Ty,
1475+
getYieldOnceCoroutineBufferAlignment(subIGF.IGM).getValue()),
1476+
buffer, prototype, allocFn, deallocFn});
1477+
1478+
// Call 'llvm.coro.begin', just for consistency with the normal pattern.
1479+
// This serves as a handle that we can pass around to other intrinsics.
1480+
auto hdl = subIGF.Builder.CreateIntrinsicCall(
1481+
llvm::Intrinsic::coro_begin,
1482+
{id, llvm::ConstantPointerNull::get(subIGF.IGM.Int8PtrTy)});
1483+
1484+
// Set the coroutine handle; this also flags that is a coroutine so that
1485+
// e.g. dynamic allocas use the right code generation.
1486+
subIGF.setCoroutineHandle(hdl);
1487+
1488+
auto *pt = subIGF.Builder.IRBuilderBase::CreateAlloca(
1489+
subIGF.IGM.Int1Ty,
1490+
/*array size*/ nullptr, "earliest insert point");
1491+
subIGF.setEarliestInsertionPoint(pt);
1492+
}
1493+
1494+
void gatherArgumentsFromApply() override {
1495+
super::gatherArgumentsFromApply(false);
1496+
}
1497+
llvm::Value *getDynamicFunctionPointer() override {
1498+
llvm::Value *Ret = SecondData;
1499+
SecondData = nullptr;
1500+
return Ret;
1501+
}
1502+
llvm::Value *getDynamicFunctionContext() override {
1503+
llvm::Value *Ret = FirstData;
1504+
FirstData = nullptr;
1505+
return Ret;
1506+
}
1507+
void addDynamicFunctionContext(Explosion &explosion) override {
1508+
assert(!Self && "context value overrides 'self'");
1509+
FirstData = explosion.claimNext();
1510+
}
1511+
void addDynamicFunctionPointer(Explosion &explosion) override {
1512+
SecondData = explosion.claimNext();
1513+
}
1514+
void addSelf(Explosion &explosion) override {
1515+
assert(!FirstData && "'self' overrides another context value");
1516+
if (!hasSelfContextParameter(origType)) {
1517+
// witness methods can be declared on types that are not classes. Pass
1518+
// such "self" argument as a plain argument.
1519+
addArgument(explosion);
1520+
return;
1521+
}
1522+
Self = explosion.claimNext();
1523+
FirstData = Self;
1524+
}
1525+
1526+
void addWitnessSelfMetadata(llvm::Value *value) override {
1527+
Witness.SelfMetadata = value;
1528+
}
1529+
1530+
void addWitnessSelfWitnessTable(llvm::Value *value) override {
1531+
Witness.SelfWitnessTable = value;
1532+
}
1533+
1534+
void forwardErrorResult() override {
1535+
bool isTypedError = origConv.isTypedError();
1536+
SILType origErrorTy =
1537+
origConv.getSILErrorType(subIGF.IGM.getMaximalTypeExpansionContext());
1538+
auto errorAlignment =
1539+
isTypedError ? subIGF.IGM.getPointerAlignment()
1540+
: cast<FixedTypeInfo>(subIGF.getTypeInfo(origErrorTy))
1541+
.getFixedAlignment();
1542+
auto errorStorageType =
1543+
isTypedError ? IGM.Int8PtrTy
1544+
: cast<FixedTypeInfo>(subIGF.getTypeInfo(origErrorTy))
1545+
.getStorageType();
1546+
llvm::Value *errorResultPtr = origParams.claimNext();
1547+
subIGF.setCallerErrorResultSlot(
1548+
Address(errorResultPtr, errorStorageType, errorAlignment));
1549+
}
1550+
1551+
Explosion callCoroutine(FunctionPointer &fnPtr) {
1552+
Callee callee({origType, substType, subs}, fnPtr, FirstData, SecondData);
1553+
1554+
std::unique_ptr<CallEmission> emitSuspend =
1555+
getCallEmission(subIGF, Self, std::move(callee));
1556+
1557+
emitSuspend->begin();
1558+
emitSuspend->setArgs(args, /*isOutlined=*/false, &Witness);
1559+
Explosion yieldedValues;
1560+
emitSuspend->emitToExplosion(yieldedValues, /*isOutlined=*/false);
1561+
emitSuspend->end();
1562+
emitSuspend->claimTemporaries().destroyAll(subIGF);
1563+
1564+
if (origConv.getSILResultType(subIGF.IGM.getMaximalTypeExpansionContext())
1565+
.hasTypeParameter()) {
1566+
1567+
ArrayRef<llvm::Value *> yieldValues = yieldedValues.claimAll();
1568+
ArrayRef<llvm::Type *> retTypes =
1569+
cast<llvm::StructType>(fwd->getReturnType())->elements();
1570+
Explosion yieldCoerced;
1571+
assert(yieldValues.size() == retTypes.size() &&
1572+
"mismatch between return types of the wrapper and the callee");
1573+
for (unsigned i = 0; i < yieldValues.size(); ++i) {
1574+
llvm::Value *v = yieldValues[i];
1575+
if (v->getType() != retTypes[i]) {
1576+
v = subIGF.coerceValue(v, retTypes[i], subIGF.IGM.DataLayout);
1577+
}
1578+
yieldCoerced.add(v);
1579+
}
1580+
return yieldCoerced;
1581+
}
1582+
1583+
return yieldedValues;
1584+
}
1585+
1586+
llvm::CallInst *createCall(FunctionPointer &fnPtr) override {
1587+
/// Call the wrapped coroutine
1588+
///
1589+
Address calleeBuf = emitAllocYieldOnceCoroutineBuffer(subIGF);
1590+
llvm::Value *calleeHandle = calleeBuf.getAddress();
1591+
args.insert(0, calleeHandle);
1592+
Explosion yieldedValues = callCoroutine(fnPtr);
1593+
1594+
/// Get the continuation function pointer
1595+
///
1596+
PointerAuthInfo newAuthInfo =
1597+
fnPtr.getAuthInfo().getCorrespondingCodeAuthInfo();
1598+
FunctionPointer contFn = FunctionPointer::createSigned(
1599+
FunctionPointer::Kind::Function, yieldedValues.claimNext(), newAuthInfo,
1600+
Signature::forCoroutineContinuation(subIGF.IGM, origType));
1601+
1602+
/// Forward the remaining yields of the wrapped coroutine
1603+
///
1604+
llvm::Value *condUnwind = emitYield(subIGF, substType, yieldedValues);
1605+
1606+
llvm::BasicBlock *unwindBB = subIGF.createBasicBlock("unwind");
1607+
llvm::BasicBlock *resumeBB = subIGF.createBasicBlock("resume");
1608+
llvm::BasicBlock *cleanupBB = subIGF.createBasicBlock("cleanup");
1609+
subIGF.CurFn->insert(subIGF.CurFn->end(), unwindBB);
1610+
subIGF.CurFn->insert(subIGF.CurFn->end(), resumeBB);
1611+
subIGF.CurFn->insert(subIGF.CurFn->end(), cleanupBB);
1612+
subIGF.Builder.CreateCondBr(condUnwind, unwindBB, resumeBB);
1613+
1614+
/// Call for the results
1615+
///
1616+
subIGF.Builder.SetInsertPoint(resumeBB);
1617+
1618+
auto isResume = llvm::ConstantInt::get(IGM.Int1Ty, /*isAbort*/ false);
1619+
auto *call = subIGF.Builder.CreateCall(contFn, {calleeHandle, isResume});
1620+
1621+
/// Emit coro_end for results and forward them
1622+
///
1623+
llvm::Type *callTy = call->getType();
1624+
llvm::Value *noneToken =
1625+
llvm::ConstantTokenNone::get(subIGF.Builder.getContext());
1626+
llvm::Value *resultToken = nullptr;
1627+
if (callTy->isVoidTy()) {
1628+
resultToken = noneToken;
1629+
} else if (llvm::StructType *sty = dyn_cast<llvm::StructType>(callTy)) {
1630+
Explosion splitCall;
1631+
subIGF.emitAllExtractValues(call, sty, splitCall);
1632+
resultToken = subIGF.Builder.CreateIntrinsicCall(
1633+
llvm::Intrinsic::coro_end_results, splitCall.claimAll());
1634+
} else {
1635+
resultToken = subIGF.Builder.CreateIntrinsicCall(
1636+
llvm::Intrinsic::coro_end_results, call);
1637+
}
1638+
1639+
llvm::Value *fwdHandle = subIGF.getCoroutineHandle();
1640+
subIGF.Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_end,
1641+
{fwdHandle, isResume, resultToken});
1642+
subIGF.Builder.CreateBr(cleanupBB);
1643+
1644+
/// Emit coro_end for unwind
1645+
///
1646+
subIGF.Builder.SetInsertPoint(unwindBB);
1647+
auto isUnwind = llvm::ConstantInt::get(IGM.Int1Ty, /*isAbort*/ true);
1648+
subIGF.Builder.CreateCall(contFn, {calleeHandle, isUnwind});
1649+
subIGF.Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_end,
1650+
{fwdHandle, isUnwind, noneToken});
1651+
subIGF.Builder.CreateBr(cleanupBB);
1652+
1653+
subIGF.Builder.SetInsertPoint(cleanupBB);
1654+
emitDeallocYieldOnceCoroutineBuffer(subIGF, calleeBuf);
1655+
llvm::Instruction *cleanupPt = subIGF.Builder.CreateUnreachable();
1656+
subIGF.Builder.SetInsertPoint(cleanupPt);
1657+
1658+
return nullptr;
1659+
}
1660+
1661+
void createReturn(llvm::CallInst *call) override {
1662+
// Do nothing, yield/return/unwind blocks are already created in createCall.
1663+
}
1664+
void end() override { super::end(); }
1665+
};
1666+
14251667
std::unique_ptr<PartialApplicationForwarderEmission>
14261668
getPartialApplicationForwarderEmission(
14271669
IRGenModule &IGM, IRGenFunction &subIGF, llvm::Function *fwd,
@@ -1434,6 +1676,11 @@ getPartialApplicationForwarderEmission(
14341676
return std::make_unique<AsyncPartialApplicationForwarderEmission>(
14351677
IGM, subIGF, fwd, staticFnPtr, calleeHasContext, origSig, origType,
14361678
substType, outType, subs, layout, conventions);
1679+
} else if (origType->isCoroutine()) {
1680+
return std::make_unique<CoroPartialApplicationForwarderEmission>(
1681+
IGM, subIGF, fwd, staticFnPtr, calleeHasContext, origSig, origType,
1682+
substType, outType, subs, layout, conventions);
1683+
14371684
} else {
14381685
return std::make_unique<SyncPartialApplicationForwarderEmission>(
14391686
IGM, subIGF, fwd, staticFnPtr, calleeHasContext, origSig, origType,
@@ -1611,8 +1858,6 @@ static llvm::Value *emitPartialApplicationForwarder(
16111858
ti.isSingleSwiftRetainablePointer(ResilienceExpansion::Maximal))
16121859
ref = subIGF.coerceValue(rawData, ti.getStorageType(),
16131860
subIGF.IGM.DataLayout);
1614-
else
1615-
ref = subIGF.Builder.CreateBitCast(rawData, ti.getStorageType());
16161861
param.add(ref);
16171862
bindPolymorphicParameter(subIGF, origType, substType, param, paramI);
16181863
(void)param.claimAll();
@@ -1687,7 +1932,7 @@ static llvm::Value *emitPartialApplicationForwarder(
16871932
auto argIndex = emission->getCurrentArgumentIndex();
16881933
if (haveContextArgument)
16891934
argIndex += polyArgs.size();
1690-
if (origType->isAsync())
1935+
if (origType->isAsync() || origType->isCoroutine())
16911936
argIndex += 1;
16921937

16931938
llvm::Type *expectedArgTy = origSig.getType()->getParamType(argIndex);
@@ -1712,10 +1957,15 @@ static llvm::Value *emitPartialApplicationForwarder(
17121957
} else {
17131958
argValue = subIGF.Builder.CreateBitCast(rawData, expectedArgTy);
17141959
}
1715-
emission->addArgument(argValue);
1960+
if (haveContextArgument) {
1961+
Explosion e;
1962+
e.add(argValue);
1963+
emission->addDynamicFunctionContext(e);
1964+
} else
1965+
emission->addArgument(argValue);
17161966

1717-
// If there's a data pointer required, grab it and load out the
1718-
// extra, previously-curried parameters.
1967+
// If there's a data pointer required, grab it and load out the
1968+
// extra, previously-curried parameters.
17191969
} else {
17201970
unsigned origParamI = outType->getParameters().size();
17211971
unsigned extraFieldIndex = 0;
@@ -1975,8 +2225,8 @@ static llvm::Value *emitPartialApplicationForwarder(
19752225

19762226
llvm::CallInst *call = emission->createCall(fnPtr);
19772227

1978-
if (!origType->isAsync() && addressesToDeallocate.empty() && !needsAllocas &&
1979-
(!consumesContext || !dependsOnContextLifetime))
2228+
if (!origType->isAsync() && !origType->isCoroutine() && addressesToDeallocate.empty() &&
2229+
!needsAllocas && (!consumesContext || !dependsOnContextLifetime))
19802230
call->setTailCall();
19812231

19822232
// Deallocate everything we allocated above.
@@ -2030,7 +2280,7 @@ std::optional<StackAddress> irgen::emitFunctionPartialApplication(
20302280
bool considerParameterSources = true;
20312281
for (auto param : params) {
20322282
SILType argType = IGF.IGM.silConv.getSILType(
2033-
param, origType, IGF.IGM.getMaximalTypeExpansionContext());
2283+
param, substType, IGF.IGM.getMaximalTypeExpansionContext());
20342284
auto argLoweringTy = getArgumentLoweringType(argType.getASTType(), param,
20352285
outType->isNoEscape());
20362286
auto &ti = IGF.getTypeInfoForLowered(argLoweringTy);
@@ -2043,7 +2293,7 @@ std::optional<StackAddress> irgen::emitFunctionPartialApplication(
20432293

20442294
auto addParam = [&](SILParameterInfo param) {
20452295
SILType argType = IGF.IGM.silConv.getSILType(
2046-
param, origType, IGF.IGM.getMaximalTypeExpansionContext());
2296+
param, substType, IGF.IGM.getMaximalTypeExpansionContext());
20472297

20482298
auto argLoweringTy = getArgumentLoweringType(argType.getASTType(), param,
20492299
outType->isNoEscape());

lib/IRGen/IRGenSIL.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6867,8 +6867,12 @@ void IRGenSILFunction::visitConvertFunctionInst(swift::ConvertFunctionInst *i) {
68676867
fnType->getRepresentation() != SILFunctionType::Representation::Block) {
68686868
auto *fn = temp.claimNext();
68696869
Explosion res;
6870-
auto sig = IGM.getSignature(fnType);
6871-
res.add(Builder.CreateBitCast(fn, sig.getType()->getPointerTo()));
6870+
auto &fnTI = IGM.getTypeInfoForLowered(fnType);
6871+
auto &fnNative = fnTI.nativeReturnValueSchema(IGM);
6872+
llvm::Value *newFn =
6873+
Builder.CreateBitCast(fn, fnNative.getExpandedType(IGM));
6874+
extractScalarResults(*this, newFn->getType(), newFn, res);
6875+
68726876
setLoweredExplosion(i, res);
68736877
return;
68746878
}

0 commit comments

Comments
 (0)