9
9
10
10
namespace swift {
11
11
12
+ // PartitionOpKind represents the different kinds of PartitionOps that
13
+ // SILInstructions can be translated to
12
14
enum class PartitionOpKind : uint8_t {
13
15
// Assign one value to the region of another, takes two args, second arg
14
16
// must already be tracked with a non-consumed region
15
17
Assign,
18
+
16
19
// Assign one value to a fresh region, takes one arg.
17
20
AssignFresh,
18
- // Consume the region of a value, takes one arg
21
+
22
+ // Consume the region of a value, takes one arg. Region of arg must be
23
+ // non-consumed before the op.
19
24
Consume,
25
+
26
+ // Merge the regions of two values, takes two args, both must be from
27
+ // non-consumed regions.
20
28
Merge,
29
+
30
+ // Require the region of a value to be non-consumed, takes one arg.
21
31
Require
22
32
};
23
33
@@ -30,7 +40,7 @@ class PartitionOp {
30
40
PartitionOpKind OpKind;
31
41
llvm::SmallVector<unsigned , 2 > OpArgs;
32
42
33
- // record the SILInstruction that this PartitionOp was generated from, if
43
+ // Record the SILInstruction that this PartitionOp was generated from, if
34
44
// generated during compilation from a SILBasicBlock
35
45
SILInstruction *sourceInst;
36
46
@@ -75,7 +85,7 @@ class PartitionOp {
75
85
return sourceInst;
76
86
}
77
87
78
- void dump () const {
88
+ void dump () const LLVM_ATTRIBUTE_USED {
79
89
switch (OpKind) {
80
90
case PartitionOpKind::Assign:
81
91
llvm::dbgs () << " assign %" << OpArgs[0 ] << " = %" << OpArgs[1 ] << " \n " ;
@@ -110,6 +120,7 @@ static void horizontalUpdate(std::map<unsigned, signed> &map, unsigned key,
110
120
}
111
121
112
122
signed oldVal = map[key];
123
+ if (val == oldVal) return ;
113
124
114
125
for (auto [otherKey, otherVal] : map)
115
126
if (otherVal == oldVal)
@@ -118,20 +129,49 @@ static void horizontalUpdate(std::map<unsigned, signed> &map, unsigned key,
118
129
119
130
class Partition {
120
131
private:
121
- // label each index with a non-negative (unsigned) label if it is associated
132
+ // Label each index with a non-negative (unsigned) label if it is associated
122
133
// with a valid region, and with -1 if it is associated with a consumed region
123
- // in-order traversal relied upon
134
+ // in-order traversal relied upon.
124
135
std::map<unsigned , signed > labels;
125
136
126
- // track a label that is guaranteed to be fresh
137
+ // Track a label that is guaranteed to be strictly larger than all in use,
138
+ // and therefore safe for use as a fresh label.
127
139
unsigned fresh_label = 0 ;
128
140
129
- // in a canonical partition, all regions are labelled with the smallest index
141
+ // In a canonical partition, all regions are labelled with the smallest index
130
142
// of any member. Certain operations like join and equals rely on canonicality
131
143
// so when it's invalidated this boolean tracks that, and it must be
132
- // reestablished by a call to canonicalize()
144
+ // reestablished by a call to canonicalize().
133
145
bool canonical;
134
146
147
+ // Used only in assertions, check that Partitions promised to be canonical
148
+ // are actually canonical
149
+ bool is_canonical_correct () {
150
+ if (!canonical) return true ; // vacuously correct
151
+
152
+ auto fail = [&](unsigned i, int type) {
153
+ llvm::dbgs () << " FAIL(i=" << i << " ; type=" << type << " ): " ;
154
+ dump ();
155
+ return false ;
156
+ };
157
+
158
+ for (auto &[i, label] : labels) {
159
+ // correctness vacuous at consumed indices
160
+ if (label < 0 ) continue ;
161
+
162
+ // this label should not exceed fresh_label
163
+ if (label >= fresh_label) return fail (i, 0 );
164
+
165
+ // the label of a region should be at most as large as each index in it
166
+ if (i < label) return fail (i, 1 );
167
+
168
+ // each region label should refer to an index in that region
169
+ if (labels[label] != label) return fail (i, 2 );
170
+ }
171
+
172
+ return true ;
173
+ }
174
+
135
175
// linear time - For each region label that occurs, find the first index
136
176
// at which it occurs and relabel all instances of it to that index.
137
177
// This excludes the -1 label for consumed regions.
@@ -163,11 +203,32 @@ class Partition {
163
203
// set fresh_label
164
204
fresh_label = i + 1 ;
165
205
}
206
+
207
+ assert (is_canonical_correct ());
208
+ }
209
+
210
+ // linear time - merge the regions of two indices, maintaining canonicality
211
+ void merge (unsigned fst, unsigned snd) {
212
+ assert (labels.count (fst) && labels.count (snd));
213
+ if (labels[fst] == labels[snd])
214
+ return ;
215
+
216
+ // maintain canonicality by renaming the greater-numbered region
217
+ if (labels[fst] < labels[snd])
218
+ horizontalUpdate (labels, snd, labels[fst]);
219
+ else
220
+ horizontalUpdate (labels, fst, labels[snd]);
221
+
222
+ assert (is_canonical_correct ());
166
223
}
167
224
168
225
public:
169
226
Partition () : labels({}), canonical(true ) {}
170
227
228
+ // 1-arg constructor used when canonicality will be immediately invalidated,
229
+ // so set to false to begin with
230
+ Partition (bool canonical) : labels({}), canonical(canonical) {}
231
+
171
232
static Partition singleRegion (std::vector<unsigned > indices) {
172
233
Partition p;
173
234
if (!indices.empty ()) {
@@ -177,17 +238,9 @@ class Partition {
177
238
p.labels [index] = min_index;
178
239
}
179
240
}
180
- return p;
181
- }
182
241
183
- void dump () const {
184
- llvm::dbgs () << " Partition" ;
185
- if (canonical)
186
- llvm::dbgs () << " (canonical)" ;
187
- llvm::dbgs () << " {" ;
188
- for (const auto &[i, label] : labels)
189
- llvm::dbgs () << " [" << i << " : " << label << " ] " ;
190
- llvm::dbgs () << " }\n " ;
242
+ assert (p.is_canonical_correct ());
243
+ return p;
191
244
}
192
245
193
246
// linear time - Test two partititons for equality by first putting them
@@ -204,54 +257,52 @@ class Partition {
204
257
// and two indices are in the same region of the join iff they are in the same
205
258
// region in either operand.
206
259
static Partition join (Partition &fst, Partition &snd) {
207
- fst.canonicalize ();
208
- snd.canonicalize ();
209
-
210
- std::map<unsigned , signed > relabel_fst;
211
- std::map<unsigned , signed > relabel_snd;
212
- auto lookup_fst = [&](unsigned i) {
213
- // signed to unsigned conversion... ?
214
- return relabel_fst.count (fst.labels [i]) ? relabel_fst[fst.labels [i]]
215
- : fst.labels [i];
216
- };
217
-
218
- auto lookup_snd = [&](unsigned i) {
219
- // signed to unsigned conversion... safe?
220
- return relabel_snd.count (snd.labels [i]) ? relabel_snd[snd.labels [i]]
221
- : snd.labels [i];
222
- };
223
-
224
- for (const auto &[i, _] : fst.labels ) {
225
- // only consider indices present in both fst and snd
226
- if (!snd.labels .count (i))
227
- continue ;
228
-
229
- signed label_joined = std::min (lookup_fst (i), lookup_snd (i));
230
-
231
- horizontalUpdate (relabel_fst, fst.labels [i], label_joined);
232
- horizontalUpdate (relabel_snd, snd.labels [i], label_joined);
260
+ // ensure copies are made
261
+ Partition fst_reduced = false ;
262
+ Partition snd_reduced = false ;
263
+
264
+ // make canonical copies of fst and snd, reduced to their intersected domain
265
+ for (auto [i, _] : fst.labels )
266
+ if (snd.labels .count (i)) {
267
+ fst_reduced.labels [i] = fst.labels [i];
268
+ snd_reduced.labels [i] = snd.labels [i];
269
+ }
270
+ fst_reduced.canonicalize ();
271
+ snd_reduced.canonicalize ();
272
+
273
+ // merging each index in fst with its label in snd ensures that all pairs
274
+ // of indices that are in the same region in snd are also in the same region
275
+ // in fst - the desired property
276
+ for (const auto [i, snd_label] : snd_reduced.labels ) {
277
+ if (snd_label < 0 )
278
+ // if snd says that the region has been consumed, mark it consumed in fst
279
+ horizontalUpdate (fst_reduced.labels , i, -1 );
280
+ else
281
+ fst_reduced.merge (i, snd_label);
233
282
}
234
283
235
- Partition joined;
236
- joined.canonical = true ;
237
- for (const auto &[i, _] : fst.labels ) {
238
- if (!snd.labels .count (i))
239
- continue ;
240
- signed label_i = lookup_fst (i);
241
- joined.labels [i] = label_i;
242
- joined.fresh_label = std::max (joined.fresh_label , (unsigned ) label_i + 1 );
243
- }
284
+ assert (fst_reduced.is_canonical_correct ());
244
285
245
- return joined;
286
+ // fst_reduced is now the join
287
+ return fst_reduced;
246
288
}
247
289
248
290
// Apply the passed PartitionOp to this partition, performing its action.
249
291
// A `handleFailure` closure can optionally be passed in that will be called
250
292
// if a consumed region is required. The closure is given the PartitionOp that
251
293
// failed, and the index of the SIL value that was required but consumed.
294
+ // Additionally, a list of "nonconsumable" indices can be passed in along with
295
+ // a handleConsumeNonConsumable closure. In the event that a region containing
296
+ // one of the nonconsumable indices is consumed, the closure will be called
297
+ // with the offending Consume.
252
298
void apply (
253
- PartitionOp op, llvm::function_ref<void (const PartitionOp&, unsigned )> handleFailure =
254
- [](const PartitionOp&, unsigned ) {}) {
299
+ PartitionOp op,
300
+ llvm::function_ref<void (const PartitionOp&, unsigned )>
301
+ handleFailure = [](const PartitionOp&, unsigned ) {},
302
+ std::vector<unsigned > nonconsumables = {},
303
+ llvm::function_ref<void (const PartitionOp&, unsigned )>
304
+ handleConsumeNonConsumable = [](const PartitionOp&, unsigned ) {}
305
+ ) {
255
306
switch (op.OpKind ) {
256
307
case PartitionOpKind::Assign:
257
308
assert (op.OpArgs .size () == 2 &&
@@ -290,6 +341,19 @@ class Partition {
290
341
291
342
// mark region as consumed
292
343
horizontalUpdate (labels, op.OpArgs [0 ], -1 );
344
+
345
+ // check if any nonconsumables were consumed, and handle the failure if so
346
+ for (unsigned nonconsumable : nonconsumables) {
347
+ assert (labels.count (nonconsumable) &&
348
+ " nonconsumables should be function args and self, and therefore"
349
+ " always present in the label map because of initialization at "
350
+ " entry" );
351
+ if (labels[nonconsumable] < 0 ) {
352
+ handleConsumeNonConsumable (op, nonconsumable);
353
+ break ;
354
+ }
355
+ }
356
+
293
357
break ;
294
358
case PartitionOpKind::Merge:
295
359
assert (op.OpArgs .size () == 2 &&
@@ -302,14 +366,7 @@ class Partition {
302
366
if (labels[op.OpArgs [1 ]] < 0 )
303
367
handleFailure (op, op.OpArgs [1 ]);
304
368
305
- if (labels[op.OpArgs [0 ]] == labels[op.OpArgs [1 ]])
306
- break ;
307
-
308
- // maintain canonicality by renaming the greater-numbered region
309
- if (labels[op.OpArgs [0 ]] < labels[op.OpArgs [1 ]])
310
- horizontalUpdate (labels, op.OpArgs [1 ], labels[op.OpArgs [0 ]]);
311
- else
312
- horizontalUpdate (labels, op.OpArgs [0 ], labels[op.OpArgs [1 ]]);
369
+ merge (op.OpArgs [0 ], op.OpArgs [1 ]);
313
370
break ;
314
371
case PartitionOpKind::Require:
315
372
assert (op.OpArgs .size () == 1 &&
@@ -319,9 +376,21 @@ class Partition {
319
376
if (labels[op.OpArgs [0 ]] < 0 )
320
377
handleFailure (op, op.OpArgs [0 ]);
321
378
}
379
+
380
+ assert (is_canonical_correct ());
322
381
}
323
382
324
- void dump () {
383
+ void dump_labels () const LLVM_ATTRIBUTE_USED {
384
+ llvm::dbgs () << " Partition" ;
385
+ if (canonical)
386
+ llvm::dbgs () << " (canonical)" ;
387
+ llvm::dbgs () << " (fresh=" << fresh_label << " ){" ;
388
+ for (const auto &[i, label] : labels)
389
+ llvm::dbgs () << " [" << i << " : " << label << " ] " ;
390
+ llvm::dbgs () << " }\n " ;
391
+ }
392
+
393
+ void dump () LLVM_ATTRIBUTE_USED {
325
394
std::map<signed , std::vector<unsigned >> buckets;
326
395
327
396
for (auto [i, label] : labels) {
@@ -331,12 +400,15 @@ class Partition {
331
400
llvm::dbgs () << " [" ;
332
401
for (auto [label, indices] : buckets) {
333
402
llvm::dbgs () << (label < 0 ? " {" : " (" );
403
+ int j = 0 ;
334
404
for (unsigned i : indices) {
335
- llvm::dbgs () << i << " " ;
405
+ llvm::dbgs () << (j++? " " : " " ) << i ;
336
406
}
337
407
llvm::dbgs () << (label < 0 ? " }" : " )" );
338
408
}
339
- llvm::dbgs () << " ]\n " ;
409
+ llvm::dbgs () << " ] | " ;
410
+
411
+ dump_labels ();
340
412
}
341
413
};
342
414
}
0 commit comments