@@ -236,9 +236,75 @@ struct RequirementMachine::Implementation {
236
236
: Context(ctx),
237
237
System(Context),
238
238
Map(Context, System.getProtocols()) {}
239
+ void verify (const MutableTerm &term);
239
240
void dump (llvm::raw_ostream &out);
240
241
};
241
242
243
+ void RequirementMachine::Implementation::verify (const MutableTerm &term) {
244
+ #ifndef NDEBUG
245
+ MutableTerm erased;
246
+
247
+ // First, "erase" resolved associated types from the term, and try
248
+ // to simplify it again.
249
+ for (auto atom : term) {
250
+ if (erased.empty ()) {
251
+ switch (atom.getKind ()) {
252
+ case Atom::Kind::Protocol:
253
+ case Atom::Kind::GenericParam:
254
+ erased.add (atom);
255
+ continue ;
256
+
257
+ case Atom::Kind::AssociatedType:
258
+ erased.add (Atom::forProtocol (atom.getProtocols ()[0 ], Context));
259
+ break ;
260
+
261
+ case Atom::Kind::Name:
262
+ case Atom::Kind::Layout:
263
+ case Atom::Kind::Superclass:
264
+ case Atom::Kind::ConcreteType:
265
+ llvm::errs () << " Bad initial atom in " << term << " \n " ;
266
+ abort ();
267
+ break ;
268
+ }
269
+ }
270
+
271
+ switch (atom.getKind ()) {
272
+ case Atom::Kind::Name:
273
+ assert (!erased.empty ());
274
+ erased.add (atom);
275
+ break ;
276
+
277
+ case Atom::Kind::AssociatedType:
278
+ erased.add (Atom::forName (atom.getName (), Context));
279
+ break ;
280
+
281
+ case Atom::Kind::Protocol:
282
+ case Atom::Kind::GenericParam:
283
+ case Atom::Kind::Layout:
284
+ case Atom::Kind::Superclass:
285
+ case Atom::Kind::ConcreteType:
286
+ llvm::errs () << " Bad interior atom " << atom << " in " << term << " \n " ;
287
+ abort ();
288
+ break ;
289
+ }
290
+ }
291
+
292
+ MutableTerm simplified = erased;
293
+ System.simplify (simplified);
294
+
295
+ // We should end up with the same term.
296
+ if (simplified != term) {
297
+ llvm::errs () << " Term verification failed\n " ;
298
+ llvm::errs () << " Initial term: " << term << " \n " ;
299
+ llvm::errs () << " Erased term: " << erased << " \n " ;
300
+ llvm::errs () << " Simplified term: " << simplified << " \n " ;
301
+ llvm::errs () << " \n " ;
302
+ dump (llvm::errs ());
303
+ abort ();
304
+ }
305
+ #endif
306
+ }
307
+
242
308
void RequirementMachine::Implementation::dump (llvm::raw_ostream &out) {
243
309
out << " Requirement machine for " << Sig << " \n " ;
244
310
System.dump (out);
@@ -368,6 +434,7 @@ bool RequirementMachine::requiresClass(Type depType) const {
368
434
auto term = Impl->Context .getMutableTermForType (depType->getCanonicalType (),
369
435
/* proto=*/ nullptr );
370
436
Impl->System .simplify (term);
437
+ Impl->verify (term);
371
438
372
439
auto *equivClass = Impl->Map .lookUpEquivalenceClass (term);
373
440
if (!equivClass)
@@ -384,6 +451,7 @@ LayoutConstraint RequirementMachine::getLayoutConstraint(Type depType) const {
384
451
auto term = Impl->Context .getMutableTermForType (depType->getCanonicalType (),
385
452
/* proto=*/ nullptr );
386
453
Impl->System .simplify (term);
454
+ Impl->verify (term);
387
455
388
456
auto *equivClass = Impl->Map .lookUpEquivalenceClass (term);
389
457
if (!equivClass)
@@ -397,6 +465,7 @@ bool RequirementMachine::requiresProtocol(Type depType,
397
465
auto term = Impl->Context .getMutableTermForType (depType->getCanonicalType (),
398
466
/* proto=*/ nullptr );
399
467
Impl->System .simplify (term);
468
+ Impl->verify (term);
400
469
401
470
auto *equivClass = Impl->Map .lookUpEquivalenceClass (term);
402
471
if (!equivClass)
@@ -418,6 +487,7 @@ RequirementMachine::getRequiredProtocols(Type depType) const {
418
487
auto term = Impl->Context .getMutableTermForType (depType->getCanonicalType (),
419
488
/* proto=*/ nullptr );
420
489
Impl->System .simplify (term);
490
+ Impl->verify (term);
421
491
422
492
auto *equivClass = Impl->Map .lookUpEquivalenceClass (term);
423
493
if (!equivClass)
@@ -440,6 +510,7 @@ bool RequirementMachine::isConcreteType(Type depType) const {
440
510
auto term = Impl->Context .getMutableTermForType (depType->getCanonicalType (),
441
511
/* proto=*/ nullptr );
442
512
Impl->System .simplify (term);
513
+ Impl->verify (term);
443
514
444
515
auto *equivClass = Impl->Map .lookUpEquivalenceClass (term);
445
516
if (!equivClass)
@@ -453,10 +524,12 @@ bool RequirementMachine::areSameTypeParameterInContext(Type depType1,
453
524
auto term1 = Impl->Context .getMutableTermForType (depType1->getCanonicalType (),
454
525
/* proto=*/ nullptr );
455
526
Impl->System .simplify (term1);
527
+ Impl->verify (term1);
456
528
457
529
auto term2 = Impl->Context .getMutableTermForType (depType2->getCanonicalType (),
458
530
/* proto=*/ nullptr );
459
531
Impl->System .simplify (term2);
532
+ Impl->verify (term2);
460
533
461
534
return (term1 == term2);
462
535
}
0 commit comments