Skip to content

Lift Requirement and Parameter Accessors up to GenericSignature #38403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/swift/AST/ASTMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class ASTMangler : public Mangler {

void appendRequirement(const Requirement &reqt);

void appendGenericSignatureParts(TypeArrayView<GenericTypeParamType> params,
void appendGenericSignatureParts(ArrayRef<CanTypeWrapper<GenericTypeParamType>> params,
unsigned initialParamDepth,
ArrayRef<Requirement> requirements);

Expand Down
15 changes: 3 additions & 12 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -727,10 +727,7 @@ template <> struct DenseMapInfo<AutoDiffConfig> {
}

static unsigned getHashValue(const AutoDiffConfig &Val) {
auto canGenSig =
Val.derivativeGenericSignature
? Val.derivativeGenericSignature->getCanonicalSignature()
: nullptr;
auto canGenSig = Val.derivativeGenericSignature.getCanonicalSignature();
unsigned combinedHash = hash_combine(
~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices),
DenseMapInfo<void *>::getHashValue(Val.resultIndices),
Expand All @@ -739,14 +736,8 @@ template <> struct DenseMapInfo<AutoDiffConfig> {
}

static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
auto lhsCanGenSig =
LHS.derivativeGenericSignature
? LHS.derivativeGenericSignature->getCanonicalSignature()
: nullptr;
auto rhsCanGenSig =
RHS.derivativeGenericSignature
? RHS.derivativeGenericSignature->getCanonicalSignature()
: nullptr;
auto lhsCanGenSig = LHS.derivativeGenericSignature.getCanonicalSignature();
auto rhsCanGenSig = RHS.derivativeGenericSignature.getCanonicalSignature();
return LHS.parameterIndices == RHS.parameterIndices &&
LHS.resultIndices == RHS.resultIndices &&
DenseMapInfo<GenericSignature>::isEqual(lhsCanGenSig, rhsCanGenSig);
Expand Down
78 changes: 50 additions & 28 deletions include/swift/AST/GenericSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,27 @@ class GenericSignature {
// first, or use isEqual.
void operator==(GenericSignature T) const = delete;
void operator!=(GenericSignature T) const = delete;

public:
/// Retrieve the generic parameters.
TypeArrayView<GenericTypeParamType> getGenericParams() const;

/// Retrieve the innermost generic parameters.
///
/// Given a generic signature for a nested generic type, produce an
/// array of the generic parameters for the innermost generic type.
TypeArrayView<GenericTypeParamType> getInnermostGenericParams() const;

/// Retrieve the requirements.
ArrayRef<Requirement> getRequirements() const;

/// Whether this generic signature involves a type variable.
bool hasTypeVariable() const;

/// Returns the generic environment that provides fresh contextual types
/// (archetypes) that correspond to the interface types in this generic
/// signature.
GenericEnvironment *getGenericEnvironment() const;
};

/// A reference to a canonical generic signature.
Expand Down Expand Up @@ -255,23 +276,6 @@ class alignas(1 << TypeAlignInBits) GenericSignatureImpl final
friend class ArchetypeType;

public:
/// Retrieve the generic parameters.
TypeArrayView<GenericTypeParamType> getGenericParams() const {
return TypeArrayView<GenericTypeParamType>(
{getTrailingObjects<Type>(), NumGenericParams});
}

/// Retrieve the innermost generic parameters.
///
/// Given a generic signature for a nested generic type, produce an
/// array of the generic parameters for the innermost generic type.
TypeArrayView<GenericTypeParamType> getInnermostGenericParams() const;

/// Retrieve the requirements.
ArrayRef<Requirement> getRequirements() const {
return {getTrailingObjects<Requirement>(), NumRequirements};
}

/// Only allow allocation by doing a placement new.
void *operator new(size_t Bytes, void *Mem) {
assert(Mem);
Expand Down Expand Up @@ -312,20 +316,12 @@ class alignas(1 << TypeAlignInBits) GenericSignatureImpl final

ASTContext &getASTContext() const;

/// Returns the canonical generic signature. The result is cached.
CanGenericSignature getCanonicalSignature() const;

/// Retrieve the generic signature builder for the given generic signature.
GenericSignatureBuilder *getGenericSignatureBuilder() const;

/// Retrieve the requirement machine for the given generic signature.
RequirementMachine *getRequirementMachine() const;

/// Returns the generic environment that provides fresh contextual types
/// (archetypes) that correspond to the interface types in this generic
/// signature.
GenericEnvironment *getGenericEnvironment() const;

/// Collects a set of requirements on a type parameter. Used by
/// GenericEnvironment for building archetypes.
GenericSignature::LocalRequirements getLocalRequirements(Type depType) const;
Expand Down Expand Up @@ -428,9 +424,6 @@ class alignas(1 << TypeAlignInBits) GenericSignatureImpl final
/// generic parameter types by their sugared form.
Type getSugaredType(Type type) const;

/// Whether this generic signature involves a type variable.
bool hasTypeVariable() const;

static void Profile(llvm::FoldingSetNodeID &ID,
TypeArrayView<GenericTypeParamType> genericParams,
ArrayRef<Requirement> requirements);
Expand All @@ -439,6 +432,35 @@ class alignas(1 << TypeAlignInBits) GenericSignatureImpl final
void print(ASTPrinter &Printer, PrintOptions Opts = PrintOptions()) const;
SWIFT_DEBUG_DUMP;
std::string getAsString() const;

private:
friend GenericSignature;
friend CanGenericSignature;

/// Retrieve the generic parameters.
TypeArrayView<GenericTypeParamType> getGenericParams() const {
return TypeArrayView<GenericTypeParamType>(
{getTrailingObjects<Type>(), NumGenericParams});
}

/// Retrieve the innermost generic parameters.
///
/// Given a generic signature for a nested generic type, produce an
/// array of the generic parameters for the innermost generic type.
TypeArrayView<GenericTypeParamType> getInnermostGenericParams() const;

/// Retrieve the requirements.
ArrayRef<Requirement> getRequirements() const {
return {getTrailingObjects<Requirement>(), NumRequirements};
}

/// Returns the canonical generic signature. The result is cached.
CanGenericSignature getCanonicalSignature() const;

/// Returns the generic environment that provides fresh contextual types
/// (archetypes) that correspond to the interface types in this generic
/// signature.
GenericEnvironment *getGenericEnvironment() const;
};

void simple_display(raw_ostream &out, GenericSignature sig);
Expand Down
53 changes: 25 additions & 28 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -913,8 +913,7 @@ static VarDecl *getPointeeProperty(VarDecl *&cache,
NominalTypeDecl *nominal = (ctx.*getNominal)();
if (!nominal) return nullptr;
auto sig = nominal->getGenericSignature();
if (!sig) return nullptr;
if (sig->getGenericParams().size() != 1) return nullptr;
if (sig.getGenericParams().size() != 1) return nullptr;

// There must be a property named "pointee".
auto identifier = ctx.getIdentifier("pointee");
Expand All @@ -924,7 +923,7 @@ static VarDecl *getPointeeProperty(VarDecl *&cache,
// The property must have type T.
auto *property = dyn_cast<VarDecl>(results[0]);
if (!property) return nullptr;
if (!property->getInterfaceType()->isEqual(sig->getGenericParams()[0]))
if (!property->getInterfaceType()->isEqual(sig.getGenericParams()[0]))
return nullptr;

cache = property;
Expand Down Expand Up @@ -1776,8 +1775,8 @@ static AllocationArena getArena(GenericSignature genericSig) {
if (!genericSig)
return AllocationArena::Permanent;

if (genericSig->hasTypeVariable()) {
assert(false && "What's going on");
if (genericSig.hasTypeVariable()) {
assert(false && "Unsubstituted type variable leaked into generic signature");
return AllocationArena::ConstraintSolver;
}

Expand Down Expand Up @@ -1844,13 +1843,13 @@ GenericSignatureBuilder *ASTContext::getOrCreateGenericSignatureBuilder(
reprocessedSig->print(llvm::errs());
llvm::errs() << "\n";

if (sig->getGenericParams().size() ==
reprocessedSig->getGenericParams().size() &&
sig->getRequirements().size() ==
reprocessedSig->getRequirements().size()) {
for (unsigned i : indices(sig->getRequirements())) {
auto sigReq = sig->getRequirements()[i];
auto reprocessedReq = reprocessedSig->getRequirements()[i];
if (sig.getGenericParams().size() ==
reprocessedSig.getGenericParams().size() &&
sig.getRequirements().size() ==
reprocessedSig.getRequirements().size()) {
for (unsigned i : indices(sig.getRequirements())) {
auto sigReq = sig.getRequirements()[i];
auto reprocessedReq = reprocessedSig.getRequirements()[i];
if (sigReq.getKind() != reprocessedReq.getKind()) {
llvm::errs() << "Requirement mismatch:\n";
llvm::errs() << " Original: ";
Expand Down Expand Up @@ -1895,7 +1894,7 @@ GenericSignatureBuilder *ASTContext::getOrCreateGenericSignatureBuilder(

RequirementMachine *ASTContext::getOrCreateRequirementMachine(
CanGenericSignature sig) {
assert(!sig->hasTypeVariable());
assert(!sig.hasTypeVariable());

auto &rewriteCtx = getImpl().TheRewriteContext;
if (!rewriteCtx)
Expand Down Expand Up @@ -3608,12 +3607,12 @@ GenericTypeParamType *GenericTypeParamType::get(unsigned depth, unsigned index,

TypeArrayView<GenericTypeParamType>
GenericFunctionType::getGenericParams() const {
return Signature->getGenericParams();
return Signature.getGenericParams();
}

/// Retrieve the requirements of this polymorphic function type.
ArrayRef<Requirement> GenericFunctionType::getRequirements() const {
return Signature->getRequirements();
return Signature.getRequirements();
}

void SILFunctionType::Profile(
Expand Down Expand Up @@ -3747,7 +3746,7 @@ SILFunctionType::SILFunctionType(
"If all generic parameters are concrete, SILFunctionType should "
"not have a generic signature at all");

for (auto gparam : genericSig->getGenericParams()) {
for (auto gparam : genericSig.getGenericParams()) {
(void)gparam;
assert(gparam->isCanonical() && "generic signature is not canonicalized");
}
Expand Down Expand Up @@ -4123,7 +4122,7 @@ OpaqueTypeArchetypeType::get(OpaqueTypeDecl *Decl,
// Same-type-constrain the arguments in the outer signature to their
// replacements in the substitution map.
if (auto outerSig = Decl->getGenericSignature()) {
for (auto outerParam : outerSig->getGenericParams()) {
for (auto outerParam : outerSig.getGenericParams()) {
auto boundType = Type(outerParam).subst(Substitutions);
newRequirements.push_back(
Requirement(RequirementKind::SameType, Type(outerParam), boundType));
Expand All @@ -4138,7 +4137,7 @@ OpaqueTypeArchetypeType::get(OpaqueTypeDecl *Decl,
(void)newRequirements;
# ifndef NDEBUG
for (auto reqt :
Decl->getOpaqueInterfaceGenericSignature()->getRequirements()) {
Decl->getOpaqueInterfaceGenericSignature().getRequirements()) {
auto reqtBase = reqt.getFirstType()->getRootGenericParam();
if (reqtBase->isEqual(Decl->getUnderlyingInterfaceType())) {
assert(reqt.getKind() != RequirementKind::SameType
Expand Down Expand Up @@ -4261,7 +4260,7 @@ GenericEnvironment *OpenedArchetypeType::getGenericEnvironment() const {
// Create a generic environment to represent the opened type.
auto signature = ctx.getOpenedArchetypeSignature(Opened);
auto *env = GenericEnvironment::getIncomplete(signature);
env->addMapping(signature->getGenericParams()[0], thisType);
env->addMapping(signature.getGenericParams().front().getPointer(), thisType);
Environment = env;

return env;
Expand Down Expand Up @@ -4419,7 +4418,7 @@ GenericEnvironment *GenericEnvironment::getIncomplete(
auto &ctx = signature->getASTContext();

// Allocate and construct the new environment.
unsigned numGenericParams = signature->getGenericParams().size();
unsigned numGenericParams = signature.getGenericParams().size();
size_t bytes = totalSizeToAlloc<Type>(numGenericParams);
void *mem = ctx.Allocate(bytes, alignof(GenericEnvironment));
return new (mem) GenericEnvironment(signature);
Expand Down Expand Up @@ -4913,7 +4912,7 @@ CanGenericSignature ASTContext::getOpenedArchetypeSignature(Type type) {
// The opened archetype signature for a protocol type is identical
// to the protocol's own canonical generic signature.
if (const auto protoTy = dyn_cast<ProtocolType>(existential)) {
return protoTy->getDecl()->getGenericSignature()->getCanonicalSignature();
return protoTy->getDecl()->getGenericSignature().getCanonicalSignature();
}

auto found = getImpl().ExistentialSignatures.find(existential);
Expand Down Expand Up @@ -4982,9 +4981,9 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
unsigned derivedDepth = 0;
unsigned baseDepth = 0;
if (derivedClassSig)
derivedDepth = derivedClassSig->getGenericParams().back()->getDepth() + 1;
derivedDepth = derivedClassSig.getGenericParams().back()->getDepth() + 1;
if (const auto baseClassSig = baseClass->getGenericSignature())
baseDepth = baseClassSig->getGenericParams().back()->getDepth() + 1;
baseDepth = baseClassSig.getGenericParams().back()->getDepth() + 1;

SmallVector<GenericTypeParamType *, 2> addedGenericParams;
if (const auto *gpList = derived->getAsGenericContext()->getGenericParams()) {
Expand Down Expand Up @@ -5018,7 +5017,7 @@ ASTContext::getOverrideGenericSignature(const ValueDecl *base,
};

SmallVector<Requirement, 2> addedRequirements;
for (auto reqt : baseGenericSig->getRequirements()) {
for (auto reqt : baseGenericSig.getRequirements()) {
if (auto substReqt = reqt.subst(substFn, lookupConformanceFn)) {
addedRequirements.push_back(*substReqt);
}
Expand Down Expand Up @@ -5116,7 +5115,7 @@ CanSILBoxType SILBoxType::get(ASTContext &C,
CanSILBoxType SILBoxType::get(CanType boxedType) {
auto &ctx = boxedType->getASTContext();
auto singleGenericParamSignature = ctx.getSingleGenericParameterSignature();
auto genericParam = singleGenericParamSignature->getGenericParams()[0];
auto genericParam = singleGenericParamSignature.getGenericParams()[0];
auto layout = SILLayout::get(ctx, singleGenericParamSignature,
SILField(CanType(genericParam),
/*mutable*/ true));
Expand Down Expand Up @@ -5274,9 +5273,7 @@ AutoDiffDerivativeFunctionIdentifier *AutoDiffDerivativeFunctionIdentifier::get(
llvm::FoldingSetNodeID id;
id.AddInteger((unsigned)kind);
id.AddPointer(parameterIndices);
CanGenericSignature derivativeCanGenSig;
if (derivativeGenericSignature)
derivativeCanGenSig = derivativeGenericSignature->getCanonicalSignature();
auto derivativeCanGenSig = derivativeGenericSignature.getCanonicalSignature();
id.AddPointer(derivativeCanGenSig.getPointer());

void *insertPos;
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/ASTDemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ createSubstitutionMapFromGenericArgs(GenericSignature genericSig,
if (!genericSig)
return SubstitutionMap();

if (genericSig->getGenericParams().size() != args.size())
if (genericSig.getGenericParams().size() != args.size())
return SubstitutionMap();

return SubstitutionMap::get(
Expand Down Expand Up @@ -306,7 +306,7 @@ Type ASTBuilder::createBoundGenericType(GenericTypeDecl *decl,

auto genericSig = aliasDecl->getGenericSignature();
for (unsigned i = 0, e = args.size(); i < e; ++i) {
auto origTy = genericSig->getInnermostGenericParams()[i];
auto origTy = genericSig.getInnermostGenericParams()[i];
auto substTy = args[i];

subs[origTy->getCanonicalType()->castTo<GenericTypeParamType>()] =
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/ASTDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3286,7 +3286,7 @@ static void dumpSubstitutionMapRec(
}

genericSig->print(out);
auto genericParams = genericSig->getGenericParams();
auto genericParams = genericSig.getGenericParams();
auto replacementTypes =
static_cast<const SubstitutionMap &>(map).getReplacementTypesBuffer();
for (unsigned i : indices(genericParams)) {
Expand Down Expand Up @@ -3315,7 +3315,7 @@ static void dumpSubstitutionMapRec(
return;

auto conformances = map.getConformances();
for (const auto &req : genericSig->getRequirements()) {
for (const auto &req : genericSig.getRequirements()) {
if (req.getKind() != RequirementKind::Conformance)
continue;

Expand Down
Loading