@@ -458,6 +458,25 @@ swift::rewriting::desugarRequirement(Requirement req, SourceLoc loc,
458
458
}
459
459
}
460
460
461
+ void swift::rewriting::desugarRequirements (SmallVector<StructuralRequirement, 2 > &reqs,
462
+ SmallVectorImpl<RequirementError> &errors) {
463
+ SmallVector<StructuralRequirement, 2 > result;
464
+ for (auto req : reqs) {
465
+ SmallVector<Requirement, 2 > desugaredReqs;
466
+ SmallVector<RequirementError, 2 > ignoredErrors;
467
+
468
+ if (req.inferred )
469
+ desugarRequirement (req.req , SourceLoc (), desugaredReqs, ignoredErrors);
470
+ else
471
+ desugarRequirement (req.req , req.loc , desugaredReqs, errors);
472
+
473
+ for (auto desugaredReq : desugaredReqs)
474
+ result.push_back ({desugaredReq, req.loc , req.inferred });
475
+ }
476
+
477
+ std::swap (reqs, result);
478
+ }
479
+
461
480
//
462
481
// Requirement realization and inference.
463
482
//
@@ -467,8 +486,6 @@ static void realizeTypeRequirement(DeclContext *dc,
467
486
SourceLoc loc,
468
487
SmallVectorImpl<StructuralRequirement> &result,
469
488
SmallVectorImpl<RequirementError> &errors) {
470
- SmallVector<Requirement, 2 > reqs;
471
-
472
489
// The GenericSignatureBuilder allowed the right hand side of a
473
490
// conformance or superclass requirement to reference a protocol
474
491
// typealias whose underlying type was a protocol or class.
@@ -497,22 +514,19 @@ static void realizeTypeRequirement(DeclContext *dc,
497
514
}
498
515
499
516
if (constraintType->isConstraintType ()) {
500
- Requirement req (RequirementKind::Conformance, subjectType, constraintType);
501
- desugarRequirement (req, loc, reqs, errors);
517
+ result.push_back ({Requirement (RequirementKind::Conformance,
518
+ subjectType, constraintType),
519
+ loc, /* wasInferred=*/ false });
502
520
} else if (constraintType->getClassOrBoundGenericClass ()) {
503
- Requirement req (RequirementKind::Superclass, subjectType, constraintType);
504
- desugarRequirement (req, loc, reqs, errors);
521
+ result.push_back ({Requirement (RequirementKind::Superclass,
522
+ subjectType, constraintType),
523
+ loc, /* wasInferred=*/ false });
505
524
} else {
506
525
errors.push_back (
507
526
RequirementError::forInvalidTypeRequirement (subjectType,
508
527
constraintType,
509
528
loc));
510
- return ;
511
529
}
512
-
513
- // Add source location information.
514
- for (auto req : reqs)
515
- result.push_back ({req, loc, /* wasInferred=*/ false });
516
530
}
517
531
518
532
namespace {
@@ -521,11 +535,11 @@ namespace {
521
535
struct InferRequirementsWalker : public TypeWalker {
522
536
ModuleDecl *module ;
523
537
DeclContext *dc;
524
- SmallVector<Requirement, 2 > reqs;
525
- SmallVector<RequirementError, 2 > errors;
538
+ SmallVectorImpl<StructuralRequirement> &reqs;
526
539
527
- explicit InferRequirementsWalker (ModuleDecl *module , DeclContext *dc)
528
- : module(module ), dc(dc) {}
540
+ explicit InferRequirementsWalker (ModuleDecl *module , DeclContext *dc,
541
+ SmallVectorImpl<StructuralRequirement> &reqs)
542
+ : module(module ), dc(dc), reqs(reqs) {}
529
543
530
544
Action walkToTypePre (Type ty) override {
531
545
// Unbound generic types are the result of recovered-but-invalid code, and
@@ -555,8 +569,7 @@ struct InferRequirementsWalker : public TypeWalker {
555
569
return false ;
556
570
557
571
return (req.getKind () == RequirementKind::Conformance &&
558
- req.getSecondType ()->castTo <ProtocolType>()->getDecl ()
559
- ->isSpecificProtocol (KnownProtocolKind::Sendable));
572
+ req.getProtocolDecl ()->isSpecificProtocol (KnownProtocolKind::Sendable));
560
573
};
561
574
562
575
// Infer from generic typealiases.
@@ -567,7 +580,7 @@ struct InferRequirementsWalker : public TypeWalker {
567
580
if (skipRequirement (rawReq, decl))
568
581
continue ;
569
582
570
- desugarRequirement ( rawReq.subst (subMap), SourceLoc (), reqs, errors );
583
+ reqs. push_back ({ rawReq.subst (subMap), SourceLoc (), /* inferred= */ true } );
571
584
}
572
585
573
586
return Action::Continue;
@@ -581,10 +594,9 @@ struct InferRequirementsWalker : public TypeWalker {
581
594
packExpansion->getPatternType ()->getTypeParameterPacks (packReferences);
582
595
583
596
auto countType = packExpansion->getCountType ();
584
- for (auto pack : packReferences) {
585
- Requirement req (RequirementKind::SameShape, countType, pack);
586
- desugarRequirement (req, SourceLoc (), reqs, errors);
587
- }
597
+ for (auto pack : packReferences)
598
+ reqs.push_back ({Requirement (RequirementKind::SameShape, countType, pack),
599
+ SourceLoc (), /* inferred=*/ true });
588
600
}
589
601
590
602
// Infer requirements from `@differentiable` function types.
@@ -596,9 +608,9 @@ struct InferRequirementsWalker : public TypeWalker {
596
608
if (auto *fnTy = ty->getAs <AnyFunctionType>()) {
597
609
// Add a new conformance constraint for a fixed protocol.
598
610
auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) {
599
- Requirement req (RequirementKind::Conformance, type,
600
- protocol->getDeclaredInterfaceType ());
601
- desugarRequirement (req, SourceLoc (), reqs, errors );
611
+ reqs. push_back ({ Requirement (RequirementKind::Conformance, type,
612
+ protocol->getDeclaredInterfaceType ()),
613
+ SourceLoc (), /* inferred= */ true } );
602
614
};
603
615
604
616
auto &ctx = module ->getASTContext ();
@@ -610,8 +622,9 @@ struct InferRequirementsWalker : public TypeWalker {
610
622
auto secondType = assocType->getDeclaredInterfaceType ()
611
623
->castTo <DependentMemberType>()
612
624
->substBaseType (module , firstType);
613
- Requirement req (RequirementKind::SameType, firstType, secondType);
614
- desugarRequirement (req, SourceLoc (), reqs, errors);
625
+ reqs.push_back ({Requirement (RequirementKind::SameType,
626
+ firstType, secondType),
627
+ SourceLoc (), /* inferred=*/ true });
615
628
};
616
629
auto *tangentVectorAssocType =
617
630
differentiableProtocol->getAssociatedType (ctx.Id_TangentVector );
@@ -659,8 +672,7 @@ struct InferRequirementsWalker : public TypeWalker {
659
672
if (skipRequirement (rawReq, decl))
660
673
continue ;
661
674
662
- auto req = rawReq.subst (subMap);
663
- desugarRequirement (req, SourceLoc (), reqs, errors);
675
+ reqs.push_back ({rawReq.subst (subMap), SourceLoc (), /* inferred=*/ true });
664
676
}
665
677
666
678
return Action::Continue;
@@ -683,15 +695,12 @@ void swift::rewriting::inferRequirements(
683
695
if (!type)
684
696
return ;
685
697
686
- InferRequirementsWalker walker (module , dc);
698
+ InferRequirementsWalker walker (module , dc, result );
687
699
type.walk (walker);
688
-
689
- for (const auto &req : walker.reqs )
690
- result.push_back ({req, loc, /* wasInferred=*/ true });
691
700
}
692
701
693
- // / Desugar a requirement and perform requirement inference if requested
694
- // / to obtain zero or more structural requirements .
702
+ // / Perform requirement inference from the type representations in the
703
+ // / requirement itself (eg, `T == Set<U>` infers `U: Hashable`) .
695
704
void swift::rewriting::realizeRequirement (
696
705
DeclContext *dc,
697
706
Requirement req, RequirementRepr *reqRepr,
@@ -732,12 +741,7 @@ void swift::rewriting::realizeRequirement(
732
741
inferRequirements (firstType, firstLoc, moduleForInference, dc, result);
733
742
}
734
743
735
- SmallVector<Requirement, 2 > reqs;
736
- desugarRequirement (req, loc, reqs, errors);
737
-
738
- for (auto req : reqs)
739
- result.push_back ({req, loc, /* wasInferred=*/ false });
740
-
744
+ result.push_back ({req, loc, /* wasInferred=*/ false });
741
745
break ;
742
746
}
743
747
@@ -754,11 +758,7 @@ void swift::rewriting::realizeRequirement(
754
758
inferRequirements (secondType, secondLoc, moduleForInference, dc, result);
755
759
}
756
760
757
- SmallVector<Requirement, 2 > reqs;
758
- desugarRequirement (req, loc, reqs, errors);
759
-
760
- for (auto req : reqs)
761
- result.push_back ({req, loc, /* wasInferred=*/ false });
761
+ result.push_back ({req, loc, /* wasInferred=*/ false });
762
762
break ;
763
763
}
764
764
}
@@ -903,13 +903,13 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
903
903
ProtocolDecl *proto) const {
904
904
assert (!proto->hasLazyRequirementSignature ());
905
905
906
- SmallVector<StructuralRequirement, 4 > result;
907
- SmallVector<RequirementError, 4 > errors;
906
+ SmallVector<StructuralRequirement, 2 > result;
907
+ SmallVector<RequirementError, 2 > errors;
908
908
909
909
auto &ctx = proto->getASTContext ();
910
910
auto selfTy = proto->getSelfInterfaceType ();
911
911
912
- SmallVector<Type, 4 > needsDefaultReqirements ({selfTy});
912
+ SmallVector<Type, 4 > needsDefaultRequirements ({selfTy});
913
913
914
914
unsigned errorCount = errors.size ();
915
915
realizeInheritedRequirements (proto, selfTy,
@@ -950,7 +950,12 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
950
950
result.push_back ({Requirement (RequirementKind::Layout, selfTy, layout),
951
951
proto->getLoc (), /* inferred=*/ true });
952
952
953
- expandDefaultRequirements (ctx, needsDefaultReqirements, result, errors);
953
+ desugarRequirements (result, errors);
954
+ expandDefaultRequirements (ctx, needsDefaultRequirements, result, errors);
955
+
956
+ diagnoseRequirementErrors (ctx, errors,
957
+ AllowConcreteTypePolicy::NestedAssocTypes);
958
+
954
959
return ctx.AllocateCopy (result);
955
960
}
956
961
@@ -976,7 +981,7 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
976
981
return false ;
977
982
});
978
983
979
- needsDefaultReqirements .push_back (assocType);
984
+ needsDefaultRequirements .push_back (assocType);
980
985
}
981
986
982
987
// Add requirements for each typealias.
@@ -1014,7 +1019,8 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
1014
1019
}
1015
1020
}
1016
1021
1017
- expandDefaultRequirements (ctx, needsDefaultReqirements, result, errors);
1022
+ desugarRequirements (result, errors);
1023
+ expandDefaultRequirements (ctx, needsDefaultRequirements, result, errors);
1018
1024
1019
1025
diagnoseRequirementErrors (ctx, errors,
1020
1026
AllowConcreteTypePolicy::NestedAssocTypes);
0 commit comments