@@ -128,7 +128,11 @@ struct UnboundImport {
128
128
129
129
private:
130
130
void validatePrivate (ModuleDecl *topLevelModule);
131
- void validateImplementationOnly (ASTContext &ctx);
131
+
132
+ // / Check that no import has more than one of the following modifiers:
133
+ // / @_exported, @_implementationOnly, and @_spiOnly.
134
+ void validateRestrictedImport (ASTContext &ctx);
135
+
132
136
void validateTestable (ModuleDecl *topLevelModule);
133
137
void validateResilience (NullablePtr<ModuleDecl> topLevelModule,
134
138
SourceFile &SF);
@@ -598,7 +602,7 @@ bool UnboundImport::checkModuleLoaded(ModuleDecl *M, SourceFile &SF) {
598
602
599
603
void UnboundImport::validateOptions (NullablePtr<ModuleDecl> topLevelModule,
600
604
SourceFile &SF) {
601
- validateImplementationOnly (SF.getASTContext ());
605
+ validateRestrictedImport (SF.getASTContext ());
602
606
603
607
if (auto *top = topLevelModule.getPtrOrNull ()) {
604
608
// FIXME: Having these two calls in this if condition seems dubious.
@@ -633,16 +637,55 @@ void UnboundImport::validatePrivate(ModuleDecl *topLevelModule) {
633
637
import .sourceFileArg = StringRef ();
634
638
}
635
639
636
- void UnboundImport::validateImplementationOnly (ASTContext &ctx) {
637
- if (!import .options .contains (ImportFlags::ImplementationOnly) ||
638
- !import .options .contains (ImportFlags::Exported))
640
+ void UnboundImport::validateRestrictedImport (ASTContext &ctx) {
641
+ static llvm::SmallVector<ImportFlags, 2 > flags = {ImportFlags::Exported,
642
+ ImportFlags::ImplementationOnly,
643
+ ImportFlags::SPIOnly};
644
+ llvm::SmallVector<ImportFlags, 2 > conflicts;
645
+
646
+ for (auto flag : flags) {
647
+ if (import .options .contains (flag))
648
+ conflicts.push_back (flag);
649
+ }
650
+
651
+ // Quit if there's no conflicting attributes.
652
+ if (conflicts.size () < 2 )
639
653
return ;
640
654
641
- // Remove one flag to maintain the invariant.
642
- import .options -= ImportFlags::ImplementationOnly;
655
+ // Remove all but one flag to maintain the invariant.
656
+ for (auto iter = conflicts.begin (); iter != std::prev (conflicts.end ()); iter ++)
657
+ import .options -= *iter;
658
+
659
+ DeclAttrKind attrToRemove = conflicts[0 ] == ImportFlags::ImplementationOnly?
660
+ DAK_Exported : DAK_ImplementationOnly;
661
+
662
+ auto flagName = [](ImportFlags flag) {
663
+ switch (flag) {
664
+ case ImportFlags::ImplementationOnly:
665
+ return " implementation-only" ;
666
+ case ImportFlags::SPIOnly:
667
+ return " SPI only" ;
668
+ case ImportFlags::Exported:
669
+ return " exported" ;
670
+ default :
671
+ llvm_unreachable (" Unexpected ImportFlag" );
672
+ }
673
+ };
674
+
675
+ // Report the conflict, only the first two conflicts should be enough.
676
+ auto diag = ctx.Diags .diagnose (import .module .getModulePath ().front ().Loc ,
677
+ diag::import_restriction_conflict,
678
+ import .module .getModulePath ().front ().Item ,
679
+ flagName (conflicts[0 ]),
680
+ flagName (conflicts[1 ]));
643
681
644
- diagnoseInvalidAttr (DAK_ImplementationOnly, ctx.Diags ,
645
- diag::import_implementation_cannot_be_exported);
682
+ auto *ID = getImportDecl ().getPtrOrNull ();
683
+ if (!ID) return ;
684
+ auto *attr = ID->getAttrs ().getAttribute (attrToRemove);
685
+ if (!attr) return ;
686
+
687
+ diag.fixItRemove (attr->getRangeWithAt ());
688
+ attr->setInvalid ();
646
689
}
647
690
648
691
void UnboundImport::validateTestable (ModuleDecl *topLevelModule) {
0 commit comments