Skip to content

Commit d3da1b8

Browse files
committed
AST: Generalize findGenericParameterReferences()
1 parent 21cc894 commit d3da1b8

File tree

3 files changed

+66
-42
lines changed

3 files changed

+66
-42
lines changed

include/swift/AST/Decl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9525,7 +9525,8 @@ class MacroExpansionDecl : public Decl, public FreestandingMacroExpansion {
95259525
/// specifies the index of the parameter that shall be skipped.
95269526
GenericParameterReferenceInfo
95279527
findGenericParameterReferences(const ValueDecl *value, CanGenericSignature sig,
9528-
GenericTypeParamType *genericParam,
9528+
GenericTypeParamType *origParam,
9529+
GenericTypeParamType *openedParam,
95299530
std::optional<unsigned> skipParamIndex);
95309531

95319532
inline void

lib/AST/Decl.cpp

Lines changed: 63 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4825,7 +4825,9 @@ GenericParameterReferenceInfo::operator|=(const GenericParameterReferenceInfo &o
48254825

48264826
/// Forward declaration.
48274827
static GenericParameterReferenceInfo
4828-
findGenericParameterReferencesRec(CanGenericSignature, GenericTypeParamType *,
4828+
findGenericParameterReferencesRec(CanGenericSignature,
4829+
GenericTypeParamType *,
4830+
GenericTypeParamType *,
48294831
Type, TypePosition, bool);
48304832

48314833
/// Determine whether a function type with the given result type may have
@@ -4846,7 +4848,9 @@ static bool canResultTypeHaveCovariantGenericParameterResult(Type resultTy) {
48464848
/// \param position The current position in terms of variance.
48474849
/// \param skipParamIndex The index of the parameter that shall be skipped.
48484850
static GenericParameterReferenceInfo findGenericParameterReferencesInFunction(
4849-
CanGenericSignature genericSig, GenericTypeParamType *genericParam,
4851+
CanGenericSignature genericSig,
4852+
GenericTypeParamType *origParam,
4853+
GenericTypeParamType *openedParam,
48504854
const AnyFunctionType *fnType, TypePosition position,
48514855
bool canBeCovariantResult, std::optional<unsigned> skipParamIndex) {
48524856
// If there are no type parameters, we're done.
@@ -4864,7 +4868,7 @@ static GenericParameterReferenceInfo findGenericParameterReferencesInFunction(
48644868
// inout types are invariant.
48654869
if (param.isInOut()) {
48664870
inputInfo |= ::findGenericParameterReferencesRec(
4867-
genericSig, genericParam, param.getPlainType(),
4871+
genericSig, origParam, openedParam, param.getPlainType(),
48684872
TypePosition::Invariant, /*canBeCovariantResult=*/false);
48694873
continue;
48704874
}
@@ -4877,7 +4881,7 @@ static GenericParameterReferenceInfo findGenericParameterReferencesInFunction(
48774881
paramPos = TypePosition::Invariant;
48784882

48794883
inputInfo |= ::findGenericParameterReferencesRec(
4880-
genericSig, genericParam, param.getParameterType(), paramPos,
4884+
genericSig, origParam, openedParam, param.getParameterType(), paramPos,
48814885
/*canBeCovariantResult=*/false);
48824886
}
48834887

@@ -4887,7 +4891,7 @@ static GenericParameterReferenceInfo findGenericParameterReferencesInFunction(
48874891
canResultTypeHaveCovariantGenericParameterResult(fnType->getResult());
48884892

48894893
const auto resultInfo = ::findGenericParameterReferencesRec(
4890-
genericSig, genericParam, fnType->getResult(),
4894+
genericSig, origParam, openedParam, fnType->getResult(),
48914895
position, canBeCovariantResult);
48924896

48934897
return inputInfo |= resultInfo;
@@ -4899,7 +4903,8 @@ static GenericParameterReferenceInfo findGenericParameterReferencesInFunction(
48994903
/// \param position The current position in terms of variance.
49004904
static GenericParameterReferenceInfo
49014905
findGenericParameterReferencesRec(CanGenericSignature genericSig,
4902-
GenericTypeParamType *genericParam,
4906+
GenericTypeParamType *origParam,
4907+
GenericTypeParamType *openedParam,
49034908
Type type,
49044909
TypePosition position,
49054910
bool canBeCovariantResult) {
@@ -4912,7 +4917,7 @@ findGenericParameterReferencesRec(CanGenericSignature genericSig,
49124917
auto info = GenericParameterReferenceInfo();
49134918
for (auto &elt : tuple->getElements()) {
49144919
info |= findGenericParameterReferencesRec(
4915-
genericSig, genericParam, elt.getType(), position,
4920+
genericSig, origParam, openedParam, elt.getType(), position,
49164921
/*canBeCovariantResult=*/false);
49174922
}
49184923

@@ -4923,27 +4928,28 @@ findGenericParameterReferencesRec(CanGenericSignature genericSig,
49234928
// the parameter type.
49244929
if (auto funcTy = type->getAs<AnyFunctionType>()) {
49254930
return findGenericParameterReferencesInFunction(
4926-
genericSig, genericParam, funcTy,
4931+
genericSig, origParam, openedParam, funcTy,
49274932
position, canBeCovariantResult,
49284933
/*skipParamIndex=*/std::nullopt);
49294934
}
49304935

49314936
// Metatypes preserve variance.
49324937
if (auto metaTy = type->getAs<MetatypeType>()) {
4933-
return findGenericParameterReferencesRec(genericSig, genericParam,
4938+
return findGenericParameterReferencesRec(genericSig, origParam, openedParam,
49344939
metaTy->getInstanceType(),
49354940
position, canBeCovariantResult);
49364941
}
49374942

49384943
// Optionals preserve variance.
49394944
if (auto optType = type->getOptionalObjectType()) {
49404945
return findGenericParameterReferencesRec(
4941-
genericSig, genericParam, optType, position, canBeCovariantResult);
4946+
genericSig, origParam, openedParam, optType,
4947+
position, canBeCovariantResult);
49424948
}
49434949

49444950
// DynamicSelfType preserves variance.
49454951
if (auto selfType = type->getAs<DynamicSelfType>()) {
4946-
return findGenericParameterReferencesRec(genericSig, genericParam,
4952+
return findGenericParameterReferencesRec(genericSig, origParam, openedParam,
49474953
selfType->getSelfType(), position,
49484954
/*canBeCovariantResult=*/false);
49494955
}
@@ -4954,7 +4960,7 @@ findGenericParameterReferencesRec(CanGenericSignature genericSig,
49544960
// Don't forget to look in the parent.
49554961
if (const auto parent = nominal->getParent()) {
49564962
info |= findGenericParameterReferencesRec(
4957-
genericSig, genericParam, parent, TypePosition::Invariant,
4963+
genericSig, origParam, openedParam, parent, TypePosition::Invariant,
49584964
/*canBeCovariantResult=*/false);
49594965
}
49604966

@@ -4963,20 +4969,20 @@ findGenericParameterReferencesRec(CanGenericSignature genericSig,
49634969
if (bgt->isArray()) {
49644970
// Swift.Array preserves variance in its 'Value' type.
49654971
info |= findGenericParameterReferencesRec(
4966-
genericSig, genericParam, bgt->getGenericArgs().front(),
4972+
genericSig, origParam, openedParam, bgt->getGenericArgs().front(),
49674973
position, /*canBeCovariantResult=*/false);
49684974
} else if (bgt->isDictionary()) {
49694975
// Swift.Dictionary preserves variance in its 'Element' type.
49704976
info |= findGenericParameterReferencesRec(
4971-
genericSig, genericParam, bgt->getGenericArgs().front(),
4977+
genericSig, origParam, openedParam, bgt->getGenericArgs().front(),
49724978
TypePosition::Invariant, /*canBeCovariantResult=*/false);
49734979
info |= findGenericParameterReferencesRec(
4974-
genericSig, genericParam, bgt->getGenericArgs().back(),
4980+
genericSig, origParam, openedParam, bgt->getGenericArgs().back(),
49754981
position, /*canBeCovariantResult=*/false);
49764982
} else {
49774983
for (const auto &paramType : bgt->getGenericArgs()) {
49784984
info |= findGenericParameterReferencesRec(
4979-
genericSig, genericParam, paramType,
4985+
genericSig, origParam, openedParam, paramType,
49804986
TypePosition::Invariant, /*canBeCovariantResult=*/false);
49814987
}
49824988
}
@@ -5001,14 +5007,14 @@ findGenericParameterReferencesRec(CanGenericSignature genericSig,
50015007

50025008
case RequirementKind::SameType:
50035009
info |= findGenericParameterReferencesRec(
5004-
genericSig, genericParam, req.getFirstType(),
5010+
genericSig, origParam, openedParam, req.getFirstType(),
50055011
TypePosition::Invariant, /*canBeCovariantResult=*/false);
50065012

50075013
LLVM_FALLTHROUGH;
50085014

50095015
case RequirementKind::Superclass:
50105016
info |= findGenericParameterReferencesRec(
5011-
genericSig, genericParam, req.getSecondType(),
5017+
genericSig, origParam, openedParam, req.getSecondType(),
50125018
TypePosition::Invariant, /*canBeCovariantResult=*/false);
50135019
break;
50145020
}
@@ -5026,7 +5032,7 @@ findGenericParameterReferencesRec(CanGenericSignature genericSig,
50265032

50275033
for (auto member : comp->getMembers()) {
50285034
info |= findGenericParameterReferencesRec(
5029-
genericSig, genericParam, member,
5035+
genericSig, origParam, openedParam, member,
50305036
TypePosition::Invariant, /*canBeCovariantResult=*/false);
50315037
}
50325038

@@ -5039,7 +5045,7 @@ findGenericParameterReferencesRec(CanGenericSignature genericSig,
50395045

50405046
for (auto arg : pack->getElementTypes()) {
50415047
info |= findGenericParameterReferencesRec(
5042-
genericSig, genericParam, arg,
5048+
genericSig, origParam, openedParam, arg,
50435049
TypePosition::Invariant, /*canBeCovariantResult=*/false);
50445050
}
50455051

@@ -5049,7 +5055,7 @@ findGenericParameterReferencesRec(CanGenericSignature genericSig,
50495055
// Pack expansions are invariant.
50505056
if (auto *expansion = type->getAs<PackExpansionType>()) {
50515057
return findGenericParameterReferencesRec(
5052-
genericSig, genericParam, expansion->getPatternType(),
5058+
genericSig, origParam, openedParam, expansion->getPatternType(),
50535059
TypePosition::Invariant, /*canBeCovariantResult=*/false);
50545060
}
50555061

@@ -5067,29 +5073,46 @@ findGenericParameterReferencesRec(CanGenericSignature genericSig,
50675073
abort();
50685074
}
50695075

5070-
Type selfTy(genericParam);
5071-
if (!type->getRootGenericParam()->isEqual(selfTy)) {
5076+
if (!type->getRootGenericParam()->isEqual(origParam)) {
50725077
return GenericParameterReferenceInfo();
50735078
}
50745079

50755080
// A direct reference to 'Self'.
5076-
if (selfTy->isEqual(type)) {
5081+
if (type->is<GenericTypeParamType>()) {
50775082
if (position == TypePosition::Covariant && canBeCovariantResult)
50785083
return GenericParameterReferenceInfo::forCovariantResult();
50795084

50805085
return GenericParameterReferenceInfo::forSelfRef(position);
50815086
}
50825087

5083-
// If the type parameter is beyond the domain of the existential generic
5084-
// signature, ignore it.
5085-
if (!genericSig->isValidTypeParameter(type)) {
5086-
return GenericParameterReferenceInfo();
5087-
}
5088-
5089-
if (const auto concreteTy = genericSig->getConcreteType(type)) {
5090-
return findGenericParameterReferencesRec(
5091-
genericSig, genericParam, concreteTy,
5092-
position, canBeCovariantResult);
5088+
if (origParam != openedParam) {
5089+
// Replace the original parameter with the parameter in the opened
5090+
// signature.
5091+
type = type.subst(
5092+
[&](SubstitutableType *type) {
5093+
ASSERT(type == origParam);
5094+
return openedParam;
5095+
},
5096+
MakeAbstractConformanceForGenericType());
5097+
}
5098+
5099+
if (genericSig) {
5100+
// If the type parameter is beyond the domain of the opened
5101+
// signature, ignore it.
5102+
if (!genericSig->isValidTypeParameter(type)) {
5103+
return GenericParameterReferenceInfo();
5104+
}
5105+
5106+
if (auto reducedTy = genericSig.getReducedType(type)) {
5107+
if (!reducedTy->isEqual(type)) {
5108+
// Note: origParam becomes openedParam for the recursive call,
5109+
// because concreteTy is written in terms of genericSig and not
5110+
// the signature of the old origParam.
5111+
return findGenericParameterReferencesRec(
5112+
CanGenericSignature(), openedParam, openedParam, reducedTy,
5113+
position, canBeCovariantResult);
5114+
}
5115+
}
50935116
}
50945117

50955118
// A reference to an associated type rooted on 'Self'.
@@ -5099,11 +5122,10 @@ findGenericParameterReferencesRec(CanGenericSignature genericSig,
50995122
GenericParameterReferenceInfo
51005123
swift::findGenericParameterReferences(const ValueDecl *value,
51015124
CanGenericSignature sig,
5102-
GenericTypeParamType *genericParam,
5125+
GenericTypeParamType *origParam,
5126+
GenericTypeParamType *openedParam,
51035127
std::optional<unsigned> skipParamIndex) {
51045128
assert(!isa<TypeDecl>(value));
5105-
assert(sig->getGenericParamOrdinal(genericParam) <
5106-
sig.getGenericParams().size());
51075129

51085130
auto type = value->getInterfaceType();
51095131

@@ -5118,12 +5140,12 @@ swift::findGenericParameterReferences(const ValueDecl *value,
51185140
type = type->castTo<AnyFunctionType>()->getResult();
51195141

51205142
return ::findGenericParameterReferencesInFunction(
5121-
sig, genericParam, type->castTo<AnyFunctionType>(),
5143+
sig, origParam, openedParam, type->castTo<AnyFunctionType>(),
51225144
TypePosition::Covariant, /*canBeCovariantResult=*/true,
51235145
skipParamIndex);
51245146
}
51255147

5126-
return ::findGenericParameterReferencesRec(sig, genericParam, type,
5148+
return ::findGenericParameterReferencesRec(sig, origParam, openedParam, type,
51275149
TypePosition::Covariant,
51285150
/*canBeCovariantResult=*/true);
51295151
}
@@ -5148,7 +5170,8 @@ GenericParameterReferenceInfo ValueDecl::findExistentialSelfReferences(
51485170
GenericSignature());
51495171

51505172
auto genericParam = sig.getGenericParams().front();
5151-
return findGenericParameterReferences(this, sig, genericParam, std::nullopt);
5173+
return findGenericParameterReferences(this, sig, genericParam, genericParam,
5174+
std::nullopt);
51525175
}
51535176

51545177
TypeDecl::CanBeInvertible::Result

lib/Sema/CSSimplify.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,7 @@ shouldOpenExistentialCallArgument(ValueDecl *callee, unsigned paramIdx,
16071607
// Ensure that the formal parameter is only used in covariant positions,
16081608
// because it won't match anywhere else.
16091609
auto referenceInfo = findGenericParameterReferences(
1610-
callee, genericSig, genericParam,
1610+
callee, genericSig, genericParam, genericParam,
16111611
/*skipParamIdx=*/paramIdx);
16121612
if (referenceInfo.selfRef > TypePosition::Covariant ||
16131613
referenceInfo.assocTypeRef > TypePosition::Covariant)

0 commit comments

Comments
 (0)