Skip to content

Commit 5a9e43e

Browse files
committed
AST: Re-implement transformWithPosition()'s handling of pack expansions with new utilities
1 parent ee8f45c commit 5a9e43e

File tree

2 files changed

+15
-109
lines changed

2 files changed

+15
-109
lines changed

lib/AST/Type.cpp

Lines changed: 12 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -5528,119 +5528,33 @@ case TypeKind::Id:
55285528
anyChanged = true;
55295529
}
55305530

5531-
if (auto *transformedPack = transformedEltTy->getAs<PackType>()) {
5532-
elements.append(transformedPack->getElementTypes().begin(),
5533-
transformedPack->getElementTypes().end());
5534-
} else {
5535-
elements.push_back(transformedEltTy);
5536-
}
5531+
elements.push_back(transformedEltTy);
55375532
}
55385533

55395534
if (!anyChanged)
55405535
return *this;
55415536

5542-
return PackType::get(Ptr->getASTContext(), elements);
5537+
return PackType::get(Ptr->getASTContext(), elements)->flattenPackTypes();
55435538
}
55445539

55455540
case TypeKind::PackExpansion: {
55465541
auto expand = cast<PackExpansionType>(base);
5547-
struct ExpansionGatherer {
5548-
llvm::function_ref<Optional<Type>(TypeBase *, TypePosition)> baselineFn;
5549-
llvm::DenseMap<TypeBase *, PackType *> cache;
5550-
unsigned maxArity;
5551-
5552-
public:
5553-
ExpansionGatherer(
5554-
llvm::function_ref<Optional<Type>(TypeBase *, TypePosition)>
5555-
baselineFn)
5556-
: baselineFn(baselineFn), maxArity(0) {}
5557-
5558-
Optional<Type> operator()(TypeBase *input, TypePosition pos) {
5559-
auto remap = baselineFn(input, pos);
5560-
if (!remap) {
5561-
return remap;
5562-
}
5563-
5564-
if (input->is<TypeVariableType>() ||
5565-
input->isParameterPack() ||
5566-
input->is<PackArchetypeType>()) {
5567-
if (auto *PT = (*remap)->getAs<PackType>()) {
5568-
maxArity = std::max(maxArity, PT->getNumElements());
5569-
cache.insert({input, PT});
5570-
}
5571-
}
5572-
return remap;
5573-
}
55745542

5575-
std::pair<llvm::DenseMap<TypeBase *, PackType *>, unsigned>
5576-
intoExpansions() && {
5577-
return std::make_pair(cache, maxArity);
5578-
}
5579-
};
5580-
5581-
// First, substitute down the pattern type to gather the mapping from
5582-
// contained substitutable types to packs.
5583-
auto gather = ExpansionGatherer{fn};
55845543
Type transformedPat =
5585-
expand->getPatternType().transformWithPosition(pos, gather);
5544+
expand->getPatternType().transformWithPosition(pos, fn);
55865545
if (!transformedPat)
55875546
return Type();
55885547

55895548
Type transformedCount =
5590-
expand->getCountType().transformWithPosition(pos, gather);
5549+
expand->getCountType().transformWithPosition(pos, fn);
55915550
if (!transformedCount)
55925551
return Type();
55935552

55945553
if (transformedPat.getPointer() == expand->getPatternType().getPointer() &&
55955554
transformedCount.getPointer() == expand->getCountType().getPointer())
55965555
return *this;
55975556

5598-
llvm::DenseMap<TypeBase *, PackType *> expansions;
5599-
unsigned arity;
5600-
std::tie(expansions, arity) = std::move(gather).intoExpansions();
5601-
if (expansions.empty()) {
5602-
// If we didn't find any expansions, either the caller wasn't interested
5603-
// in expanding this pack, or something has gone wrong. Leave off the
5604-
// expansion and return the transformed type.
5605-
return PackExpansionType::get(transformedPat, transformedCount);
5606-
}
5607-
5608-
SmallVector<Type, 8> elts;
5609-
elts.reserve(arity);
5610-
// Perform the expansion element-wise according to the maximum arity we
5611-
// picked up during the gather step above.
5612-
//
5613-
// For a pack expansion (F<... T..., U..., ...>) and mapping
5614-
//
5615-
// T... -> <X, Y, Z>
5616-
// U... -> <A, B, C>
5617-
//
5618-
// The expected expansion is
5619-
//
5620-
// <F<... X, A, ...>, F<... Y, B, ...>, F<... Z, C, ...> ...>
5621-
for (unsigned i = 0; i < arity; ++i) {
5622-
struct ElementExpander {
5623-
const llvm::DenseMap<TypeBase *, PackType *> &expansions;
5624-
llvm::function_ref<Optional<Type>(TypeBase *, TypePosition)> outerFn;
5625-
unsigned index;
5626-
5627-
public:
5628-
Optional<Type> operator()(TypeBase *input, TypePosition pos) {
5629-
// FIXME: Does this need to do bounds checking?
5630-
if (PackType *element = expansions.lookup(input))
5631-
return element->getElementType(index);
5632-
return outerFn(input, pos);
5633-
}
5634-
};
5635-
5636-
auto expandedElt = expand->getPatternType().transformWithPosition(
5637-
pos, ElementExpander{expansions, fn, i});
5638-
if (!expandedElt)
5639-
return Type();
5640-
5641-
elts.push_back(expandedElt);
5642-
}
5643-
return PackType::get(base->getASTContext(), elts);
5557+
return PackExpansionType::get(transformedPat, transformedCount)->expand();
56445558
}
56455559

56465560
case TypeKind::Tuple: {
@@ -5670,26 +5584,14 @@ case TypeKind::Id:
56705584
anyChanged = true;
56715585
}
56725586

5673-
// "Splat" the elements of the transformed pack expansion into the tuple.
5674-
if (auto *transformedPack = transformedEltTy->getAs<PackType>()) {
5675-
auto transformedEltTypes = transformedPack->getElementTypes();
5676-
if (!transformedEltTypes.empty()) {
5677-
// Keep the label on the first element.
5678-
5679-
elements.push_back(elt.getWithType(transformedEltTypes.front()));
5680-
elements.append(transformedEltTypes.begin() + 1,
5681-
transformedEltTypes.end());
5682-
}
5683-
} else {
5684-
// Add the new tuple element, with the transformed type.
5685-
elements.push_back(elt.getWithType(transformedEltTy));
5686-
}
5587+
// Add the new tuple element, with the transformed type.
5588+
elements.push_back(elt.getWithType(transformedEltTy));
56875589
}
56885590

56895591
if (!anyChanged)
56905592
return *this;
56915593

5692-
return TupleType::get(elements, Ptr->getASTContext());
5594+
return TupleType::get(elements, Ptr->getASTContext())->flattenPackTypes();
56935595
}
56945596

56955597

@@ -5785,7 +5687,8 @@ case TypeKind::Id:
57855687
return GenericFunctionType::get(genericSig, substParams, resultTy);
57865688
return GenericFunctionType::get(genericSig, substParams, resultTy,
57875689
function->getExtInfo()
5788-
.withGlobalActor(globalActorType));
5690+
.withGlobalActor(globalActorType))
5691+
->flattenPackTypes();
57895692
}
57905693

57915694
if (isUnchanged) return *this;
@@ -5794,7 +5697,8 @@ case TypeKind::Id:
57945697
return FunctionType::get(substParams, resultTy);
57955698
return FunctionType::get(substParams, resultTy,
57965699
function->getExtInfo()
5797-
.withGlobalActor(globalActorType));
5700+
.withGlobalActor(globalActorType))
5701+
->flattenPackTypes();
57985702
}
57995703

58005704
case TypeKind::ArraySlice: {

lib/Sema/CSApply.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5865,8 +5865,10 @@ ArgumentList *ExprRewriter::coerceCallArguments(
58655865
if (!varargIndices.empty())
58665866
labelLoc = args->getLabelLoc(varargIndices[0]);
58675867

5868+
auto packExpansionType = param.getPlainType()->castTo<PackExpansionType>();
5869+
58685870
// Convert the arguments.
5869-
auto paramTuple = param.getPlainType()->castTo<PackType>();
5871+
auto paramTuple = packExpansionType->getPatternType()->castTo<PackType>();
58705872
for (auto varargIdx : indices(varargIndices)) {
58715873
auto argIdx = varargIndices[varargIdx];
58725874
auto *arg = args->getExpr(argIdx);

0 commit comments

Comments
 (0)