Skip to content

Commit aa026f4

Browse files
authored
Merge pull request #67266 from slavapestov/transform-type-parameter-packs
Fix various places where we didn't handle "bound" pack references correctly
2 parents e973ff9 + 8a132c0 commit aa026f4

File tree

10 files changed

+135
-72
lines changed

10 files changed

+135
-72
lines changed

include/swift/AST/Type.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,14 @@ class Type {
316316
llvm::function_ref<llvm::Optional<Type>(TypeBase *, TypePosition)> fn)
317317
const;
318318

319+
/// Transform free pack element references, that is, those not captured by a
320+
/// pack expansion.
321+
///
322+
/// This is the 'map' counterpart to TypeBase::getTypeParameterPacks().
323+
Type transformTypeParameterPacks(
324+
llvm::function_ref<llvm::Optional<Type>(SubstitutableType *)> fn)
325+
const;
326+
319327
/// Look through the given type and its children and apply fn to them.
320328
void visit(llvm::function_ref<void (Type)> fn) const {
321329
findIf([&fn](Type t) -> bool {

lib/AST/ASTContext.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5867,13 +5867,13 @@ ASTContext::getOpenedElementSignature(CanGenericSignature baseGenericSig,
58675867
}
58685868

58695869
auto eraseParameterPackRec = [&](Type type) -> Type {
5870-
return type.transformRec([&](Type t) -> llvm::Optional<Type> {
5871-
if (auto *paramType = t->getAs<GenericTypeParamType>()) {
5870+
return type.transformTypeParameterPacks([&](SubstitutableType *t) -> llvm::Optional<Type> {
5871+
if (auto *paramType = dyn_cast<GenericTypeParamType>(t)) {
58725872
if (packElementParams.find(paramType) != packElementParams.end()) {
58735873
return Type(packElementParams[paramType]);
58745874
}
58755875

5876-
return t;
5876+
return Type(t);
58775877
}
58785878
return llvm::None;
58795879
});

lib/AST/ASTPrinter.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ class PrintAST : public ASTVisitor<PrintAST> {
803803

804804
void printType(Type T) { printTypeWithOptions(T, Options); }
805805

806-
void printTransformedTypeWithOptions(Type T, PrintOptions options) {
806+
Type getTransformedType(Type T) {
807807
if (CurrentType && Current && CurrentType->mayHaveMembers()) {
808808
auto *M = Current->getDeclContext()->getParentModule();
809809
SubstitutionMap subMap;
@@ -824,9 +824,16 @@ class PrintAST : public ASTVisitor<PrintAST> {
824824
}
825825

826826
T = T.subst(subMap, SubstFlags::DesugarMemberTypes);
827+
}
828+
829+
return T;
830+
}
827831

832+
void printTransformedTypeWithOptions(Type T, PrintOptions options) {
833+
T = getTransformedType(T);
834+
835+
if (CurrentType && Current && CurrentType->mayHaveMembers())
828836
options.TransformContext = TypeTransformContext(CurrentType);
829-
}
830837

831838
printTypeWithOptions(T, options);
832839
}
@@ -1832,6 +1839,11 @@ void PrintAST::printSingleDepthOfGenericSignature(
18321839
}
18331840

18341841
void PrintAST::printRequirement(const Requirement &req) {
1842+
SmallVector<Type, 2> rootParameterPacks;
1843+
getTransformedType(req.getFirstType())
1844+
->getTypeParameterPacks(rootParameterPacks);
1845+
bool isPackRequirement = !rootParameterPacks.empty();
1846+
18351847
switch (req.getKind()) {
18361848
case RequirementKind::SameShape:
18371849
Printer << "(repeat (";
@@ -1841,22 +1853,21 @@ void PrintAST::printRequirement(const Requirement &req) {
18411853
Printer << ")) : Any";
18421854
return;
18431855
case RequirementKind::Layout:
1844-
if (req.getFirstType()->hasParameterPack())
1856+
if (isPackRequirement)
18451857
Printer << "repeat ";
18461858
printTransformedType(req.getFirstType());
18471859
Printer << " : ";
18481860
req.getLayoutConstraint()->print(Printer, Options);
18491861
return;
18501862
case RequirementKind::Conformance:
18511863
case RequirementKind::Superclass:
1852-
if (req.getFirstType()->hasParameterPack())
1864+
if (isPackRequirement)
18531865
Printer << "repeat ";
18541866
printTransformedType(req.getFirstType());
18551867
Printer << " : ";
18561868
break;
18571869
case RequirementKind::SameType:
1858-
if (req.getFirstType()->hasParameterPack() ||
1859-
req.getSecondType()->hasParameterPack())
1870+
if (isPackRequirement)
18601871
Printer << "repeat ";
18611872
printTransformedType(req.getFirstType());
18621873
Printer << " == ";

lib/AST/GenericEnvironment.cpp

Lines changed: 23 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,10 @@ Type GenericEnvironment::mapTypeIntoContext(GenericTypeParamType *type) const {
659659
return result;
660660
}
661661

662+
/// So this expects a type written with the archetypes of the original generic
663+
/// environment, not 'this', the opened element environment, because it is the
664+
/// original PackArchetypes that become ElementArchetypes. Also this function
665+
/// does not apply outer substitutions, which might not be what you expect.
662666
Type
663667
GenericEnvironment::mapContextualPackTypeIntoElementContext(Type type) const {
664668
assert(getKind() == Kind::OpenedElement);
@@ -672,77 +676,40 @@ GenericEnvironment::mapContextualPackTypeIntoElementContext(Type type) const {
672676
FindElementArchetypeForOpenedPackParam
673677
findElementArchetype(this, getOpenedPackParams());
674678

675-
return type.transformRec([&](TypeBase *ty) -> llvm::Optional<Type> {
676-
// We're only directly substituting pack archetypes.
677-
auto archetype = ty->getAs<PackArchetypeType>();
678-
if (!archetype) {
679-
// Don't recurse into nested pack expansions.
680-
if (ty->is<PackExpansionType>())
681-
return Type(ty);
679+
return type.transformTypeParameterPacks(
680+
[&](SubstitutableType *ty) -> llvm::Optional<Type> {
681+
if (auto *packArchetype = dyn_cast<PackArchetypeType>(ty)) {
682+
auto interfaceType = packArchetype->getInterfaceType();
683+
if (sig->haveSameShape(interfaceType, shapeClass))
684+
return Type(findElementArchetype(interfaceType));
685+
}
682686

683-
// Recurse into any other type.
684-
return llvm::None;
685-
}
686-
687-
auto rootArchetype = cast<PackArchetypeType>(archetype->getRoot());
688-
689-
// TODO: assert that the generic environment of the pack archetype
690-
// matches the signature that was originally opened to make this
691-
// environment. Unfortunately, that isn't a trivial check because of
692-
// the extra opened-element parameters.
693-
694-
// If the archetype isn't the shape that was opened by this
695-
// environment, ignore it.
696-
auto rootParam = cast<GenericTypeParamType>(
697-
rootArchetype->getInterfaceType().getPointer());
698-
assert(rootParam->isParameterPack());
699-
if (!sig->haveSameShape(rootParam, shapeClass))
700-
return Type(ty);
701-
702-
return Type(findElementArchetype(archetype->getInterfaceType()));
703-
});
687+
return llvm::None;
688+
});
704689
}
705690

706691
CanType
707692
GenericEnvironment::mapContextualPackTypeIntoElementContext(CanType type) const {
708693
return CanType(mapContextualPackTypeIntoElementContext(Type(type)));
709694
}
710695

696+
/// Unlike mapContextualPackTypeIntoElementContext(), this also applies outer
697+
/// substitutions, so it behaves like mapTypeIntoContext() in that respect.
711698
Type
712699
GenericEnvironment::mapPackTypeIntoElementContext(Type type) const {
713700
assert(getKind() == Kind::OpenedElement);
714701
assert(!type->hasArchetype());
715702

716-
auto sig = getGenericSignature();
717-
auto shapeClass = getOpenedElementShapeClass();
718-
719-
FindElementArchetypeForOpenedPackParam
720-
findElementArchetype(this, getOpenedPackParams());
703+
if (!type->hasTypeParameter()) return type;
721704

722-
// Map the interface type to the element type by stripping
723-
// away the isParameterPack bit before mapping type parameters
724-
// to archetypes.
725-
return type.transformRec([&](TypeBase *ty) -> llvm::Optional<Type> {
726-
// We're only directly substituting pack parameters.
727-
if (!ty->isTypeParameter()) {
728-
// Don't recurse into nested pack expansions; just map it into
729-
// context.
730-
if (ty->is<PackExpansionType>())
731-
return mapTypeIntoContext(ty);
732-
733-
// Recurse into any other type.
734-
return llvm::None;
735-
}
736-
737-
// Just do normal mapping for types that are not rooted in
738-
// opened type parameters.
739-
auto rootParam = ty->getRootGenericParam();
740-
if (!rootParam->isParameterPack() ||
741-
!sig->haveSameShape(rootParam, shapeClass))
742-
return mapTypeIntoContext(ty);
705+
// Get a contextual type in the original generic environment, not the
706+
// substituted one, which is what mapContextualPackTypeIntoElementContext()
707+
// expects.
708+
auto contextualType = getPackElementContextSubstitutions()
709+
.getGenericSignature().getGenericEnvironment()->mapTypeIntoContext(type);
743710

744-
return Type(findElementArchetype(ty));
745-
});
711+
contextualType = mapContextualPackTypeIntoElementContext(contextualType);
712+
return maybeApplyOuterContextSubstitutions(contextualType);
746713
}
747714

748715
Type

lib/AST/ParameterPack.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,64 @@
2525

2626
using namespace swift;
2727

28+
/// FV(PackExpansionType(Pattern, Count), N) = FV(Pattern, N+1)
29+
/// FV(PackElementType(Param, M), N) = FV(Param, 0) if M >= N, {} otherwise
30+
/// FV(Param, N) = {Param}
31+
static Type transformTypeParameterPacksRec(
32+
Type t, llvm::function_ref<llvm::Optional<Type>(SubstitutableType *)> fn,
33+
unsigned expansionLevel) {
34+
return t.transformWithPosition(
35+
TypePosition::Invariant,
36+
[&](TypeBase *t, TypePosition p) -> llvm::Optional<Type> {
37+
38+
// If we're already inside N levels of PackExpansionType, and we're
39+
// walking into another PackExpansionType, a type parameter pack
40+
// reference now needs level (N+1) to be free.
41+
if (auto *expansionType = dyn_cast<PackExpansionType>(t)) {
42+
auto countType = expansionType->getCountType();
43+
auto patternType = expansionType->getPatternType();
44+
auto newPatternType = transformTypeParameterPacksRec(
45+
patternType, fn, expansionLevel + 1);
46+
if (patternType.getPointer() != newPatternType.getPointer())
47+
return Type(PackExpansionType::get(patternType, countType));
48+
49+
return Type(expansionType);
50+
}
51+
52+
// A PackElementType with level N reaches past N levels of
53+
// nested PackExpansionType. So a type parameter pack reference
54+
// therein is free if N is greater than or equal to our current
55+
// expansion level.
56+
if (auto *eltType = dyn_cast<PackElementType>(t)) {
57+
if (eltType->getLevel() >= expansionLevel) {
58+
return transformTypeParameterPacksRec(eltType->getPackType(), fn,
59+
/*expansionLevel=*/0);
60+
}
61+
62+
return Type(eltType);
63+
}
64+
65+
// A bare type parameter pack is like a PackElementType with level 0.
66+
if (auto *paramType = dyn_cast<SubstitutableType>(t)) {
67+
if (expansionLevel == 0 &&
68+
(isa<PackArchetypeType>(paramType) ||
69+
(isa<GenericTypeParamType>(paramType) &&
70+
cast<GenericTypeParamType>(paramType)->isParameterPack()))) {
71+
return fn(paramType);
72+
}
73+
74+
return Type(paramType);
75+
}
76+
77+
return llvm::None;
78+
});
79+
}
80+
81+
Type Type::transformTypeParameterPacks(
82+
llvm::function_ref<llvm::Optional<Type>(SubstitutableType *)> fn) const {
83+
return transformTypeParameterPacksRec(*this, fn, /*expansionLevel=*/0);
84+
}
85+
2886
namespace {
2987

3088
/// Collects all unique pack type parameters referenced from the pattern type,

lib/ClangImporter/ImportName.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2247,7 +2247,6 @@ ImportedName NameImporter::importNameImpl(const clang::NamedDecl *D,
22472247
if (arg.getKind() == clang::TemplateArgument::Type) {
22482248
auto ty = arg.getAsType().getTypePtr();
22492249
if (auto builtin = dyn_cast<clang::BuiltinType>(ty)) {
2250-
auto &ctx = swiftCtx;
22512250
if (auto swiftTypeName = getSwiftBuiltinTypeName(builtin)) {
22522251
buffer << *swiftTypeName;
22532252
return;

lib/IRGen/GenPack.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ static void bindElementSignatureRequirementsAtIndex(
330330
IGF.emitTypeMetadataRef(patternPackArchetype, request);
331331
auto elementArchetype =
332332
context.environment
333-
->mapPackTypeIntoElementContext(
334-
patternPackArchetype->getInterfaceType())
333+
->mapContextualPackTypeIntoElementContext(
334+
patternPackArchetype)
335335
->getCanonicalType();
336336
auto *patternPack = response.getMetadata();
337337
auto elementMetadata = bindMetadataAtIndex(
@@ -346,8 +346,8 @@ static void bindElementSignatureRequirementsAtIndex(
346346
auto patternPackArchetype = getMappedPackArchetypeType(context, ty);
347347
auto elementArchetype =
348348
context.environment
349-
->mapPackTypeIntoElementContext(
350-
patternPackArchetype->getInterfaceType())
349+
->mapContextualPackTypeIntoElementContext(
350+
patternPackArchetype)
351351
->getCanonicalType();
352352
llvm::Value *_metadata = nullptr;
353353
auto packConformance =

lib/Sema/CSSimplify.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9162,7 +9162,6 @@ ConstraintSystem::simplifyPackElementOfConstraint(Type first, Type second,
91629162
if (!patternType->hasTypeVariable()) {
91639163
auto *loc = getConstraintLocator(locator);
91649164
auto shapeClass = patternType->getReducedShape();
9165-
patternType = patternType->mapTypeOutOfContext();
91669165
auto *elementEnv = getPackElementEnvironment(loc, shapeClass);
91679166

91689167
// Without an opened element environment, we cannot derive the
@@ -9185,7 +9184,7 @@ ConstraintSystem::simplifyPackElementOfConstraint(Type first, Type second,
91859184
}
91869185

91879186
auto expectedElementTy =
9188-
elementEnv->mapPackTypeIntoElementContext(patternType);
9187+
elementEnv->mapContextualPackTypeIntoElementContext(patternType);
91899188
assert(!expectedElementTy->is<PackType>());
91909189

91919190
addConstraint(ConstraintKind::Equal, elementType, expectedElementTy,

test/ModuleInterface/pack_expansion_type.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// RUN: %target-swift-emit-module-interface(%t/PackExpansionType.swiftinterface) %s -module-name PackExpansionType -disable-availability-checking
33
// RUN: %FileCheck %s < %t/PackExpansionType.swiftinterface
44

5+
/// Requirements
6+
57
// CHECK: #if compiler(>=5.3) && $ParameterPacks
68
// CHECK-NEXT: public func variadicFunction<each T, each U>(t: repeat each T, u: repeat each U) -> (repeat (each T, each U)) where (repeat (each T, each U)) : Any
79
public func variadicFunction<each T, each U>(t: repeat each T, u: repeat each U) -> (repeat (each T, each U)) {}
@@ -24,6 +26,13 @@ public struct VariadicType<each T> {
2426
// CHECK: }
2527
// CHECK-NEXT: #endif
2628

29+
// The second requirement should not be prefixed with 'repeat'
30+
// CHECK: public struct SameTypeReq<T, each U> where T : Swift.Sequence, T.Element == PackExpansionType.VariadicType<repeat each U> {
31+
public struct SameTypeReq<T: Sequence, each U> where T.Element == VariadicType<repeat each U> {}
32+
// CHECK: }
33+
34+
/// Pack expansion types
35+
2736
// CHECK: public func returnsVariadicType() -> PackExpansionType.VariadicType<>
2837
public func returnsVariadicType() -> VariadicType< > {}
2938

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: %target-swift-frontend -emit-ir %s -disable-availability-checking
2+
3+
public protocol P {
4+
associatedtype A
5+
}
6+
7+
public struct G<each T> {}
8+
9+
public struct S<T: P, each U: P> where repeat T.A == G<repeat each U> {
10+
public init(predicate: repeat each U) {}
11+
}
12+

0 commit comments

Comments
 (0)