Skip to content

Commit 8ddbc8b

Browse files
Merge pull request #72144 from nate-chandler/partial-consumption/20240306/1
[NoncopyablePartialConsumption] Allow consume.
2 parents ed745a5 + e40581a commit 8ddbc8b

10 files changed

+615
-15
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7741,6 +7741,16 @@ ERROR(moveOnly_not_allowed_here,none,
77417741
"'@_moveOnly' attribute is only valid on structs or enums", ())
77427742
ERROR(consume_expression_not_passed_lvalue,none,
77437743
"'consume' can only be applied to a local binding ('let', 'var', or parameter)", ())
7744+
ERROR(consume_expression_partial_copyable,none,
7745+
"'consume' can only be used to partially consume storage of a noncopyable type", ())
7746+
ERROR(consume_expression_non_storage,none,
7747+
"'consume' can only be used to partially consume storage", ())
7748+
NOTE(note_consume_expression_non_storage_call,none,
7749+
"non-storage produced by this call", ())
7750+
NOTE(note_consume_expression_non_storage_subscript,none,
7751+
"non-storage produced by this subscript", ())
7752+
NOTE(note_consume_expression_non_storage_property,none,
7753+
"non-storage produced by this computed property", ())
77447754
ERROR(borrow_expression_not_passed_lvalue,none,
77457755
"'borrow' can only be applied to a local binding ('let', 'var', or parameter)", ())
77467756
ERROR(copy_expression_not_passed_lvalue,none,

lib/SILOptimizer/Mandatory/MoveOnlyAddressCheckerTester.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class MoveOnlyAddressCheckerTesterPass : public SILFunctionTransform {
9494
llvm::SmallSetVector<MarkUnresolvedNonCopyableValueInst *, 32>
9595
moveIntroducersToProcess;
9696
searchForCandidateAddressMarkUnresolvedNonCopyableValueInsts(
97-
fn, moveIntroducersToProcess, diagnosticEmitter);
97+
fn, getAnalysis<PostOrderAnalysis>(), moveIntroducersToProcess,
98+
diagnosticEmitter);
9899

99100
LLVM_DEBUG(llvm::dbgs()
100101
<< "Emitting diagnostic when checking for mark must check inst: "

lib/SILOptimizer/Mandatory/MoveOnlyAddressCheckerUtils.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -479,14 +479,14 @@ static bool isCopyableValue(SILValue value) {
479479

480480
void swift::siloptimizer::
481481
searchForCandidateAddressMarkUnresolvedNonCopyableValueInsts(
482-
SILFunction *fn,
482+
SILFunction *fn, PostOrderAnalysis *poa,
483483
llvm::SmallSetVector<MarkUnresolvedNonCopyableValueInst *, 32>
484484
&moveIntroducersToProcess,
485485
DiagnosticEmitter &diagnosticEmitter) {
486-
for (auto &block : *fn) {
487-
for (auto ii = block.begin(), ie = block.end(); ii != ie;) {
488-
auto *mmci = dyn_cast<MarkUnresolvedNonCopyableValueInst>(&*ii);
489-
++ii;
486+
auto *po = poa->get(fn);
487+
for (auto *block : po->getPostOrder()) {
488+
for (auto &ii : llvm::make_range(block->rbegin(), block->rend())) {
489+
auto *mmci = dyn_cast<MarkUnresolvedNonCopyableValueInst>(&ii);
490490

491491
if (!mmci || !mmci->hasMoveCheckerKind() || !mmci->getType().isAddress())
492492
continue;
@@ -3877,6 +3877,9 @@ bool MoveOnlyAddressChecker::check(
38773877
// If we fail the address check in some way, set the diagnose!
38783878
diagnosticEmitter.emitCheckerDoesntUnderstandDiagnostic(markedValue);
38793879
}
3880+
3881+
markedValue->replaceAllUsesWith(markedValue->getOperand());
3882+
markedValue->eraseFromParent();
38803883
}
38813884

38823885
if (DumpSILBeforeRemovingMarkUnresolvedNonCopyableValueInst) {

lib/SILOptimizer/Mandatory/MoveOnlyAddressCheckerUtils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
namespace swift {
1919

20+
class PostOrderAnalysis;
21+
2022
namespace siloptimizer {
2123

2224
class DiagnosticEmitter;
@@ -26,7 +28,7 @@ class DiagnosticEmitter;
2628
/// NOTE: To see if we emitted a diagnostic, use \p
2729
/// diagnosticEmitter.getDiagnosticCount().
2830
void searchForCandidateAddressMarkUnresolvedNonCopyableValueInsts(
29-
SILFunction *fn,
31+
SILFunction *fn, PostOrderAnalysis *poa,
3032
llvm::SmallSetVector<MarkUnresolvedNonCopyableValueInst *, 32>
3133
&moveIntroducersToProcess,
3234
DiagnosticEmitter &diagnosticEmitter);

lib/SILOptimizer/Mandatory/MoveOnlyChecker.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ void MoveOnlyChecker::checkAddresses() {
186186
llvm::SmallSetVector<MarkUnresolvedNonCopyableValueInst *, 32>
187187
moveIntroducersToProcess;
188188
searchForCandidateAddressMarkUnresolvedNonCopyableValueInsts(
189-
fn, moveIntroducersToProcess, diagnosticEmitter);
189+
fn, poa, moveIntroducersToProcess, diagnosticEmitter);
190190

191191
LLVM_DEBUG(
192192
llvm::dbgs()

lib/SILOptimizer/Mandatory/MoveOnlyTempAllocationFromLetTester.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ struct MoveOnlyTempAllocationFromLetTester : SILFunctionTransform {
5151

5252
unsigned diagCount = diagnosticEmitter.getDiagnosticCount();
5353
searchForCandidateAddressMarkUnresolvedNonCopyableValueInsts(
54-
getFunction(), moveIntroducersToProcess, diagnosticEmitter);
54+
getFunction(), getAnalysis<PostOrderAnalysis>(),
55+
moveIntroducersToProcess, diagnosticEmitter);
5556

5657
// Return early if we emitted a diagnostic.
5758
if (diagCount != diagnosticEmitter.getDiagnosticCount())

lib/Sema/MiscDiagnostics.cpp

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,12 +420,89 @@ static void diagSyntacticUseRestrictions(const Expr *E, const DeclContext *DC,
420420
}
421421

422422
void checkConsumeExpr(ConsumeExpr *consumeExpr) {
423-
auto *subExpr = consumeExpr->getSubExpr();
424-
if (auto *li = dyn_cast<LoadExpr>(subExpr))
425-
subExpr = li->getSubExpr();
426-
if (!isa<DeclRefExpr>(subExpr)) {
423+
auto partialConsumptionEnabled =
424+
Ctx.LangOpts.hasFeature(Feature::MoveOnlyPartialConsumption);
425+
426+
bool noncopyable = false;
427+
bool partial = false;
428+
Expr *current = consumeExpr->getSubExpr();
429+
while (current) {
430+
if (current->getType()->getCanonicalType()->isNoncopyable()) {
431+
noncopyable = true;
432+
}
433+
if (auto *dre = dyn_cast<DeclRefExpr>(current)) {
434+
if (partial & !noncopyable) {
435+
Ctx.Diags.diagnose(consumeExpr->getLoc(),
436+
diag::consume_expression_partial_copyable);
437+
return;
438+
}
439+
// The chain of member_ref_exprs and load_exprs terminates at a
440+
// declref_expr. This is legal.
441+
return;
442+
}
443+
// Look through loads.
444+
if (auto *le = dyn_cast<LoadExpr>(current)) {
445+
current = le->getSubExpr();
446+
continue;
447+
}
448+
auto *mre = dyn_cast<MemberRefExpr>(current);
449+
if (mre && partialConsumptionEnabled) {
450+
auto *vd = dyn_cast<VarDecl>(mre->getMember().getDecl());
451+
if (!vd) {
452+
Ctx.Diags.diagnose(consumeExpr->getLoc(),
453+
diag::consume_expression_non_storage);
454+
return;
455+
}
456+
partial = true;
457+
AccessStrategy strategy = vd->getAccessStrategy(
458+
mre->getAccessSemantics(), AccessKind::Read,
459+
DC->getParentModule(), ResilienceExpansion::Minimal);
460+
if (strategy.getKind() != AccessStrategy::Storage) {
461+
if (noncopyable) {
462+
Ctx.Diags.diagnose(consumeExpr->getLoc(),
463+
diag::consume_expression_non_storage);
464+
Ctx.Diags.diagnose(
465+
mre->getLoc(),
466+
diag::note_consume_expression_non_storage_property);
467+
} else {
468+
Ctx.Diags.diagnose(consumeExpr->getLoc(),
469+
diag::consume_expression_partial_copyable);
470+
}
471+
return;
472+
}
473+
current = mre->getBase();
474+
continue;
475+
}
476+
auto *ce = dyn_cast<CallExpr>(current);
477+
if (ce && partialConsumptionEnabled) {
478+
if (noncopyable) {
479+
Ctx.Diags.diagnose(consumeExpr->getLoc(),
480+
diag::consume_expression_non_storage);
481+
Ctx.Diags.diagnose(ce->getLoc(),
482+
diag::note_consume_expression_non_storage_call);
483+
} else {
484+
Ctx.Diags.diagnose(consumeExpr->getLoc(),
485+
diag::consume_expression_partial_copyable);
486+
}
487+
return;
488+
}
489+
auto *se = dyn_cast<SubscriptExpr>(current);
490+
if (se && partialConsumptionEnabled) {
491+
if (noncopyable) {
492+
Ctx.Diags.diagnose(consumeExpr->getLoc(),
493+
diag::consume_expression_non_storage);
494+
Ctx.Diags.diagnose(
495+
se->getLoc(),
496+
diag::note_consume_expression_non_storage_subscript);
497+
} else {
498+
Ctx.Diags.diagnose(consumeExpr->getLoc(),
499+
diag::consume_expression_partial_copyable);
500+
}
501+
return;
502+
}
427503
Ctx.Diags.diagnose(consumeExpr->getLoc(),
428504
diag::consume_expression_not_passed_lvalue);
505+
return;
429506
}
430507
}
431508

test/Interop/Cxx/operators/move-only/move-only-synthesized-property-typecheck.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ func testNonCopyableHolderConstDerefPointee() {
3030
holder.pointee.mutMethod(1) // expected-error {{cannot use mutating member on immutable value: 'pointee' is a get-only property}}
3131
holder.pointee.x = 2 // expected-error {{cannot assign to property: 'pointee' is a get-only property}}
3232
#else
33-
consumingNC(holder.pointee) // CHECK: [[@LINE]]:{{.*}}: error:
34-
let consumeVal = holder.pointee // CHECK: [[@LINE]]:{{.*}}: error:
33+
consumingNC(holder.pointee) // CHECK-DAG: [[@LINE]]:{{.*}}: error:
34+
let consumeVal = holder.pointee // CHECK-DAG: [[@LINE]]:{{.*}}: error:
3535
#endif
3636
}
3737

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// RUN: %target-run-simple-swift(-parse-as-library -Xfrontend -sil-verify-all -enable-experimental-feature MoveOnlyPartialConsumption) | %FileCheck %s
2+
// RUN: %target-run-simple-swift(-parse-as-library -O -Xfrontend -sil-verify-all -enable-experimental-feature MoveOnlyPartialConsumption) | %FileCheck %s
3+
4+
// REQUIRES: executable_test
5+
6+
@main struct App { static func main() {
7+
test1((1,2,3,4))
8+
}}
9+
10+
func barrier() { print("barrier") }
11+
12+
struct Ur<T> : ~Copyable {
13+
var t: T
14+
var name: String
15+
init(_ t: T, named name: String) {
16+
self.t = t
17+
self.name = name
18+
print("hi", name)
19+
}
20+
deinit {
21+
print("bye", name)
22+
}
23+
}
24+
func take<T>(_ u: consuming Ur<T>) {}
25+
26+
struct Pair<T> : ~Copyable {
27+
var u1: Ur<T>
28+
var u2: Ur<T>
29+
init(_ t1: T, _ t2: T, named name: String) {
30+
u1 = .init(t1, named: "\(name).u1")
31+
u2 = .init(t2, named: "\(name).u2")
32+
}
33+
}
34+
func take<T>(_ u: consuming Pair<T>) {}
35+
36+
struct Quad<T> : ~Copyable {
37+
var p1: Pair<T>
38+
var p2: Pair<T>
39+
init(_ t1: T, _ t2: T, _ t3: T, _ t4: T, named name: String) {
40+
p1 = .init(t1, t2, named: "\(name).p1")
41+
p2 = .init(t3, t4, named: "\(name).p2")
42+
}
43+
}
44+
func take<T>(_ u: consuming Quad<T>) {}
45+
46+
func test1<T>(_ ts: (T, T, T, T)) {
47+
do {
48+
let q1 = Quad<T>(ts.0, ts.1, ts.2, ts.3, named: "\(#function).q1")
49+
barrier()
50+
// CHECK: barrier
51+
// CHECK: bye test1(_:).q1.p1.u1
52+
// CHECK: bye test1(_:).q1.p1.u2
53+
// CHECK: bye test1(_:).q1.p2.u1
54+
// CHECK: bye test1(_:).q1.p2.u2
55+
}
56+
57+
let q2 = Quad<T>(ts.0, ts.1, ts.2, ts.3, named: "\(#function).q2")
58+
take(q2.p2.u2)
59+
take(q2.p2.u1)
60+
barrier()
61+
take(q2.p1.u2)
62+
take(q2.p1.u1)
63+
// CHECK: bye test1(_:).q2.p2.u2
64+
// CHECK: bye test1(_:).q2.p2.u1
65+
// CHECK: barrier
66+
// CHECK: bye test1(_:).q2.p1.u2
67+
// CHECK: bye test1(_:).q2.p1.u1
68+
69+
let q3 = Quad<T>(ts.0, ts.1, ts.2, ts.3, named: "\(#function).q3")
70+
_ = consume q3.p2.u2
71+
_ = consume q3.p2.u1
72+
_ = consume q3.p1.u2
73+
barrier()
74+
_ = consume q3.p1.u1
75+
// CHECK: bye test1(_:).q3.p2.u2
76+
// CHECK: bye test1(_:).q3.p2.u1
77+
// CHECK: bye test1(_:).q3.p1.u2
78+
// CHECK: barrier
79+
// CHECK: bye test1(_:).q3.p1.u1
80+
81+
let q4 = Quad<T>(ts.0, ts.1, ts.2, ts.3, named: "\(#function).q4")
82+
_ = consume q4.p1.u1
83+
barrier()
84+
_ = consume q4.p2.u1
85+
_ = consume q4.p1.u2
86+
_ = consume q4.p2.u2
87+
// CHECK: bye test1(_:).q4.p1.u1
88+
// CHECK: barrier
89+
// CHECK: bye test1(_:).q4.p2.u1
90+
// CHECK: bye test1(_:).q4.p1.u2
91+
// CHECK: bye test1(_:).q4.p2.u2
92+
93+
do {
94+
let q5 = Quad<T>(ts.0, ts.1, ts.2, ts.3, named: "\(#function).q5")
95+
take(q5.p1.u1)
96+
_ = consume q5.p2.u1
97+
_ = consume q5.p1.u2
98+
take(q5.p2.u2)
99+
// CHECK: bye test1(_:).q5.p1.u1
100+
// CHECK: bye test1(_:).q5.p2.u1
101+
// CHECK: bye test1(_:).q5.p1.u2
102+
// CHECK: bye test1(_:).q5.p2.u2
103+
}
104+
do {
105+
let q6 = Quad<T>(ts.0, ts.1, ts.2, ts.3, named: "\(#function).q6")
106+
if Bool.random() {
107+
take(q6.p1.u1)
108+
_ = consume q6.p2.u1
109+
} else {
110+
_ = consume q6.p1.u2
111+
take(q6.p2.u2)
112+
}
113+
}
114+
}
115+

0 commit comments

Comments
 (0)