Skip to content

Commit 0365c6a

Browse files
committed
MoveOnlyAddressUtils: Fixes for borrowing switch over address-only types.
Relax some existing pattern matches and add some unhandled instructions to the walkers so that borrowing switches over address-only enums are properly analyzed for incorrect consumption. Add a `[strict]` flag to `mark_unresolved_move_only_value` to indicate a borrow access that should remain a borrow access even if the subject is later stack-promoted from a box.
1 parent f832ba2 commit 0365c6a

15 files changed

+370
-54
lines changed

include/swift/SIL/SILBuilder.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,9 +1445,11 @@ class SILBuilder {
14451445

14461446
MarkUnresolvedNonCopyableValueInst *createMarkUnresolvedNonCopyableValueInst(
14471447
SILLocation loc, SILValue src,
1448-
MarkUnresolvedNonCopyableValueInst::CheckKind kind) {
1448+
MarkUnresolvedNonCopyableValueInst::CheckKind kind,
1449+
MarkUnresolvedNonCopyableValueInst::IsStrict_t strict
1450+
= MarkUnresolvedNonCopyableValueInst::IsNotStrict) {
14491451
return insert(new (getModule()) MarkUnresolvedNonCopyableValueInst(
1450-
getSILDebugLocation(loc), src, kind));
1452+
getSILDebugLocation(loc), src, kind, strict));
14511453
}
14521454

14531455
MarkUnresolvedReferenceBindingInst *createMarkUnresolvedReferenceBindingInst(

include/swift/SIL/SILInstruction.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8659,15 +8659,31 @@ class MarkUnresolvedNonCopyableValueInst
86598659
/// like class initializers.
86608660
InitableButNotConsumable,
86618661
};
8662+
8663+
/// During SILGen, we have not yet done escape analysis on local variables,
8664+
/// so we conservatively emit them as boxed and let the AllocBoxToStack
8665+
/// pass promote unescaped local variables. As part of this promotion,
8666+
/// non-strict `NoConsumeOrAssign` accesses can be promoted to
8667+
/// `ConsumableAndAssignable` since the variable is locally owned
8668+
/// if it doesn't escape. "Strict" accesses on the other hand preserve
8669+
/// their stricter access constraints. This is useful for representing things
8670+
/// like `borrow` bindings.
8671+
enum IsStrict_t : bool {
8672+
IsNotStrict = false,
8673+
IsStrict = true,
8674+
};
86628675

86638676
private:
86648677
CheckKind kind;
8678+
IsStrict_t strict;
86658679

86668680
MarkUnresolvedNonCopyableValueInst(SILDebugLocation DebugLoc,
8667-
SILValue operand, CheckKind checkKind)
8681+
SILValue operand, CheckKind checkKind,
8682+
IsStrict_t strict = IsNotStrict)
86688683
: UnaryInstructionBase(DebugLoc, operand, operand->getType(),
86698684
operand->getOwnershipKind()),
8670-
kind(checkKind) {
8685+
kind(checkKind),
8686+
strict(strict) {
86718687
assert(operand->getType().isMoveOnly() &&
86728688
"mark_unresolved_non_copyable_value can only take a move only typed "
86738689
"value");
@@ -8689,6 +8705,10 @@ class MarkUnresolvedNonCopyableValueInst
86898705
return true;
86908706
}
86918707
}
8708+
8709+
IsStrict_t isStrict() const {
8710+
return strict;
8711+
}
86928712
};
86938713

86948714
/// A marker instruction that states a given alloc_box or alloc_stack is a

lib/SIL/IR/SILPrinter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2114,6 +2114,9 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
21142114
void visitMarkUnresolvedNonCopyableValueInst(
21152115
MarkUnresolvedNonCopyableValueInst *I) {
21162116
using CheckKind = MarkUnresolvedNonCopyableValueInst::CheckKind;
2117+
if (I->isStrict()) {
2118+
*this << "[strict] ";
2119+
}
21172120
switch (I->getCheckKind()) {
21182121
case CheckKind::Invalid:
21192122
llvm_unreachable("Invalid?!");

lib/SIL/Parser/ParseSIL.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3521,6 +3521,16 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B,
35213521
P.diagnose(InstLoc.getSourceLoc(), diag);
35223522
return true;
35233523
}
3524+
3525+
auto Strict = MarkUnresolvedNonCopyableValueInst::IsNotStrict;
3526+
if (AttrName.equals("strict")) {
3527+
Strict = MarkUnresolvedNonCopyableValueInst::IsStrict;
3528+
if (!parseSILOptional(AttrName, *this)) {
3529+
auto diag = diag::sil_markmustcheck_requires_attribute;
3530+
P.diagnose(InstLoc.getSourceLoc(), diag);
3531+
return true;
3532+
}
3533+
}
35243534

35253535
using CheckKind = MarkUnresolvedNonCopyableValueInst::CheckKind;
35263536
CheckKind CKind =
@@ -3545,7 +3555,8 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B,
35453555
if (parseSILDebugLocation(InstLoc, B))
35463556
return true;
35473557

3548-
auto *MVI = B.createMarkUnresolvedNonCopyableValueInst(InstLoc, Val, CKind);
3558+
auto *MVI = B.createMarkUnresolvedNonCopyableValueInst(InstLoc, Val, CKind,
3559+
Strict);
35493560
ResultVal = MVI;
35503561
break;
35513562
}

lib/SIL/Utils/FieldSensitivePrunedLiveness.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ SubElementOffset::computeForAddress(SILValue projectionDerivedFromRoot,
202202
// really do not want to abort. Instead, our caller can choose to abort if
203203
// they get back a None. This ensures that we do not abort in cases where we
204204
// just want to emit to the user a "I do not understand" error.
205+
LLVM_DEBUG(llvm::dbgs() << "unhandled projection derived from root:\n";
206+
projectionDerivedFromRoot->print(llvm::dbgs()));
207+
205208
return llvm::None;
206209
}
207210
}

lib/SILGen/SILGenBuilder.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,13 +1089,14 @@ ManagedValue SILGenBuilder::createGuaranteedCopyableToMoveOnlyWrapperValue(
10891089

10901090
ManagedValue SILGenBuilder::createMarkUnresolvedNonCopyableValueInst(
10911091
SILLocation loc, ManagedValue value,
1092-
MarkUnresolvedNonCopyableValueInst::CheckKind kind) {
1092+
MarkUnresolvedNonCopyableValueInst::CheckKind kind,
1093+
MarkUnresolvedNonCopyableValueInst::IsStrict_t strict) {
10931094
assert((value.isPlusOne(SGF) || value.isLValue() ||
10941095
value.getType().isAddress()) &&
10951096
"Argument must be at +1 or be an address!");
10961097
CleanupCloner cloner(*this, value);
10971098
auto *mdi = SILBuilder::createMarkUnresolvedNonCopyableValueInst(
1098-
loc, value.forward(getSILGenFunction()), kind);
1099+
loc, value.forward(getSILGenFunction()), kind, strict);
10991100
return cloner.clone(mdi);
11001101
}
11011102

lib/SILGen/SILGenBuilder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,9 @@ class SILGenBuilder : public SILBuilder {
492492
using SILBuilder::createMarkUnresolvedNonCopyableValueInst;
493493
ManagedValue createMarkUnresolvedNonCopyableValueInst(
494494
SILLocation loc, ManagedValue value,
495-
MarkUnresolvedNonCopyableValueInst::CheckKind kind);
495+
MarkUnresolvedNonCopyableValueInst::CheckKind kind,
496+
MarkUnresolvedNonCopyableValueInst::IsStrict_t strict
497+
= MarkUnresolvedNonCopyableValueInst::IsNotStrict);
496498

497499
using SILBuilder::emitCopyValueOperation;
498500
ManagedValue emitCopyValueOperation(SILLocation Loc, ManagedValue v);

lib/SILGen/SILGenPattern.cpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,8 +1420,12 @@ void PatternMatchEmission::bindBorrow(Pattern *pattern, VarDecl *var,
14201420
// Create a notional copy for the borrow checker to use.
14211421
bindValue = bindValue.copy(SGF, pattern);
14221422
}
1423+
// We mark the borrow check as "strict" because we don't want to allow
1424+
// consumes through the binding, even if the original value manages to be
1425+
// stack promoted during AllocBoxToStack or anything like that.
14231426
bindValue = SGF.B.createMarkUnresolvedNonCopyableValueInst(pattern, bindValue,
1424-
MarkUnresolvedNonCopyableValueInst::CheckKind::NoConsumeOrAssign);
1427+
MarkUnresolvedNonCopyableValueInst::CheckKind::NoConsumeOrAssign,
1428+
MarkUnresolvedNonCopyableValueInst::IsStrict);
14251429

14261430
SGF.VarLocs[var] = SILGenFunction::VarLoc::get(bindValue.getValue());
14271431
}
@@ -3198,6 +3202,26 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF,
31983202
SGF.Cleanups.emitBranchAndCleanups(sharedDest, caseBlock, args);
31993203
}
32003204

3205+
class EndAccessCleanup final : public Cleanup {
3206+
SILValue beginAccess;
3207+
public:
3208+
EndAccessCleanup(SILValue beginAccess)
3209+
: beginAccess(beginAccess)
3210+
{}
3211+
3212+
void emit(SILGenFunction &SGF, CleanupLocation loc, ForUnwind_t forUnwind)
3213+
override {
3214+
SGF.B.createEndAccess(loc, beginAccess, /*aborted*/ false);
3215+
}
3216+
3217+
void dump(SILGenFunction &SGF) const override {
3218+
llvm::errs() << "EndAccessCleanup\n";
3219+
if (beginAccess) {
3220+
beginAccess->print(llvm::errs());
3221+
}
3222+
}
3223+
};
3224+
32013225
void SILGenFunction::emitSwitchStmt(SwitchStmt *S) {
32023226
LLVM_DEBUG(llvm::dbgs() << "emitting switch stmt\n";
32033227
S->dump(llvm::dbgs());
@@ -3385,12 +3409,22 @@ void SILGenFunction::emitSwitchStmt(SwitchStmt *S) {
33853409
if (!subjectMV.isPlusZero()) {
33863410
subjectMV = subjectMV.borrow(*this, S);
33873411
}
3388-
if (subjectMV.getType().isAddress() &&
3389-
subjectMV.getType().isLoadable(F)) {
3390-
// Load a borrow if the type is loadable.
3391-
subjectMV = subjectUndergoesFormalAccess
3392-
? B.createFormalAccessLoadBorrow(S, subjectMV)
3393-
: B.createLoadBorrow(S, subjectMV);
3412+
if (subjectMV.getType().isAddress()) {
3413+
if (subjectMV.getType().isLoadable(F)) {
3414+
// Load a borrow if the type is loadable.
3415+
subjectMV = subjectUndergoesFormalAccess
3416+
? B.createFormalAccessLoadBorrow(S, subjectMV)
3417+
: B.createLoadBorrow(S, subjectMV);
3418+
} else {
3419+
// Initiate a read access on the memory, to ensure that even
3420+
// if the underlying memory is mutable or consumable, the pattern
3421+
// match is not allowed to modify it.
3422+
auto access = B.createBeginAccess(S, subjectMV.getValue(),
3423+
SILAccessKind::Read,
3424+
SILAccessEnforcement::Static, false, false);
3425+
Cleanups.pushCleanup<EndAccessCleanup>(access);
3426+
subjectMV = ManagedValue::forBorrowedAddressRValue(access);
3427+
}
33943428
}
33953429
return {subjectMV, CastConsumptionKind::BorrowAlways};
33963430

0 commit comments

Comments
 (0)