Skip to content

Commit 064eb27

Browse files
authored
Merge pull request #41622 from slavapestov/fix-protocol-overrides
Fix generics invariant violations in protocol requirement override checking
2 parents d8ebecf + 38d188b commit 064eb27

File tree

5 files changed

+151
-71
lines changed

5 files changed

+151
-71
lines changed

lib/AST/ASTContext.cpp

Lines changed: 56 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,14 @@ using AssociativityCacheType =
101101
struct OverrideSignatureKey {
102102
GenericSignature baseMethodSig;
103103
GenericSignature derivedMethodSig;
104-
Decl *subclassDecl;
104+
NominalTypeDecl *derivedNominal;
105105

106106
OverrideSignatureKey(GenericSignature baseMethodSignature,
107107
GenericSignature derivedMethodSignature,
108-
Decl *subclassDecl)
108+
NominalTypeDecl *derivedNominal)
109109
: baseMethodSig(baseMethodSignature),
110110
derivedMethodSig(derivedMethodSignature),
111-
subclassDecl(subclassDecl) {}
111+
derivedNominal(derivedNominal) {}
112112
};
113113

114114
namespace llvm {
@@ -120,27 +120,27 @@ template <> struct DenseMapInfo<OverrideSignatureKey> {
120120
const OverrideSignatureKey rhs) {
121121
return lhs.baseMethodSig.getPointer() == rhs.baseMethodSig.getPointer() &&
122122
lhs.derivedMethodSig.getPointer() == rhs.derivedMethodSig.getPointer() &&
123-
lhs.subclassDecl == rhs.subclassDecl;
123+
lhs.derivedNominal == rhs.derivedNominal;
124124
}
125125

126126
static inline OverrideSignatureKey getEmptyKey() {
127127
return OverrideSignatureKey(DenseMapInfo<GenericSignature>::getEmptyKey(),
128128
DenseMapInfo<GenericSignature>::getEmptyKey(),
129-
DenseMapInfo<Decl *>::getEmptyKey());
129+
DenseMapInfo<NominalTypeDecl *>::getEmptyKey());
130130
}
131131

132132
static inline OverrideSignatureKey getTombstoneKey() {
133133
return OverrideSignatureKey(
134134
DenseMapInfo<GenericSignature>::getTombstoneKey(),
135135
DenseMapInfo<GenericSignature>::getTombstoneKey(),
136-
DenseMapInfo<Decl *>::getTombstoneKey());
136+
DenseMapInfo<NominalTypeDecl *>::getTombstoneKey());
137137
}
138138

139139
static unsigned getHashValue(const OverrideSignatureKey &Val) {
140140
return hash_combine(
141141
DenseMapInfo<GenericSignature>::getHashValue(Val.baseMethodSig),
142142
DenseMapInfo<GenericSignature>::getHashValue(Val.derivedMethodSig),
143-
DenseMapInfo<Decl *>::getHashValue(Val.subclassDecl));
143+
DenseMapInfo<NominalTypeDecl *>::getHashValue(Val.derivedNominal));
144144
}
145145
};
146146
} // namespace llvm
@@ -5214,11 +5214,11 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
52145214
assert(isa<AbstractFunctionDecl>(base) || isa<SubscriptDecl>(base));
52155215
assert(isa<AbstractFunctionDecl>(derived) || isa<SubscriptDecl>(derived));
52165216

5217-
const auto baseClass = base->getDeclContext()->getSelfClassDecl();
5218-
const auto derivedClass = derived->getDeclContext()->getSelfClassDecl();
5217+
const auto baseNominal = base->getDeclContext()->getSelfNominalTypeDecl();
5218+
const auto derivedNominal = derived->getDeclContext()->getSelfNominalTypeDecl();
52195219

5220-
assert(baseClass != nullptr);
5221-
assert(derivedClass != nullptr);
5220+
assert(baseNominal != nullptr);
5221+
assert(derivedNominal != nullptr);
52225222

52235223
const auto baseGenericSig =
52245224
base->getAsGenericContext()->getGenericSignature();
@@ -5228,10 +5228,6 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
52285228
if (base == derived)
52295229
return derivedGenericSig;
52305230

5231-
const auto derivedSuperclass = derivedClass->getSuperclass();
5232-
if (derivedSuperclass.isNull())
5233-
return nullptr;
5234-
52355231
if (derivedGenericSig.isNull())
52365232
return nullptr;
52375233

@@ -5240,21 +5236,14 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
52405236

52415237
auto key = OverrideSignatureKey(baseGenericSig,
52425238
derivedGenericSig,
5243-
derivedClass);
5239+
derivedNominal);
52445240

52455241
if (getImpl().overrideSigCache.find(key) !=
52465242
getImpl().overrideSigCache.end()) {
52475243
return getImpl().overrideSigCache.lookup(key);
52485244
}
52495245

5250-
const auto derivedClassSig = derivedClass->getGenericSignature();
5251-
5252-
unsigned derivedDepth = 0;
5253-
unsigned baseDepth = 0;
5254-
if (derivedClassSig)
5255-
derivedDepth = derivedClassSig.getGenericParams().back()->getDepth() + 1;
5256-
if (const auto baseClassSig = baseClass->getGenericSignature())
5257-
baseDepth = baseClassSig.getGenericParams().back()->getDepth() + 1;
5246+
const auto derivedNominalSig = derivedNominal->getGenericSignature();
52585247

52595248
SmallVector<GenericTypeParamType *, 2> addedGenericParams;
52605249
if (const auto *gpList = derived->getAsGenericContext()->getGenericParams()) {
@@ -5264,38 +5253,59 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
52645253
}
52655254
}
52665255

5267-
const auto subMap = derivedSuperclass->getContextSubstitutionMap(
5268-
derivedClass->getModuleContext(), baseClass);
5256+
SmallVector<Requirement, 2> addedRequirements;
52695257

5270-
auto substFn = [&](SubstitutableType *type) -> Type {
5271-
auto *gp = cast<GenericTypeParamType>(type);
5258+
if (isa<ProtocolDecl>(baseNominal)) {
5259+
assert(isa<ProtocolDecl>(derivedNominal));
52725260

5273-
if (gp->getDepth() < baseDepth) {
5274-
return Type(gp).subst(subMap);
5261+
for (auto reqt : baseGenericSig.getRequirements()) {
5262+
addedRequirements.push_back(reqt);
52755263
}
5264+
} else {
5265+
const auto derivedSuperclass = cast<ClassDecl>(derivedNominal)
5266+
->getSuperclass();
5267+
if (derivedSuperclass.isNull())
5268+
return nullptr;
52765269

5277-
return CanGenericTypeParamType::get(
5278-
gp->isTypeSequence(), gp->getDepth() - baseDepth + derivedDepth,
5279-
gp->getIndex(), *this);
5280-
};
5270+
unsigned derivedDepth = 0;
5271+
unsigned baseDepth = 0;
5272+
if (derivedNominalSig)
5273+
derivedDepth = derivedNominalSig.getGenericParams().back()->getDepth() + 1;
5274+
if (const auto baseNominalSig = baseNominal->getGenericSignature())
5275+
baseDepth = baseNominalSig.getGenericParams().back()->getDepth() + 1;
52815276

5282-
auto lookupConformanceFn =
5283-
[&](CanType depTy, Type substTy,
5284-
ProtocolDecl *proto) -> ProtocolConformanceRef {
5285-
if (auto conf = subMap.lookupConformance(depTy, proto))
5286-
return conf;
5277+
const auto subMap = derivedSuperclass->getContextSubstitutionMap(
5278+
derivedNominal->getModuleContext(), baseNominal);
52875279

5288-
return ProtocolConformanceRef(proto);
5289-
};
5280+
auto substFn = [&](SubstitutableType *type) -> Type {
5281+
auto *gp = cast<GenericTypeParamType>(type);
52905282

5291-
SmallVector<Requirement, 2> addedRequirements;
5292-
for (auto reqt : baseGenericSig.getRequirements()) {
5293-
if (auto substReqt = reqt.subst(substFn, lookupConformanceFn)) {
5294-
addedRequirements.push_back(*substReqt);
5283+
if (gp->getDepth() < baseDepth) {
5284+
return Type(gp).subst(subMap);
5285+
}
5286+
5287+
return CanGenericTypeParamType::get(
5288+
gp->isTypeSequence(), gp->getDepth() - baseDepth + derivedDepth,
5289+
gp->getIndex(), *this);
5290+
};
5291+
5292+
auto lookupConformanceFn =
5293+
[&](CanType depTy, Type substTy,
5294+
ProtocolDecl *proto) -> ProtocolConformanceRef {
5295+
if (auto conf = subMap.lookupConformance(depTy, proto))
5296+
return conf;
5297+
5298+
return ProtocolConformanceRef(proto);
5299+
};
5300+
5301+
for (auto reqt : baseGenericSig.getRequirements()) {
5302+
if (auto substReqt = reqt.subst(substFn, lookupConformanceFn)) {
5303+
addedRequirements.push_back(*substReqt);
5304+
}
52955305
}
52965306
}
52975307

5298-
auto genericSig = buildGenericSignature(*this, derivedClassSig,
5308+
auto genericSig = buildGenericSignature(*this, derivedNominalSig,
52995309
std::move(addedGenericParams),
53005310
std::move(addedRequirements));
53015311
getImpl().overrideSigCache.insert(std::make_pair(key, genericSig));

lib/AST/SubstitutionMap.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -511,15 +511,9 @@ SubstitutionMap::getOverrideSubstitutions(
511511
Optional<SubstitutionMap> derivedSubs) {
512512
// For overrides within a protocol hierarchy, substitute the Self type.
513513
if (auto baseProto = baseDecl->getDeclContext()->getSelfProtocolDecl()) {
514-
if (auto derivedProtoSelf =
515-
derivedDecl->getDeclContext()->getSelfInterfaceType()) {
516-
return SubstitutionMap::getProtocolSubstitutions(
517-
baseProto,
518-
derivedProtoSelf,
519-
ProtocolConformanceRef(baseProto));
520-
}
521-
522-
return SubstitutionMap();
514+
auto baseSig = baseDecl->getInnermostDeclContext()
515+
->getGenericSignatureOfContext();
516+
return baseSig->getIdentitySubstitutionMap();
523517
}
524518

525519
auto *baseClass = baseDecl->getDeclContext()->getSelfClassDecl();

lib/Sema/TypeCheckDeclOverride.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -164,25 +164,32 @@ bool swift::isOverrideBasedOnType(const ValueDecl *decl, Type declTy,
164164
//
165165
// We can still succeed with a subtype match later in
166166
// OverrideMatcher::match().
167-
if (decl->getDeclContext()->getSelfClassDecl()) {
168-
if (auto declCtx = decl->getAsGenericContext()) {
169-
auto *parentCtx = parentDecl->getAsGenericContext();
167+
if (auto declCtx = decl->getAsGenericContext()) {
168+
// The below logic now works correctly for protocol requirements which are
169+
// themselves generic, but that would be an ABI break, since we would now
170+
// drop the protocol requirements from witness tables. Simulate the old
171+
// behavior by not considering generic declarations in protocols as
172+
// overrides at all.
173+
if (decl->getDeclContext()->getSelfProtocolDecl() &&
174+
declCtx->isGeneric())
175+
return false;
170176

171-
if (declCtx->isGeneric() != parentCtx->isGeneric())
172-
return false;
177+
auto *parentCtx = parentDecl->getAsGenericContext();
173178

174-
if (declCtx->isGeneric() &&
175-
(declCtx->getGenericParams()->size() !=
176-
parentCtx->getGenericParams()->size()))
177-
return false;
179+
if (declCtx->isGeneric() != parentCtx->isGeneric())
180+
return false;
178181

179-
auto &ctx = decl->getASTContext();
180-
auto sig = ctx.getOverrideGenericSignature(parentDecl, decl);
181-
if (sig &&
182-
declCtx->getGenericSignature().getCanonicalSignature() !=
183-
sig.getCanonicalSignature()) {
184-
return false;
185-
}
182+
if (declCtx->isGeneric() &&
183+
(declCtx->getGenericParams()->size() !=
184+
parentCtx->getGenericParams()->size()))
185+
return false;
186+
187+
auto &ctx = decl->getASTContext();
188+
auto sig = ctx.getOverrideGenericSignature(parentDecl, decl);
189+
if (sig &&
190+
declCtx->getGenericSignature().getCanonicalSignature() !=
191+
sig.getCanonicalSignature()) {
192+
return false;
186193
}
187194
}
188195

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: %target-swift-frontend -emit-silgen %s -verify -warn-implicit-overrides | %FileCheck %s
2+
3+
protocol Base {
4+
func foo1<T : P>(_: T)
5+
func foo2<T : P>(_: T, _: T.T)
6+
func foo3<T : P>(_: T, _: T.T)
7+
}
8+
9+
protocol Derived : Base {
10+
func foo1<T : P>(_: T)
11+
func foo2<T : P>(_: T, _: T.T)
12+
func foo3<T : Q>(_: T, _: T.T)
13+
}
14+
15+
protocol P {
16+
associatedtype T
17+
}
18+
19+
protocol Q {
20+
associatedtype T
21+
}
22+
23+
struct S : Derived {
24+
func foo1<T : P>(_: T) {}
25+
func foo2<T : P>(_: T, _: T.T) {}
26+
func foo3<T : P>(_: T, _: T.T) {}
27+
func foo3<T : Q>(_: T, _: T.T) {}
28+
}
29+
30+
// Make sure that Derived.foo1 and Derived.foo2 are not counted as overrides of
31+
// Base.foo1 and Base.foo2 respectively. Even though their types match, bugs
32+
// in Swift 5.6 and earlier prevented them from being overrides. We can't fix
33+
// it now because it would be an ABI break.
34+
35+
// CHECK-LABEL: sil_witness_table hidden S: Derived module override_generic {
36+
// CHECK-NEXT: base_protocol Base: S: Base module override_generic
37+
// CHECK-NEXT: method #Derived.foo1: <Self where Self : Derived><T where T : P> (Self) -> (T) -> () : @$s16override_generic1SVAA7DerivedA2aDP4foo1yyqd__AA1PRd__lFTW
38+
// CHECK-NEXT: method #Derived.foo2: <Self where Self : Derived><T where T : P> (Self) -> (T, T.T) -> () : @$s16override_generic1SVAA7DerivedA2aDP4foo2yyqd___1TQyd__tAA1PRd__lFTW
39+
// CHECK-NEXT: method #Derived.foo3: <Self where Self : Derived><T where T : Q> (Self) -> (T, T.T) -> () : @$s16override_generic1SVAA7DerivedA2aDP4foo3yyqd___1TQyd__tAA1QRd__lFTW
40+
// CHECK-NEXT: }
41+
42+
// CHECK-LABEL: sil_witness_table hidden S: Base module override_generic {
43+
// CHECK-NEXT: method #Base.foo1: <Self where Self : Base><T where T : P> (Self) -> (T) -> () : @$s16override_generic1SVAA4BaseA2aDP4foo1yyqd__AA1PRd__lFTW
44+
// CHECK-NEXT: method #Base.foo2: <Self where Self : Base><T where T : P> (Self) -> (T, T.T) -> () : @$s16override_generic1SVAA4BaseA2aDP4foo2yyqd___1TQyd__tAA1PRd__lFTW
45+
// CHECK-NEXT: method #Base.foo3: <Self where Self : Base><T where T : P> (Self) -> (T, T.T) -> () : @$s16override_generic1SVAA4BaseA2aDP4foo3yyqd___1TQyd__tAA1PRd__lFTW
46+
// CHECK-NEXT: }
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: %target-swift-frontend -emit-ir %s
2+
3+
public struct Observable<T> {}
4+
5+
public protocol BaseVariant: CaseIterable, Equatable {}
6+
7+
public protocol FeatureGate {
8+
associatedtype Variant: BaseVariant
9+
}
10+
11+
public enum FeatureVariantState<T: BaseVariant>: Equatable {}
12+
13+
public protocol BaseGatingProvider {
14+
func exposeFeatureVariantState<G: FeatureGate>(for featureGate: G)
15+
-> Observable<FeatureVariantState<G.Variant>>
16+
}
17+
18+
public struct UserFeatureGate<Variant: BaseVariant>: FeatureGate {}
19+
20+
public protocol UserGatingProvider: BaseGatingProvider {
21+
func exposeFeatureVariantState<V>(for featureGate: UserFeatureGate<V>)
22+
-> Observable<FeatureVariantState<V>>
23+
}

0 commit comments

Comments
 (0)