Skip to content

Commit b834493

Browse files
committed
AST: Implement Type::transformTypeParameterPacks()
1 parent 4723e55 commit b834493

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

include/swift/AST/Type.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,14 @@ class Type {
316316
llvm::function_ref<llvm::Optional<Type>(TypeBase *, TypePosition)> fn)
317317
const;
318318

319+
/// Transform free pack element references, that is, those not captured by a
320+
/// pack expansion.
321+
///
322+
/// This is the 'map' counterpart to TypeBase::getTypeParameterPacks().
323+
Type transformTypeParameterPacks(
324+
llvm::function_ref<llvm::Optional<Type>(SubstitutableType *)> fn)
325+
const;
326+
319327
/// Look through the given type and its children and apply fn to them.
320328
void visit(llvm::function_ref<void (Type)> fn) const {
321329
findIf([&fn](Type t) -> bool {

lib/AST/ParameterPack.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,64 @@
2525

2626
using namespace swift;
2727

28+
/// FV(PackExpansionType(Pattern, Count), N) = FV(Pattern, N+1)
29+
/// FV(PackElementType(Param, M), N) = FV(Param, 0) if M >= N, {} otherwise
30+
/// FV(Param, N) = {Param}
31+
static Type transformTypeParameterPacksRec(
32+
Type t, llvm::function_ref<llvm::Optional<Type>(SubstitutableType *)> fn,
33+
unsigned expansionLevel) {
34+
return t.transformWithPosition(
35+
TypePosition::Invariant,
36+
[&](TypeBase *t, TypePosition p) -> llvm::Optional<Type> {
37+
38+
// If we're already inside N levels of PackExpansionType, and we're
39+
// walking into another PackExpansionType, a type parameter pack
40+
// reference now needs level (N+1) to be free.
41+
if (auto *expansionType = dyn_cast<PackExpansionType>(t)) {
42+
auto countType = expansionType->getCountType();
43+
auto patternType = expansionType->getPatternType();
44+
auto newPatternType = transformTypeParameterPacksRec(
45+
patternType, fn, expansionLevel + 1);
46+
if (patternType.getPointer() != newPatternType.getPointer())
47+
return Type(PackExpansionType::get(patternType, countType));
48+
49+
return Type(expansionType);
50+
}
51+
52+
// A PackElementType with level N reaches past N levels of
53+
// nested PackExpansionType. So a type parameter pack reference
54+
// therein is free if N is greater than or equal to our current
55+
// expansion level.
56+
if (auto *eltType = dyn_cast<PackElementType>(t)) {
57+
if (eltType->getLevel() >= expansionLevel) {
58+
return transformTypeParameterPacksRec(eltType->getPackType(), fn,
59+
/*expansionLevel=*/0);
60+
}
61+
62+
return Type(eltType);
63+
}
64+
65+
// A bare type parameter pack is like a PackElementType with level 0.
66+
if (auto *paramType = dyn_cast<SubstitutableType>(t)) {
67+
if (expansionLevel == 0 &&
68+
(isa<PackArchetypeType>(paramType) ||
69+
(isa<GenericTypeParamType>(paramType) &&
70+
cast<GenericTypeParamType>(paramType)->isParameterPack()))) {
71+
return fn(paramType);
72+
}
73+
74+
return Type(paramType);
75+
}
76+
77+
return llvm::None;
78+
});
79+
}
80+
81+
Type Type::transformTypeParameterPacks(
82+
llvm::function_ref<llvm::Optional<Type>(SubstitutableType *)> fn) const {
83+
return transformTypeParameterPacksRec(*this, fn, /*expansionLevel=*/0);
84+
}
85+
2886
namespace {
2987

3088
/// Collects all unique pack type parameters referenced from the pattern type,

0 commit comments

Comments
 (0)