Skip to content

Preliminary support for substitution with nested pack expansions #66096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/ABI/Mangling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,8 @@ Types
type ::= assoc-type-list 'QZ' // shortcut for 'QYz'
type ::= opaque-type-decl-name bound-generic-args 'Qo' INDEX // opaque type

type ::= pack-type 'Qe' INDEX // pack element type

type ::= pattern-type count-type 'Qp' // pack expansion type
type ::= pack-element-list 'QP' // pack type
type ::= pack-element-list 'QS' DIRECTNESS // SIL pack type
Expand Down
18 changes: 3 additions & 15 deletions include/swift/AST/InFlightSubstitution.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,15 @@ class InFlightSubstitution {
InFlightSubstitution(const InFlightSubstitution &) = delete;
InFlightSubstitution &operator=(const InFlightSubstitution &) = delete;

// TODO: when we add PackElementType, we should recognize it during
// substitution and either call different methods on this class or
// pass an extra argument for the pack-expansion depth D. We should
// be able to rely on that to mark a pack-element reference instead
// of checking whether the original type was a pack. Substitution
// should use the D'th entry from the end of ActivePackExpansions to
// guide the element substitution:
// - project the given index of the pack substitution
// - wrap it in a PackElementType if it's a subst expansion
// - the depth of that PackElementType is the number of subst
// expansions between the depth entry and the end of
// ActivePackExpansions

/// Perform primitive substitution on the given type. Returns Type()
/// if the type should not be substituted as a whole.
Type substType(SubstitutableType *origType);
Type substType(SubstitutableType *origType, unsigned level);

/// Perform primitive conformance lookup on the given type.
ProtocolConformanceRef lookupConformance(CanType dependentType,
Type conformingReplacementType,
ProtocolDecl *conformedProtocol);
ProtocolDecl *conformedProtocol,
unsigned level);

/// Given the shape type of a pack expansion, invoke the given callback
/// for each expanded component of it. If the substituted component
Expand Down
6 changes: 5 additions & 1 deletion include/swift/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,11 @@ enum class SubstFlags {
/// Map member types to their desugared witness type.
DesugarMemberTypes = 0x02,
/// Substitute types involving opaque type archetypes.
SubstituteOpaqueArchetypes = 0x04
SubstituteOpaqueArchetypes = 0x04,
/// Don't increase pack expansion level for free pack references.
/// Do not introduce new usages of this flag.
/// FIXME: Remove this.
PreservePackExpansionLevel = 0x08,
};

/// Options for performing substitutions into a type.
Expand Down
9 changes: 7 additions & 2 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -724,8 +724,10 @@ class alignas(1 << TypeAlignInBits) TypeBase
SmallVectorImpl<OpenedArchetypeType *> &rootOpenedArchetypes) const;

/// Retrieve the set of type parameter packs that occur within this type.
void getTypeParameterPacks(
SmallVectorImpl<Type> &rootParameterPacks);
void getTypeParameterPacks(SmallVectorImpl<Type> &rootParameterPacks);

/// Retrieve the set of type parameter packs that occur within this type.
void walkPackReferences(llvm::function_ref<bool (Type)> fn);

/// Replace opened archetypes with the given root with their most
/// specific non-dependent upper bounds throughout this type.
Expand Down Expand Up @@ -792,6 +794,9 @@ class alignas(1 << TypeAlignInBits) TypeBase
/// this function will wrap into a pack containing a singleton expansion.
PackType *getPackSubstitutionAsPackType();

/// Increase the expansion level of each parameter pack appearing in this type.
Type increasePackElementLevel(unsigned level);

/// Determines whether this type is an lvalue. This includes both straight
/// lvalue types as well as tuples or optionals of lvalues.
bool hasLValueType() {
Expand Down
1 change: 1 addition & 0 deletions include/swift/Demangling/DemangleNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ NODE(SILPackDirect)
NODE(SILPackIndirect)
NODE(PackExpansion)
NODE(PackElement)
NODE(PackElementLevel)
NODE(Type)
CONTEXT_NODE(TypeSymbolicReference)
CONTEXT_NODE(TypeAlias)
Expand Down
14 changes: 9 additions & 5 deletions include/swift/SIL/AbstractionPattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -938,10 +938,9 @@ class AbstractionPattern {
case Kind::Discard: {
auto type = getType();
if (isa<DependentMemberType>(type) ||
isa<GenericTypeParamType>(type)) {
return true;
}
if (isa<ArchetypeType>(type)) {
isa<GenericTypeParamType>(type) ||
isa<PackElementType>(type) ||
isa<ArchetypeType>(type)) {
return true;
}
return false;
Expand All @@ -960,7 +959,8 @@ class AbstractionPattern {
case Kind::Discard: {
auto type = getType();
if (isa<DependentMemberType>(type) ||
isa<GenericTypeParamType>(type)) {
isa<GenericTypeParamType>(type) ||
isa<PackElementType>(type)) {
return true;
}
if (auto archetype = dyn_cast<ArchetypeType>(type)) {
Expand Down Expand Up @@ -1448,6 +1448,10 @@ class AbstractionPattern {
/// the abstraction pattern for an element type.
AbstractionPattern getTupleElementType(unsigned index) const;

/// Given that the value being abstracted is a pack element type, return
/// the abstraction pattern for its pack type.
AbstractionPattern getPackElementPackType() const;

/// Given that the value being abstracted is a pack type, return
/// the abstraction pattern for an element type.
AbstractionPattern getPackElementType(unsigned index) const;
Expand Down
8 changes: 5 additions & 3 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,8 @@ static Type getTypeForDWARFMangling(Type t) {
return t;
},
MakeAbstractConformanceForGenericType(),
SubstFlags::AllowLoweredTypes);
SubstFlags::AllowLoweredTypes |
SubstFlags::PreservePackExpansionLevel);
}

std::string ASTMangler::mangleTypeForDebugger(Type Ty, GenericSignature sig) {
Expand Down Expand Up @@ -1288,8 +1289,9 @@ void ASTMangler::appendType(Type type, GenericSignature sig,
case TypeKind::PackElement: {
auto elementType = cast<PackElementType>(tybase);
appendType(elementType->getPackType(), sig, forDecl);

// FIXME: append expansion depth
// If this ever changes, just mangle level 0 as a plain type parameter.
assert(elementType->getLevel() > 0);
appendOperator("Qe", Index(elementType->getLevel() - 1));

return;
}
Expand Down
14 changes: 13 additions & 1 deletion lib/AST/ASTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6075,12 +6075,24 @@ class TypePrinter : public TypeVisitor<TypePrinter> {
}

void visitPackExpansionType(PackExpansionType *T) {
SmallVector<Type, 2> rootParameterPacks;
T->getPatternType()->getTypeParameterPacks(rootParameterPacks);

if (rootParameterPacks.empty() &&
(T->getCountType()->isParameterPack() ||
T->getCountType()->is<PackArchetypeType>())) {
Printer << "/* shape: ";
visit(T->getCountType());
Printer << " */ ";
}

Printer << "repeat ";

visit(T->getPatternType());
}

void visitPackElementType(PackElementType *T) {
Printer << "@level(" << T->getLevel() << ") ";
Printer << "/* level: " << T->getLevel() << " */ ";
visit(T->getPackType());
}

Expand Down
29 changes: 17 additions & 12 deletions lib/AST/GenericEnvironment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ Type TypeBase::mapTypeOutOfContext() {
assert(!hasTypeParameter() && "already have an interface type");
return Type(this).subst(MapTypeOutOfContext(),
MakeAbstractConformanceForGenericType(),
SubstFlags::AllowLoweredTypes);
SubstFlags::AllowLoweredTypes |
SubstFlags::PreservePackExpansionLevel);
}

class GenericEnvironment::NestedTypeStorage
Expand Down Expand Up @@ -637,7 +638,8 @@ Type GenericEnvironment::mapTypeIntoContext(
type = maybeApplyOuterContextSubstitutions(type);
Type result = type.subst(QueryInterfaceTypeSubstitutions(this),
lookupConformance,
SubstFlags::AllowLoweredTypes);
SubstFlags::AllowLoweredTypes |
SubstFlags::PreservePackExpansionLevel);
assert((!result->hasTypeParameter() || result->hasError() ||
getKind() == Kind::Opaque) &&
"not fully substituted");
Expand Down Expand Up @@ -787,16 +789,19 @@ GenericEnvironment::mapElementTypeIntoPackContext(Type type) const {
// Map element archetypes to the pack archetypes by converting
// element types to interface types and adding the isParameterPack
// bit. Then, map type parameters to archetypes.
return type.subst([&](SubstitutableType *type) {
auto *genericParam = type->getAs<GenericTypeParamType>();
if (!genericParam)
return Type();

if (auto *packParam = packParamForElement[{genericParam}])
return substitutions(packParam);

return substitutions(genericParam);
}, LookUpConformanceInSignature(sig.getPointer()));
return type.subst(
[&](SubstitutableType *type) {
auto *genericParam = type->getAs<GenericTypeParamType>();
if (!genericParam)
return Type();

if (auto *packParam = packParamForElement[{genericParam}])
return substitutions(packParam);

return substitutions(genericParam);
},
LookUpConformanceInSignature(sig.getPointer()),
SubstFlags::PreservePackExpansionLevel);
}

namespace {
Expand Down
77 changes: 64 additions & 13 deletions lib/AST/ParameterPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ namespace {

/// Collects all unique pack type parameters referenced from the pattern type,
/// skipping those captured by nested pack expansion types.
struct PackTypeParameterCollector: TypeWalker {
llvm::SetVector<Type> typeParams;
struct PackReferenceCollector: TypeWalker {
llvm::function_ref<bool (Type)> fn;
unsigned expansionLevel;
SmallVector<unsigned, 2> elementLevel;

PackTypeParameterCollector() : expansionLevel(0) {
PackReferenceCollector(llvm::function_ref<bool (Type)> fn)
: fn(fn), expansionLevel(0) {
elementLevel.push_back(0);
}

Expand Down Expand Up @@ -72,12 +73,8 @@ struct PackTypeParameterCollector: TypeWalker {
}

if (elementLevel.back() == expansionLevel) {
if (auto *paramTy = t->getAs<GenericTypeParamType>()) {
if (paramTy->isParameterPack())
typeParams.insert(paramTy);
} else if (auto *archetypeTy = t->getAs<PackArchetypeType>()) {
typeParams.insert(archetypeTy->getRoot());
}
if (fn(t))
return Action::Stop;
}

return Action::Continue;
Expand All @@ -96,13 +93,28 @@ struct PackTypeParameterCollector: TypeWalker {

}

void TypeBase::walkPackReferences(
llvm::function_ref<bool (Type)> fn) {
Type(this).walk(PackReferenceCollector(fn));
}

void TypeBase::getTypeParameterPacks(
SmallVectorImpl<Type> &rootParameterPacks) {
PackTypeParameterCollector collector;
Type(this).walk(collector);
llvm::SmallSetVector<Type, 2> rootParameterPackSet;

walkPackReferences([&](Type t) {
if (auto *paramTy = t->getAs<GenericTypeParamType>()) {
if (paramTy->isParameterPack())
rootParameterPackSet.insert(paramTy);
} else if (auto *archetypeTy = t->getAs<PackArchetypeType>()) {
rootParameterPackSet.insert(archetypeTy->getRoot());
}

return false;
});

rootParameterPacks.append(collector.typeParams.begin(),
collector.typeParams.end());
rootParameterPacks.append(rootParameterPackSet.begin(),
rootParameterPackSet.end());
}

bool GenericTypeParamType::isParameterPack() const {
Expand Down Expand Up @@ -133,6 +145,45 @@ PackType *TypeBase::getPackSubstitutionAsPackType() {
}
}

static Type increasePackElementLevelImpl(
Type type, unsigned level, unsigned outerLevel) {
assert(level > 0);

return type.transformRec([&](TypeBase *t) -> Optional<Type> {
if (auto *elementType = dyn_cast<PackElementType>(t)) {
if (elementType->getLevel() >= outerLevel) {
elementType = PackElementType::get(elementType->getPackType(),
elementType->getLevel() + level);
}

return Type(elementType);
}

if (auto *expansionType = dyn_cast<PackExpansionType>(t)) {
return Type(PackExpansionType::get(
increasePackElementLevelImpl(expansionType->getPatternType(),
level, outerLevel + 1),
expansionType->getCountType()));
}

if (t->isParameterPack() || isa<PackArchetypeType>(t)) {
if (outerLevel == 0)
return Type(PackElementType::get(t, level));

return Type(t);
}

return None;
});
}

Type TypeBase::increasePackElementLevel(unsigned level) {
if (level == 0)
return Type(this);

return increasePackElementLevelImpl(Type(this), level, 0);
}

CanType PackExpansionType::getReducedShape() {
auto reducedShape = countType->getReducedShape();
if (reducedShape == getASTContext().TheEmptyTupleType)
Expand Down
10 changes: 7 additions & 3 deletions lib/AST/ProtocolConformanceRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ ProtocolConformanceRef::subst(Type origType, InFlightSubstitution &IFS) const {
}

// Check the conformance map.
return IFS.lookupConformance(origType->getCanonicalType(), substType, proto);
// FIXME: Pack element level?
return IFS.lookupConformance(origType->getCanonicalType(), substType, proto,
/*level=*/0);
}

ProtocolConformanceRef ProtocolConformanceRef::mapConformanceOutOfContext() const {
Expand All @@ -130,7 +132,8 @@ ProtocolConformanceRef ProtocolConformanceRef::mapConformanceOutOfContext() cons
return archetypeType->getInterfaceType();
return type;
},
MakeAbstractConformanceForGenericType());
MakeAbstractConformanceForGenericType(),
SubstFlags::PreservePackExpansionLevel);
return ProtocolConformanceRef(concrete);
} else if (isPack()) {
return getPack()->subst(
Expand All @@ -139,7 +142,8 @@ ProtocolConformanceRef ProtocolConformanceRef::mapConformanceOutOfContext() cons
return archetypeType->getInterfaceType();
return type;
},
MakeAbstractConformanceForGenericType());
MakeAbstractConformanceForGenericType(),
SubstFlags::PreservePackExpansionLevel);
}

return *this;
Expand Down
4 changes: 3 additions & 1 deletion lib/AST/RequirementEnvironment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ RequirementEnvironment::RequirementEnvironment(
MakeAbstractConformanceForGenericType();

auto substConcreteType = concreteType.subst(
conformanceToWitnessThunkTypeFn, conformanceToWitnessThunkConformanceFn);
conformanceToWitnessThunkTypeFn,
conformanceToWitnessThunkConformanceFn,
SubstFlags::PreservePackExpansionLevel);

// Calculate the depth at which the requirement's generic parameters
// appear in the witness thunk signature.
Expand Down
11 changes: 8 additions & 3 deletions lib/AST/SubstitutionMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ SubstitutionMap SubstitutionMap::get(GenericSignature genericSig,
CanType depTy = req.getFirstType()->getCanonicalType();
auto replacement = depTy.subst(IFS);
auto *proto = req.getProtocolDecl();
auto conformance = IFS.lookupConformance(depTy, replacement, proto);
auto conformance = IFS.lookupConformance(depTy, replacement, proto,
/*level=*/0);
conformances.push_back(conformance);
}

Expand Down Expand Up @@ -440,7 +441,9 @@ SubstitutionMap::lookupConformance(CanType type, ProtocolDecl *proto) const {
}

SubstitutionMap SubstitutionMap::mapReplacementTypesOutOfContext() const {
return subst(MapTypeOutOfContext(), MakeAbstractConformanceForGenericType());
return subst(MapTypeOutOfContext(),
MakeAbstractConformanceForGenericType(),
SubstFlags::PreservePackExpansionLevel);
}

SubstitutionMap SubstitutionMap::subst(SubstitutionMap subMap,
Expand Down Expand Up @@ -829,5 +832,7 @@ SubstitutionMap SubstitutionMap::mapIntoTypeExpansionContext(
ReplaceOpaqueTypesWithUnderlyingTypes replacer(
context.getContext(), context.getResilienceExpansion(),
context.isWholeModuleContext());
return this->subst(replacer, replacer, SubstFlags::SubstituteOpaqueArchetypes);
return this->subst(replacer, replacer,
SubstFlags::SubstituteOpaqueArchetypes |
SubstFlags::PreservePackExpansionLevel);
}
Loading