Skip to content

Commit 5b7e375

Browse files
authored
Merge pull request #42480 from eeckstein/improve-cow-opts
COWOpts: handle `struct_extract` and `struct` instructions.
2 parents dbf1cb2 + ecfb431 commit 5b7e375

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

lib/SILOptimizer/Transforms/COWOpts.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,22 @@ void COWOptsPass::run() {
103103
}
104104
}
105105

106+
static SILValue skipStructAndExtract(SILValue value) {
107+
while (true) {
108+
if (auto *si = dyn_cast<StructInst>(value)) {
109+
if (si->getNumOperands() != 1)
110+
return value;
111+
value = si->getOperand(0);
112+
continue;
113+
}
114+
if (auto *sei = dyn_cast<StructExtractInst>(value)) {
115+
value = sei->getOperand();
116+
continue;
117+
}
118+
return value;
119+
}
120+
}
121+
106122
bool COWOptsPass::optimizeBeginCOW(BeginCOWMutationInst *BCM) {
107123
VoidPointerSet handled;
108124
SmallVector<SILValue, 8> workList;
@@ -112,7 +128,7 @@ bool COWOptsPass::optimizeBeginCOW(BeginCOWMutationInst *BCM) {
112128
// looking through block phi-arguments.
113129
workList.push_back(BCM->getOperand());
114130
while (!workList.empty()) {
115-
SILValue v = workList.pop_back_val();
131+
SILValue v = skipStructAndExtract(workList.pop_back_val());
116132
if (SILPhiArgument *arg = dyn_cast<SILPhiArgument>(v)) {
117133
if (handled.insert(arg).second) {
118134
SmallVector<SILValue, 4> incomingVals;
@@ -241,7 +257,9 @@ void COWOptsPass::collectEscapePoints(SILValue v,
241257
escapePoints, handled);
242258
break;
243259
case SILInstructionKind::StructInst:
260+
case SILInstructionKind::StructExtractInst:
244261
case SILInstructionKind::TupleInst:
262+
case SILInstructionKind::TupleExtractInst:
245263
case SILInstructionKind::UncheckedRefCastInst:
246264
collectEscapePoints(cast<SingleValueInstruction>(user),
247265
escapePoints, handled);

test/SILOptimizer/cow_opts.sil

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ final class Buffer {
1111
init()
1212
}
1313

14+
struct Str {
15+
@_hasStorage var b: Buffer { get set }
16+
}
17+
1418
sil @unknown : $@convention(thin) (@guaranteed Buffer) -> ()
1519

1620
// CHECK-LABEL: sil @test_simple
@@ -92,11 +96,35 @@ bb0(%0 : $*Buffer):
9296
sil @test_loop : $@convention(thin) (@owned Buffer) -> (Builtin.Int1, @owned Buffer) {
9397
bb0(%0 : $Buffer):
9498
%e = end_cow_mutation %0 : $Buffer
95-
br bb1(%e : $Buffer)
96-
bb1(%a : $Buffer):
97-
(%u, %b) = begin_cow_mutation %a : $Buffer
99+
%s1 = struct $Str (%e: $Buffer)
100+
br bb1(%s1 : $Str)
101+
bb1(%a : $Str):
102+
%as = struct_extract %a : $Str, #Str.b
103+
(%u, %b) = begin_cow_mutation %as : $Buffer
98104
%e2 = end_cow_mutation %b : $Buffer
99-
cond_br undef, bb1(%e2 : $Buffer), bb2
105+
%s2 = struct $Str (%e2: $Buffer)
106+
cond_br undef, bb1(%s2 : $Str), bb2
107+
bb2:
108+
%t = tuple (%u : $Builtin.Int1, %e2 : $Buffer)
109+
return %t : $(Builtin.Int1, Buffer)
110+
}
111+
112+
// CHECK-LABEL: sil @not_all_incoming_values_are_end_cow_mutation
113+
// CHECK: ([[U:%[0-9]+]], {{.*}}) = begin_cow_mutation
114+
// CHECK: [[B:%[0-9]+]] = end_cow_mutation
115+
// CHECK: [[T:%[0-9]+]] = tuple ([[U]] : $Builtin.Int1, [[B]] : $Buffer)
116+
// CHECK: return [[T]]
117+
// CHECK: } // end sil function 'not_all_incoming_values_are_end_cow_mutation'
118+
sil @not_all_incoming_values_are_end_cow_mutation : $@convention(thin) (@owned Buffer) -> (Builtin.Int1, @owned Buffer) {
119+
bb0(%0 : $Buffer):
120+
%s1 = struct $Str (%0: $Buffer)
121+
br bb1(%s1 : $Str)
122+
bb1(%a : $Str):
123+
%as = struct_extract %a : $Str, #Str.b
124+
(%u, %b) = begin_cow_mutation %as : $Buffer
125+
%e2 = end_cow_mutation %b : $Buffer
126+
%s2 = struct $Str (%e2: $Buffer)
127+
cond_br undef, bb1(%s2 : $Str), bb2
100128
bb2:
101129
%t = tuple (%u : $Builtin.Int1, %e2 : $Buffer)
102130
return %t : $(Builtin.Int1, Buffer)

0 commit comments

Comments
 (0)