Skip to content

Commit 11416bc

Browse files
author
jturcotti
committed
Tweak, improve, and debug the PartitionAnalysis engine until a fairly comprehensive suite of simple tests passes (region_based_sendability.swift)
1 parent 29bd728 commit 11416bc

File tree

6 files changed

+699
-184
lines changed

6 files changed

+699
-184
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -856,11 +856,12 @@ NOTE(sil_referencebinding_inout_binding_here, none,
856856
"inout binding here",
857857
())
858858

859-
// SendNonSendable checking
860-
861-
WARNING(send_non_sendable, none,
862-
"Non-Sendable value %0 consumed, then used at this site; could be race",
863-
(unsigned))
859+
// Warnings arising from the flow-sensitive checking of Sendability of
860+
// non-Sendable values
861+
WARNING(consumed_value_used, none,
862+
"Non-Sendable value consumed, then used at this site; could yield race with another thread", ())
863+
WARNING(arg_region_consumed, none,
864+
"This application could pass `self` or a Non-Sendable argument of this function to another thread, potentially yielding a race with the caller", ())
864865

865866
#define UNDEFINE_DIAGNOSTIC_MACROS
866867
#include "DefineDiagnosticMacros.h"

include/swift/SILOptimizer/Utils/PartitionUtils.h

Lines changed: 128 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class PartitionOp {
3030
PartitionOpKind OpKind;
3131
llvm::SmallVector<unsigned, 2> OpArgs;
3232

33-
// record the SILInstruction that this PartitionOp was generated from, if
33+
// Record the SILInstruction that this PartitionOp was generated from, if
3434
// generated during compilation from a SILBasicBlock
3535
SILInstruction *sourceInst;
3636

@@ -75,7 +75,7 @@ class PartitionOp {
7575
return sourceInst;
7676
}
7777

78-
void dump() const {
78+
void dump() const LLVM_ATTRIBUTE_USED {
7979
switch (OpKind) {
8080
case PartitionOpKind::Assign:
8181
llvm::dbgs() << "assign %" << OpArgs[0] << " = %" << OpArgs[1] << "\n";
@@ -110,6 +110,7 @@ static void horizontalUpdate(std::map<unsigned, signed> &map, unsigned key,
110110
}
111111

112112
signed oldVal = map[key];
113+
if (val == oldVal) return;
113114

114115
for (auto [otherKey, otherVal] : map)
115116
if (otherVal == oldVal)
@@ -118,20 +119,49 @@ static void horizontalUpdate(std::map<unsigned, signed> &map, unsigned key,
118119

119120
class Partition {
120121
private:
121-
// label each index with a non-negative (unsigned) label if it is associated
122+
// Label each index with a non-negative (unsigned) label if it is associated
122123
// with a valid region, and with -1 if it is associated with a consumed region
123-
// in-order traversal relied upon
124+
// in-order traversal relied upon.
124125
std::map<unsigned, signed> labels;
125126

126-
// track a label that is guaranteed to be fresh
127+
// Track a label that is guaranteed to be strictly larger than all in use,
128+
// and therefore safe for use as a fresh label.
127129
unsigned fresh_label = 0;
128130

129-
// in a canonical partition, all regions are labelled with the smallest index
131+
// In a canonical partition, all regions are labelled with the smallest index
130132
// of any member. Certain operations like join and equals rely on canonicality
131133
// so when it's invalidated this boolean tracks that, and it must be
132-
// reestablished by a call to canonicalize()
134+
// reestablished by a call to canonicalize().
133135
bool canonical;
134136

137+
// Used only in assertions, check that Partitions promised to be canonical
138+
// are actually canonical
139+
bool is_canonical_correct() {
140+
if (!canonical) return true; // vacuously correct
141+
142+
auto fail = [&](unsigned i, int type) {
143+
llvm::dbgs() << "FAIL(i=" << i << "; type=" << type << "): ";
144+
dump();
145+
return false;
146+
};
147+
148+
for (auto &[i, label] : labels) {
149+
// correctness vacuous at consumed indices
150+
if (label < 0) continue;
151+
152+
// this label should not exceed fresh_label
153+
if (label >= fresh_label) return fail(i, 0);
154+
155+
// the label of a region should be at most as large as each index in it
156+
if (i < label) return fail(i, 1);
157+
158+
// each region label should refer to an index in that region
159+
if (labels[label] != label) return fail(i, 2);
160+
}
161+
162+
return true;
163+
}
164+
135165
// linear time - For each region label that occurs, find the first index
136166
// at which it occurs and relabel all instances of it to that index.
137167
// This excludes the -1 label for consumed regions.
@@ -163,11 +193,32 @@ class Partition {
163193
// set fresh_label
164194
fresh_label = i + 1;
165195
}
196+
197+
assert(is_canonical_correct());
198+
}
199+
200+
// linear time - merge the regions of two indices, maintaining canonicality
201+
void merge(unsigned fst, unsigned snd) {
202+
assert(labels.count(fst) && labels.count(snd));
203+
if (labels[fst] == labels[snd])
204+
return;
205+
206+
// maintain canonicality by renaming the greater-numbered region
207+
if (labels[fst] < labels[snd])
208+
horizontalUpdate(labels, snd, labels[fst]);
209+
else
210+
horizontalUpdate(labels, fst, labels[snd]);
211+
212+
assert(is_canonical_correct());
166213
}
167214

168215
public:
169216
Partition() : labels({}), canonical(true) {}
170217

218+
// 1-arg constructor used when canonicality will be immediately invalidated,
219+
// so set to false to begin with
220+
Partition(bool canonical) : labels({}), canonical(canonical) {}
221+
171222
static Partition singleRegion(std::vector<unsigned> indices) {
172223
Partition p;
173224
if (!indices.empty()) {
@@ -177,17 +228,9 @@ class Partition {
177228
p.labels[index] = min_index;
178229
}
179230
}
180-
return p;
181-
}
182231

183-
void dump() const {
184-
llvm::dbgs() << "Partition";
185-
if (canonical)
186-
llvm::dbgs() << "(canonical)";
187-
llvm::dbgs() << "{";
188-
for (const auto &[i, label] : labels)
189-
llvm::dbgs() << "[" << i << ": " << label << "] ";
190-
llvm::dbgs() << "}\n";
232+
assert(p.is_canonical_correct());
233+
return p;
191234
}
192235

193236
// linear time - Test two partititons for equality by first putting them
@@ -204,54 +247,52 @@ class Partition {
204247
// and two indices are in the same region of the join iff they are in the same
205248
// region in either operand.
206249
static Partition join(Partition &fst, Partition &snd) {
207-
fst.canonicalize();
208-
snd.canonicalize();
209-
210-
std::map<unsigned, signed> relabel_fst;
211-
std::map<unsigned, signed> relabel_snd;
212-
auto lookup_fst = [&](unsigned i) {
213-
// signed to unsigned conversion... ?
214-
return relabel_fst.count(fst.labels[i]) ? relabel_fst[fst.labels[i]]
215-
: fst.labels[i];
216-
};
217-
218-
auto lookup_snd = [&](unsigned i) {
219-
// signed to unsigned conversion... safe?
220-
return relabel_snd.count(snd.labels[i]) ? relabel_snd[snd.labels[i]]
221-
: snd.labels[i];
222-
};
223-
224-
for (const auto &[i, _] : fst.labels) {
225-
// only consider indices present in both fst and snd
226-
if (!snd.labels.count(i))
227-
continue;
228-
229-
signed label_joined = std::min(lookup_fst(i), lookup_snd(i));
230-
231-
horizontalUpdate(relabel_fst, fst.labels[i], label_joined);
232-
horizontalUpdate(relabel_snd, snd.labels[i], label_joined);
250+
//ensure copies are made
251+
Partition fst_reduced = false;
252+
Partition snd_reduced = false;
253+
254+
// make canonical copies of fst and snd, reduced to their intersected domain
255+
for (auto [i, _] : fst.labels)
256+
if (snd.labels.count(i)) {
257+
fst_reduced.labels[i] = fst.labels[i];
258+
snd_reduced.labels[i] = snd.labels[i];
259+
}
260+
fst_reduced.canonicalize();
261+
snd_reduced.canonicalize();
262+
263+
// merging each index in fst with its label in snd ensures that all pairs
264+
// of indices that are in the same region in snd are also in the same region
265+
// in fst - the desired property
266+
for (const auto [i, snd_label] : snd_reduced.labels) {
267+
if (snd_label < 0)
268+
// if snd says that the region has been consumed, mark it consumed in fst
269+
horizontalUpdate(fst_reduced.labels, i, -1);
270+
else
271+
fst_reduced.merge(i, snd_label);
233272
}
234273

235-
Partition joined;
236-
joined.canonical = true;
237-
for (const auto &[i, _] : fst.labels) {
238-
if (!snd.labels.count(i))
239-
continue;
240-
signed label_i = lookup_fst(i);
241-
joined.labels[i] = label_i;
242-
joined.fresh_label = std::max(joined.fresh_label, (unsigned) label_i + 1);
243-
}
274+
assert(fst_reduced.is_canonical_correct());
244275

245-
return joined;
276+
// fst_reduced is now the join
277+
return fst_reduced;
246278
}
247279

248280
// Apply the passed PartitionOp to this partition, performing its action.
249281
// A `handleFailure` closure can optionally be passed in that will be called
250282
// if a consumed region is required. The closure is given the PartitionOp that
251283
// failed, and the index of the SIL value that was required but consumed.
284+
// Additionally, a list of "nonconsumable" indices can be passed in along with
285+
// a handleConsumeNonConsumable closure. In the event that a region containing
286+
// one of the nonconsumable indices is consumed, the closure will be called
287+
// with the offending Consume.
252288
void apply(
253-
PartitionOp op, llvm::function_ref<void(const PartitionOp&, unsigned)> handleFailure =
254-
[](const PartitionOp&, unsigned) {}) {
289+
PartitionOp op,
290+
llvm::function_ref<void(const PartitionOp&, unsigned)>
291+
handleFailure = [](const PartitionOp&, unsigned) {},
292+
std::vector<unsigned> nonconsumables = {},
293+
llvm::function_ref<void(const PartitionOp&, unsigned)>
294+
handleConsumeNonConsumable = [](const PartitionOp&, unsigned) {}
295+
) {
255296
switch (op.OpKind) {
256297
case PartitionOpKind::Assign:
257298
assert(op.OpArgs.size() == 2 &&
@@ -290,6 +331,19 @@ class Partition {
290331

291332
// mark region as consumed
292333
horizontalUpdate(labels, op.OpArgs[0], -1);
334+
335+
// check if any nonconsumables were consumed, and handle the failure if so
336+
for (unsigned nonconsumable : nonconsumables) {
337+
assert(labels.count(nonconsumable) &&
338+
"nonconsumables should be function args and self, and therefore"
339+
"always present in the label map because of initialization at "
340+
"entry");
341+
if (labels[nonconsumable] < 0) {
342+
handleConsumeNonConsumable(op, nonconsumable);
343+
break;
344+
}
345+
}
346+
293347
break;
294348
case PartitionOpKind::Merge:
295349
assert(op.OpArgs.size() == 2 &&
@@ -302,14 +356,7 @@ class Partition {
302356
if (labels[op.OpArgs[1]] < 0)
303357
handleFailure(op, op.OpArgs[1]);
304358

305-
if (labels[op.OpArgs[0]] == labels[op.OpArgs[1]])
306-
break;
307-
308-
// maintain canonicality by renaming the greater-numbered region
309-
if (labels[op.OpArgs[0]] < labels[op.OpArgs[1]])
310-
horizontalUpdate(labels, op.OpArgs[1], labels[op.OpArgs[0]]);
311-
else
312-
horizontalUpdate(labels, op.OpArgs[0], labels[op.OpArgs[1]]);
359+
merge(op.OpArgs[0], op.OpArgs[1]);
313360
break;
314361
case PartitionOpKind::Require:
315362
assert(op.OpArgs.size() == 1 &&
@@ -319,9 +366,21 @@ class Partition {
319366
if (labels[op.OpArgs[0]] < 0)
320367
handleFailure(op, op.OpArgs[0]);
321368
}
369+
370+
assert(is_canonical_correct());
371+
}
372+
373+
void dump_labels() const LLVM_ATTRIBUTE_USED {
374+
llvm::dbgs() << "Partition";
375+
if (canonical)
376+
llvm::dbgs() << "(canonical)";
377+
llvm::dbgs() << "(fresh=" << fresh_label << "){";
378+
for (const auto &[i, label] : labels)
379+
llvm::dbgs() << "[" << i << ": " << label << "] ";
380+
llvm::dbgs() << "}\n";
322381
}
323382

324-
void dump() {
383+
void dump() LLVM_ATTRIBUTE_USED {
325384
std::map<signed, std::vector<unsigned>> buckets;
326385

327386
for (auto [i, label] : labels) {
@@ -331,12 +390,15 @@ class Partition {
331390
llvm::dbgs() << "[";
332391
for (auto [label, indices] : buckets) {
333392
llvm::dbgs() << (label < 0 ? "{" : "(");
393+
int j = 0;
334394
for (unsigned i : indices) {
335-
llvm::dbgs() << i << " ";
395+
llvm::dbgs() << (j++? " " : "") << i;
336396
}
337397
llvm::dbgs() << (label < 0 ? "}" : ")");
338398
}
339-
llvm::dbgs() << "]\n";
399+
llvm::dbgs() << "] | ";
400+
401+
dump_labels();
340402
}
341403
};
342404
}

0 commit comments

Comments
 (0)