Skip to content

Commit f6870b7

Browse files
authored
Merge pull request #67310 from slavapestov/transform-type-parameter-packs-5.9
Sema: Fix ASTPrinter and opened element environments to handle nested pack expansion types [5.9]
2 parents d12674a + 5452ece commit f6870b7

File tree

7 files changed

+117
-11
lines changed

7 files changed

+117
-11
lines changed

include/swift/AST/Type.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,14 @@ class Type {
316316
llvm::function_ref<llvm::Optional<Type>(TypeBase *, TypePosition)> fn)
317317
const;
318318

319+
/// Transform free pack element references, that is, those not captured by a
320+
/// pack expansion.
321+
///
322+
/// This is the 'map' counterpart to TypeBase::getTypeParameterPacks().
323+
Type transformTypeParameterPacks(
324+
llvm::function_ref<llvm::Optional<Type>(SubstitutableType *)> fn)
325+
const;
326+
319327
/// Look through the given type and its children and apply fn to them.
320328
void visit(llvm::function_ref<void (Type)> fn) const {
321329
findIf([&fn](Type t) -> bool {

lib/AST/ASTContext.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5856,13 +5856,13 @@ ASTContext::getOpenedElementSignature(CanGenericSignature baseGenericSig,
58565856
}
58575857

58585858
auto eraseParameterPackRec = [&](Type type) -> Type {
5859-
return type.transformRec([&](Type t) -> llvm::Optional<Type> {
5860-
if (auto *paramType = t->getAs<GenericTypeParamType>()) {
5859+
return type.transformTypeParameterPacks([&](SubstitutableType *t) -> llvm::Optional<Type> {
5860+
if (auto *paramType = dyn_cast<GenericTypeParamType>(t)) {
58615861
if (packElementParams.find(paramType) != packElementParams.end()) {
58625862
return Type(packElementParams[paramType]);
58635863
}
58645864

5865-
return t;
5865+
return Type(t);
58665866
}
58675867
return llvm::None;
58685868
});

lib/AST/ASTPrinter.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ class PrintAST : public ASTVisitor<PrintAST> {
804804

805805
void printType(Type T) { printTypeWithOptions(T, Options); }
806806

807-
void printTransformedTypeWithOptions(Type T, PrintOptions options) {
807+
Type getTransformedType(Type T) {
808808
if (CurrentType && Current && CurrentType->mayHaveMembers()) {
809809
auto *M = Current->getDeclContext()->getParentModule();
810810
SubstitutionMap subMap;
@@ -825,9 +825,16 @@ class PrintAST : public ASTVisitor<PrintAST> {
825825
}
826826

827827
T = T.subst(subMap, SubstFlags::DesugarMemberTypes);
828+
}
829+
830+
return T;
831+
}
828832

833+
void printTransformedTypeWithOptions(Type T, PrintOptions options) {
834+
T = getTransformedType(T);
835+
836+
if (CurrentType && Current && CurrentType->mayHaveMembers())
829837
options.TransformContext = TypeTransformContext(CurrentType);
830-
}
831838

832839
printTypeWithOptions(T, options);
833840
}
@@ -1833,6 +1840,11 @@ void PrintAST::printSingleDepthOfGenericSignature(
18331840
}
18341841

18351842
void PrintAST::printRequirement(const Requirement &req) {
1843+
SmallVector<Type, 2> rootParameterPacks;
1844+
getTransformedType(req.getFirstType())
1845+
->getTypeParameterPacks(rootParameterPacks);
1846+
bool isPackRequirement = !rootParameterPacks.empty();
1847+
18361848
switch (req.getKind()) {
18371849
case RequirementKind::SameShape:
18381850
Printer << "(repeat (";
@@ -1842,22 +1854,21 @@ void PrintAST::printRequirement(const Requirement &req) {
18421854
Printer << ")) : Any";
18431855
return;
18441856
case RequirementKind::Layout:
1845-
if (req.getFirstType()->hasParameterPack())
1857+
if (isPackRequirement)
18461858
Printer << "repeat ";
18471859
printTransformedType(req.getFirstType());
18481860
Printer << " : ";
18491861
req.getLayoutConstraint()->print(Printer, Options);
18501862
return;
18511863
case RequirementKind::Conformance:
18521864
case RequirementKind::Superclass:
1853-
if (req.getFirstType()->hasParameterPack())
1865+
if (isPackRequirement)
18541866
Printer << "repeat ";
18551867
printTransformedType(req.getFirstType());
18561868
Printer << " : ";
18571869
break;
18581870
case RequirementKind::SameType:
1859-
if (req.getFirstType()->hasParameterPack() ||
1860-
req.getSecondType()->hasParameterPack())
1871+
if (isPackRequirement)
18611872
Printer << "repeat ";
18621873
printTransformedType(req.getFirstType());
18631874
Printer << " == ";

lib/AST/ParameterPack.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,64 @@
2525

2626
using namespace swift;
2727

28+
/// FV(PackExpansionType(Pattern, Count), N) = FV(Pattern, N+1)
29+
/// FV(PackElementType(Param, M), N) = FV(Param, 0) if M >= N, {} otherwise
30+
/// FV(Param, N) = {Param}
31+
static Type transformTypeParameterPacksRec(
32+
Type t, llvm::function_ref<llvm::Optional<Type>(SubstitutableType *)> fn,
33+
unsigned expansionLevel) {
34+
return t.transformWithPosition(
35+
TypePosition::Invariant,
36+
[&](TypeBase *t, TypePosition p) -> llvm::Optional<Type> {
37+
38+
// If we're already inside N levels of PackExpansionType, and we're
39+
// walking into another PackExpansionType, a type parameter pack
40+
// reference now needs level (N+1) to be free.
41+
if (auto *expansionType = dyn_cast<PackExpansionType>(t)) {
42+
auto countType = expansionType->getCountType();
43+
auto patternType = expansionType->getPatternType();
44+
auto newPatternType = transformTypeParameterPacksRec(
45+
patternType, fn, expansionLevel + 1);
46+
if (patternType.getPointer() != newPatternType.getPointer())
47+
return Type(PackExpansionType::get(patternType, countType));
48+
49+
return Type(expansionType);
50+
}
51+
52+
// A PackElementType with level N reaches past N levels of
53+
// nested PackExpansionType. So a type parameter pack reference
54+
// therein is free if N is greater than or equal to our current
55+
// expansion level.
56+
if (auto *eltType = dyn_cast<PackElementType>(t)) {
57+
if (eltType->getLevel() >= expansionLevel) {
58+
return transformTypeParameterPacksRec(eltType->getPackType(), fn,
59+
/*expansionLevel=*/0);
60+
}
61+
62+
return Type(eltType);
63+
}
64+
65+
// A bare type parameter pack is like a PackElementType with level 0.
66+
if (auto *paramType = dyn_cast<SubstitutableType>(t)) {
67+
if (expansionLevel == 0 &&
68+
(isa<PackArchetypeType>(paramType) ||
69+
(isa<GenericTypeParamType>(paramType) &&
70+
cast<GenericTypeParamType>(paramType)->isParameterPack()))) {
71+
return fn(paramType);
72+
}
73+
74+
return Type(paramType);
75+
}
76+
77+
return llvm::None;
78+
});
79+
}
80+
81+
Type Type::transformTypeParameterPacks(
82+
llvm::function_ref<llvm::Optional<Type>(SubstitutableType *)> fn) const {
83+
return transformTypeParameterPacksRec(*this, fn, /*expansionLevel=*/0);
84+
}
85+
2886
namespace {
2987

3088
/// Collects all unique pack type parameters referenced from the pattern type,

lib/Sema/ConstraintSystem.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,8 +1024,11 @@ Type ConstraintSystem::openType(Type type, OpenedTypeMap &replacements,
10241024
// that gets introduced by the interface type, see
10251025
// \c openUnboundGenericType for more details.
10261026
if (auto *packTy = type->getAs<PackType>()) {
1027-
if (auto expansion = packTy->unwrapSingletonPackExpansion())
1028-
type = expansion->getPatternType();
1027+
if (auto expansionTy = packTy->unwrapSingletonPackExpansion()) {
1028+
auto patternTy = expansionTy->getPatternType();
1029+
if (patternTy->isTypeParameter())
1030+
return openType(patternTy, replacements, locator);
1031+
}
10291032
}
10301033

10311034
if (auto *expansion = type->getAs<PackExpansionType>()) {

test/ModuleInterface/pack_expansion_type.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// RUN: %target-swift-emit-module-interface(%t/PackExpansionType.swiftinterface) %s -module-name PackExpansionType -disable-availability-checking
33
// RUN: %FileCheck %s < %t/PackExpansionType.swiftinterface
44

5+
/// Requirements
6+
57
// CHECK: #if compiler(>=5.3) && $ParameterPacks
68
// CHECK-NEXT: public func variadicFunction<each T, each U>(t: repeat each T, u: repeat each U) -> (repeat (each T, each U)) where (repeat (each T, each U)) : Any
79
public func variadicFunction<each T, each U>(t: repeat each T, u: repeat each U) -> (repeat (each T, each U)) {}
@@ -24,6 +26,13 @@ public struct VariadicType<each T> {
2426
// CHECK: }
2527
// CHECK-NEXT: #endif
2628

29+
// The second requirement should not be prefixed with 'repeat'
30+
// CHECK: public struct SameTypeReq<T, each U> where T : Swift.Sequence, T.Element == PackExpansionType.VariadicType<repeat each U> {
31+
public struct SameTypeReq<T: Sequence, each U> where T.Element == VariadicType<repeat each U> {}
32+
// CHECK: }
33+
34+
/// Pack expansion types
35+
2736
// CHECK: public func returnsVariadicType() -> PackExpansionType.VariadicType<>
2837
public func returnsVariadicType() -> VariadicType< > {}
2938

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: %target-swift-frontend -emit-ir %s -disable-availability-checking
2+
3+
public protocol P {
4+
associatedtype A
5+
}
6+
7+
public struct G<each T> {}
8+
9+
public protocol Q {
10+
init()
11+
}
12+
13+
public struct S<T: P, each U: P>: Q where repeat T.A == G<repeat each U.A> {
14+
public init() {}
15+
public init(predicate: repeat each U) {}
16+
}
17+

0 commit comments

Comments
 (0)