@@ -30,7 +30,7 @@ class PartitionOp {
30
30
PartitionOpKind OpKind;
31
31
llvm::SmallVector<unsigned , 2 > OpArgs;
32
32
33
- // record the SILInstruction that this PartitionOp was generated from, if
33
+ // Record the SILInstruction that this PartitionOp was generated from, if
34
34
// generated during compilation from a SILBasicBlock
35
35
SILInstruction *sourceInst;
36
36
@@ -75,7 +75,7 @@ class PartitionOp {
75
75
return sourceInst;
76
76
}
77
77
78
- void dump () const {
78
+ void dump () const LLVM_ATTRIBUTE_USED {
79
79
switch (OpKind) {
80
80
case PartitionOpKind::Assign:
81
81
llvm::dbgs () << " assign %" << OpArgs[0 ] << " = %" << OpArgs[1 ] << " \n " ;
@@ -110,6 +110,7 @@ static void horizontalUpdate(std::map<unsigned, signed> &map, unsigned key,
110
110
}
111
111
112
112
signed oldVal = map[key];
113
+ if (val == oldVal) return ;
113
114
114
115
for (auto [otherKey, otherVal] : map)
115
116
if (otherVal == oldVal)
@@ -118,20 +119,49 @@ static void horizontalUpdate(std::map<unsigned, signed> &map, unsigned key,
118
119
119
120
class Partition {
120
121
private:
121
- // label each index with a non-negative (unsigned) label if it is associated
122
+ // Label each index with a non-negative (unsigned) label if it is associated
122
123
// with a valid region, and with -1 if it is associated with a consumed region
123
- // in-order traversal relied upon
124
+ // in-order traversal relied upon.
124
125
std::map<unsigned , signed > labels;
125
126
126
- // track a label that is guaranteed to be fresh
127
+ // Track a label that is guaranteed to be strictly larger than all in use,
128
+ // and therefore safe for use as a fresh label.
127
129
unsigned fresh_label = 0 ;
128
130
129
- // in a canonical partition, all regions are labelled with the smallest index
131
+ // In a canonical partition, all regions are labelled with the smallest index
130
132
// of any member. Certain operations like join and equals rely on canonicality
131
133
// so when it's invalidated this boolean tracks that, and it must be
132
- // reestablished by a call to canonicalize()
134
+ // reestablished by a call to canonicalize().
133
135
bool canonical;
134
136
137
+ // Used only in assertions, check that Partitions promised to be canonical
138
+ // are actually canonical
139
+ bool is_canonical_correct () {
140
+ if (!canonical) return true ; // vacuously correct
141
+
142
+ auto fail = [&](unsigned i, int type) {
143
+ llvm::dbgs () << " FAIL(i=" << i << " ; type=" << type << " ): " ;
144
+ dump ();
145
+ return false ;
146
+ };
147
+
148
+ for (auto &[i, label] : labels) {
149
+ // correctness vacuous at consumed indices
150
+ if (label < 0 ) continue ;
151
+
152
+ // this label should not exceed fresh_label
153
+ if (label >= fresh_label) return fail (i, 0 );
154
+
155
+ // the label of a region should be at most as large as each index in it
156
+ if (i < label) return fail (i, 1 );
157
+
158
+ // each region label should refer to an index in that region
159
+ if (labels[label] != label) return fail (i, 2 );
160
+ }
161
+
162
+ return true ;
163
+ }
164
+
135
165
// linear time - For each region label that occurs, find the first index
136
166
// at which it occurs and relabel all instances of it to that index.
137
167
// This excludes the -1 label for consumed regions.
@@ -163,11 +193,32 @@ class Partition {
163
193
// set fresh_label
164
194
fresh_label = i + 1 ;
165
195
}
196
+
197
+ assert (is_canonical_correct ());
198
+ }
199
+
200
+ // linear time - merge the regions of two indices, maintaining canonicality
201
+ void merge (unsigned fst, unsigned snd) {
202
+ assert (labels.count (fst) && labels.count (snd));
203
+ if (labels[fst] == labels[snd])
204
+ return ;
205
+
206
+ // maintain canonicality by renaming the greater-numbered region
207
+ if (labels[fst] < labels[snd])
208
+ horizontalUpdate (labels, snd, labels[fst]);
209
+ else
210
+ horizontalUpdate (labels, fst, labels[snd]);
211
+
212
+ assert (is_canonical_correct ());
166
213
}
167
214
168
215
public:
169
216
Partition () : labels({}), canonical(true ) {}
170
217
218
+ // 1-arg constructor used when canonicality will be immediately invalidated,
219
+ // so set to false to begin with
220
+ Partition (bool canonical) : labels({}), canonical(canonical) {}
221
+
171
222
static Partition singleRegion (std::vector<unsigned > indices) {
172
223
Partition p;
173
224
if (!indices.empty ()) {
@@ -177,17 +228,9 @@ class Partition {
177
228
p.labels [index] = min_index;
178
229
}
179
230
}
180
- return p;
181
- }
182
231
183
- void dump () const {
184
- llvm::dbgs () << " Partition" ;
185
- if (canonical)
186
- llvm::dbgs () << " (canonical)" ;
187
- llvm::dbgs () << " {" ;
188
- for (const auto &[i, label] : labels)
189
- llvm::dbgs () << " [" << i << " : " << label << " ] " ;
190
- llvm::dbgs () << " }\n " ;
232
+ assert (p.is_canonical_correct ());
233
+ return p;
191
234
}
192
235
193
236
// linear time - Test two partititons for equality by first putting them
@@ -204,54 +247,52 @@ class Partition {
204
247
// and two indices are in the same region of the join iff they are in the same
205
248
// region in either operand.
206
249
static Partition join (Partition &fst, Partition &snd) {
207
- fst.canonicalize ();
208
- snd.canonicalize ();
209
-
210
- std::map<unsigned , signed > relabel_fst;
211
- std::map<unsigned , signed > relabel_snd;
212
- auto lookup_fst = [&](unsigned i) {
213
- // signed to unsigned conversion... ?
214
- return relabel_fst.count (fst.labels [i]) ? relabel_fst[fst.labels [i]]
215
- : fst.labels [i];
216
- };
217
-
218
- auto lookup_snd = [&](unsigned i) {
219
- // signed to unsigned conversion... safe?
220
- return relabel_snd.count (snd.labels [i]) ? relabel_snd[snd.labels [i]]
221
- : snd.labels [i];
222
- };
223
-
224
- for (const auto &[i, _] : fst.labels ) {
225
- // only consider indices present in both fst and snd
226
- if (!snd.labels .count (i))
227
- continue ;
228
-
229
- signed label_joined = std::min (lookup_fst (i), lookup_snd (i));
230
-
231
- horizontalUpdate (relabel_fst, fst.labels [i], label_joined);
232
- horizontalUpdate (relabel_snd, snd.labels [i], label_joined);
250
+ // ensure copies are made
251
+ Partition fst_reduced = false ;
252
+ Partition snd_reduced = false ;
253
+
254
+ // make canonical copies of fst and snd, reduced to their intersected domain
255
+ for (auto [i, _] : fst.labels )
256
+ if (snd.labels .count (i)) {
257
+ fst_reduced.labels [i] = fst.labels [i];
258
+ snd_reduced.labels [i] = snd.labels [i];
259
+ }
260
+ fst_reduced.canonicalize ();
261
+ snd_reduced.canonicalize ();
262
+
263
+ // merging each index in fst with its label in snd ensures that all pairs
264
+ // of indices that are in the same region in snd are also in the same region
265
+ // in fst - the desired property
266
+ for (const auto [i, snd_label] : snd_reduced.labels ) {
267
+ if (snd_label < 0 )
268
+ // if snd says that the region has been consumed, mark it consumed in fst
269
+ horizontalUpdate (fst_reduced.labels , i, -1 );
270
+ else
271
+ fst_reduced.merge (i, snd_label);
233
272
}
234
273
235
- Partition joined;
236
- joined.canonical = true ;
237
- for (const auto &[i, _] : fst.labels ) {
238
- if (!snd.labels .count (i))
239
- continue ;
240
- signed label_i = lookup_fst (i);
241
- joined.labels [i] = label_i;
242
- joined.fresh_label = std::max (joined.fresh_label , (unsigned ) label_i + 1 );
243
- }
274
+ assert (fst_reduced.is_canonical_correct ());
244
275
245
- return joined;
276
+ // fst_reduced is now the join
277
+ return fst_reduced;
246
278
}
247
279
248
280
// Apply the passed PartitionOp to this partition, performing its action.
249
281
// A `handleFailure` closure can optionally be passed in that will be called
250
282
// if a consumed region is required. The closure is given the PartitionOp that
251
283
// failed, and the index of the SIL value that was required but consumed.
284
+ // Additionally, a list of "nonconsumable" indices can be passed in along with
285
+ // a handleConsumeNonConsumable closure. In the event that a region containing
286
+ // one of the nonconsumable indices is consumed, the closure will be called
287
+ // with the offending Consume.
252
288
void apply (
253
- PartitionOp op, llvm::function_ref<void (const PartitionOp&, unsigned )> handleFailure =
254
- [](const PartitionOp&, unsigned ) {}) {
289
+ PartitionOp op,
290
+ llvm::function_ref<void (const PartitionOp&, unsigned )>
291
+ handleFailure = [](const PartitionOp&, unsigned ) {},
292
+ std::vector<unsigned > nonconsumables = {},
293
+ llvm::function_ref<void (const PartitionOp&, unsigned )>
294
+ handleConsumeNonConsumable = [](const PartitionOp&, unsigned ) {}
295
+ ) {
255
296
switch (op.OpKind ) {
256
297
case PartitionOpKind::Assign:
257
298
assert (op.OpArgs .size () == 2 &&
@@ -290,6 +331,19 @@ class Partition {
290
331
291
332
// mark region as consumed
292
333
horizontalUpdate (labels, op.OpArgs [0 ], -1 );
334
+
335
+ // check if any nonconsumables were consumed, and handle the failure if so
336
+ for (unsigned nonconsumable : nonconsumables) {
337
+ assert (labels.count (nonconsumable) &&
338
+ " nonconsumables should be function args and self, and therefore"
339
+ " always present in the label map because of initialization at "
340
+ " entry" );
341
+ if (labels[nonconsumable] < 0 ) {
342
+ handleConsumeNonConsumable (op, nonconsumable);
343
+ break ;
344
+ }
345
+ }
346
+
293
347
break ;
294
348
case PartitionOpKind::Merge:
295
349
assert (op.OpArgs .size () == 2 &&
@@ -302,14 +356,7 @@ class Partition {
302
356
if (labels[op.OpArgs [1 ]] < 0 )
303
357
handleFailure (op, op.OpArgs [1 ]);
304
358
305
- if (labels[op.OpArgs [0 ]] == labels[op.OpArgs [1 ]])
306
- break ;
307
-
308
- // maintain canonicality by renaming the greater-numbered region
309
- if (labels[op.OpArgs [0 ]] < labels[op.OpArgs [1 ]])
310
- horizontalUpdate (labels, op.OpArgs [1 ], labels[op.OpArgs [0 ]]);
311
- else
312
- horizontalUpdate (labels, op.OpArgs [0 ], labels[op.OpArgs [1 ]]);
359
+ merge (op.OpArgs [0 ], op.OpArgs [1 ]);
313
360
break ;
314
361
case PartitionOpKind::Require:
315
362
assert (op.OpArgs .size () == 1 &&
@@ -319,9 +366,21 @@ class Partition {
319
366
if (labels[op.OpArgs [0 ]] < 0 )
320
367
handleFailure (op, op.OpArgs [0 ]);
321
368
}
369
+
370
+ assert (is_canonical_correct ());
371
+ }
372
+
373
+ void dump_labels () const LLVM_ATTRIBUTE_USED {
374
+ llvm::dbgs () << " Partition" ;
375
+ if (canonical)
376
+ llvm::dbgs () << " (canonical)" ;
377
+ llvm::dbgs () << " (fresh=" << fresh_label << " ){" ;
378
+ for (const auto &[i, label] : labels)
379
+ llvm::dbgs () << " [" << i << " : " << label << " ] " ;
380
+ llvm::dbgs () << " }\n " ;
322
381
}
323
382
324
- void dump () {
383
+ void dump () LLVM_ATTRIBUTE_USED {
325
384
std::map<signed , std::vector<unsigned >> buckets;
326
385
327
386
for (auto [i, label] : labels) {
@@ -331,12 +390,15 @@ class Partition {
331
390
llvm::dbgs () << " [" ;
332
391
for (auto [label, indices] : buckets) {
333
392
llvm::dbgs () << (label < 0 ? " {" : " (" );
393
+ int j = 0 ;
334
394
for (unsigned i : indices) {
335
- llvm::dbgs () << i << " " ;
395
+ llvm::dbgs () << (j++? " " : " " ) << i ;
336
396
}
337
397
llvm::dbgs () << (label < 0 ? " }" : " )" );
338
398
}
339
- llvm::dbgs () << " ]\n " ;
399
+ llvm::dbgs () << " ] | " ;
400
+
401
+ dump_labels ();
340
402
}
341
403
};
342
404
}
0 commit comments