Skip to content

Commit d7123c7

Browse files
authored
Merge pull request #64135 from rjmccall/variadic-generic-callee-results
Implement the callee side of returning a tuple containing a pack expansion
2 parents 9c22762 + 157be34 commit d7123c7

20 files changed

+979
-74
lines changed

include/swift/AST/Types.h

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2451,6 +2451,20 @@ BEGIN_CAN_TYPE_WRAPPER(TupleType, Type)
24512451
CanTupleEltTypeArrayRef getElementTypes() const {
24522452
return CanTupleEltTypeArrayRef(getPointer()->getElements());
24532453
}
2454+
2455+
bool containsPackExpansionType() const {
2456+
return containsPackExpansionTypeImpl(*this);
2457+
}
2458+
2459+
/// Induce a pack type from a range of the elements of this tuple type.
2460+
inline CanTypeWrapper<PackType>
2461+
getInducedPackType(unsigned start, unsigned count) const;
2462+
2463+
private:
2464+
static bool containsPackExpansionTypeImpl(CanTupleType tuple);
2465+
2466+
static CanTypeWrapper<PackType>
2467+
getInducedPackTypeImpl(CanTupleType tuple, unsigned start, unsigned count);
24542468
END_CAN_TYPE_WRAPPER(TupleType, Type)
24552469

24562470
/// UnboundGenericType - Represents a generic type where the type arguments have
@@ -4233,6 +4247,11 @@ class SILResultInfo {
42334247
return !isIndirectFormalResult(getConvention());
42344248
}
42354249

4250+
/// Is this a pack result? Pack results are always indirect.
4251+
bool isPack() const {
4252+
return getConvention() == ResultConvention::Pack;
4253+
}
4254+
42364255
/// Transform this SILResultInfo by applying the user-provided
42374256
/// function to its type.
42384257
///
@@ -4408,8 +4427,9 @@ class SILFunctionType final
44084427

44094428
// These are *normal* results if this is not a coroutine and *yield* results
44104429
// otherwise.
4411-
unsigned NumAnyResults : 16; // Not including the ErrorResult.
4412-
unsigned NumAnyIndirectFormalResults : 16; // Subset of NumAnyResults.
4430+
unsigned NumAnyResults; // Not including the ErrorResult.
4431+
unsigned NumAnyIndirectFormalResults; // Subset of NumAnyResults.
4432+
unsigned NumPackResults; // Subset of NumAnyIndirectFormalResults.
44134433

44144434
// [NOTE: SILFunctionType-layout]
44154435
// The layout of a SILFunctionType in memory is:
@@ -4589,6 +4609,9 @@ class SILFunctionType final
45894609
unsigned getNumDirectFormalResults() const {
45904610
return isCoroutine() ? 0 : NumAnyResults - NumAnyIndirectFormalResults;
45914611
}
4612+
unsigned getNumPackResults() const {
4613+
return isCoroutine() ? 0 : NumPackResults;
4614+
}
45924615

45934616
struct IndirectFormalResultFilter {
45944617
bool operator()(SILResultInfo result) const {
@@ -4618,6 +4641,21 @@ class SILFunctionType final
46184641
return llvm::make_filter_range(getResults(), DirectFormalResultFilter());
46194642
}
46204643

4644+
struct PackResultFilter {
4645+
bool operator()(SILResultInfo result) const {
4646+
return result.isPack();
4647+
}
4648+
};
4649+
using PackResultIter =
4650+
llvm::filter_iterator<const SILResultInfo *, PackResultFilter>;
4651+
using PackResultRange = iterator_range<PackResultIter>;
4652+
4653+
/// A range of SILResultInfo for all pack results. Pack results are also
4654+
/// included in the set of indirect results.
4655+
PackResultRange getPackResults() const {
4656+
return llvm::make_filter_range(getResults(), PackResultFilter());
4657+
}
4658+
46214659
/// Get a single non-address SILType that represents all formal direct
46224660
/// results. The actual SIL result type of an apply instruction that calls
46234661
/// this function depends on the current SIL stage and is known by
@@ -4636,6 +4674,9 @@ class SILFunctionType final
46364674
unsigned getNumDirectFormalYields() const {
46374675
return isCoroutine() ? NumAnyResults - NumAnyIndirectFormalResults : 0;
46384676
}
4677+
unsigned getNumPackYields() const {
4678+
return isCoroutine() ? NumPackResults : 0;
4679+
}
46394680

46404681
struct IndirectFormalYieldFilter {
46414682
bool operator()(SILYieldInfo yield) const {
@@ -6798,6 +6839,11 @@ BEGIN_CAN_TYPE_WRAPPER(PackType, Type)
67986839
}
67996840
END_CAN_TYPE_WRAPPER(PackType, Type)
68006841

6842+
inline CanPackType
6843+
CanTupleType::getInducedPackType(unsigned start, unsigned end) const {
6844+
return getInducedPackTypeImpl(*this, start, end);
6845+
}
6846+
68016847
/// PackExpansionType - The interface type of the explicit expansion of a
68026848
/// corresponding set of variadic generic parameters.
68036849
///

include/swift/SIL/SILFunction.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,10 @@ class SILFunction
725725

726726
SILType getLoweredType(Type t) const;
727727

728+
CanType getLoweredRValueType(Lowering::AbstractionPattern orig, Type subst) const;
729+
730+
CanType getLoweredRValueType(Type t) const;
731+
728732
SILType getLoweredLoadableType(Type t) const;
729733

730734
SILType getLoweredType(SILType t) const;

include/swift/SIL/SILFunctionConventions.h

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -214,30 +214,38 @@ class SILFunctionConventions {
214214

215215
/// Get the number of SIL results passed as address-typed arguments.
216216
unsigned getNumIndirectSILResults() const {
217-
return silConv.loweredAddresses ? funcTy->getNumIndirectFormalResults() : 0;
217+
// TODO: Return packs directly in lowered-address mode
218+
return silConv.loweredAddresses ? funcTy->getNumIndirectFormalResults()
219+
: funcTy->getNumPackResults();
218220
}
219221

220222
/// Are any SIL results passed as address-typed arguments?
221223
bool hasIndirectSILResults() const { return getNumIndirectSILResults() != 0; }
222224

223-
using IndirectSILResultIter = SILFunctionType::IndirectFormalResultIter;
224-
using IndirectSILResultRange = SILFunctionType::IndirectFormalResultRange;
225+
struct IndirectSILResultFilter {
226+
bool loweredAddresses;
227+
IndirectSILResultFilter(bool loweredAddresses)
228+
: loweredAddresses(loweredAddresses) {}
229+
bool operator()(SILResultInfo result) const {
230+
return (loweredAddresses ? result.isFormalIndirect() : result.isPack());
231+
}
232+
};
233+
using IndirectSILResultIter =
234+
llvm::filter_iterator<const SILResultInfo *, IndirectSILResultFilter>;
235+
using IndirectSILResultRange = iterator_range<IndirectSILResultIter>;
225236

226237
/// Return a range of indirect result information for results passed as
227238
/// address-typed SIL arguments.
228239
IndirectSILResultRange getIndirectSILResults() const {
229-
if (silConv.loweredAddresses)
230-
return funcTy->getIndirectFormalResults();
231-
232240
return llvm::make_filter_range(
233-
llvm::make_range((const SILResultInfo *)0, (const SILResultInfo *)0),
234-
SILFunctionType::IndirectFormalResultFilter());
241+
funcTy->getResults(),
242+
IndirectSILResultFilter(silConv.loweredAddresses));
235243
}
236244

237245
struct SILResultTypeFunc;
238246

239247
// Gratuitous template parameter is to delay instantiating `mapped_iterator`
240-
// on the incomplete type SILParameterTypeFunc.
248+
// on the incomplete type SILResultTypeFunc.
241249
template<bool _ = false>
242250
using IndirectSILResultTypeIter = typename delay_template_expansion<_,
243251
llvm::mapped_iterator, IndirectSILResultIter, SILResultTypeFunc>::type;
@@ -253,7 +261,7 @@ class SILFunctionConventions {
253261
/// Get the number of SIL results directly returned by SIL value.
254262
unsigned getNumDirectSILResults() const {
255263
return silConv.loweredAddresses ? funcTy->getNumDirectFormalResults()
256-
: funcTy->getNumResults();
264+
: funcTy->getNumResults() - funcTy->getNumPackResults();
257265
}
258266

259267
/// Like getNumDirectSILResults but @out tuples, which are not flattened in
@@ -266,7 +274,7 @@ class SILFunctionConventions {
266274
DirectSILResultFilter(bool loweredAddresses)
267275
: loweredAddresses(loweredAddresses) {}
268276
bool operator()(SILResultInfo result) const {
269-
return !(loweredAddresses && result.isFormalIndirect());
277+
return (loweredAddresses ? !result.isFormalIndirect() : !result.isPack());
270278
}
271279
};
272280
using DirectSILResultIter =

lib/AST/ASTContext.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4343,20 +4343,27 @@ SILFunctionType::SILFunctionType(
43434343
if (coroutineKind == SILCoroutineKind::None) {
43444344
assert(yields.empty());
43454345
NumAnyResults = normalResults.size();
4346-
NumAnyIndirectFormalResults =
4347-
std::count_if(normalResults.begin(), normalResults.end(),
4348-
[](const SILResultInfo &resultInfo) {
4349-
return resultInfo.isFormalIndirect();
4350-
});
4346+
NumAnyIndirectFormalResults = 0;
4347+
NumPackResults = 0;
4348+
for (auto &resultInfo : normalResults) {
4349+
if (resultInfo.isFormalIndirect())
4350+
NumAnyIndirectFormalResults++;
4351+
if (resultInfo.isPack())
4352+
NumPackResults++;
4353+
}
43514354
memcpy(getMutableResults().data(), normalResults.data(),
43524355
normalResults.size() * sizeof(SILResultInfo));
43534356
} else {
4354-
assert(normalResults.empty());
4357+
assert(normalResults.empty());
43554358
NumAnyResults = yields.size();
4356-
NumAnyIndirectFormalResults = std::count_if(
4357-
yields.begin(), yields.end(), [](const SILYieldInfo &yieldInfo) {
4358-
return yieldInfo.isFormalIndirect();
4359-
});
4359+
NumAnyIndirectFormalResults = 0;
4360+
NumPackResults = 0;
4361+
for (auto &yieldInfo : yields) {
4362+
if (yieldInfo.isFormalIndirect())
4363+
NumAnyIndirectFormalResults++;
4364+
if (yieldInfo.isPack())
4365+
NumPackResults++;
4366+
}
43604367
memcpy(getMutableYields().data(), yields.data(),
43614368
yields.size() * sizeof(SILYieldInfo));
43624369
}

lib/AST/ParameterPack.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,15 @@ bool TupleType::containsPackExpansionType() const {
165165
return false;
166166
}
167167

168+
bool CanTupleType::containsPackExpansionTypeImpl(CanTupleType tuple) {
169+
for (auto eltType : tuple.getElementTypes()) {
170+
if (isa<PackExpansionType>(eltType))
171+
return true;
172+
}
173+
174+
return false;
175+
}
176+
168177
/// (W, {X, Y}..., Z) => (W, X, Y, Z)
169178
Type TupleType::flattenPackTypes() {
170179
bool anyChanged = false;
@@ -465,3 +474,17 @@ bool SILPackType::containsPackExpansionType() const {
465474

466475
return false;
467476
}
477+
478+
CanPackType
479+
CanTupleType::getInducedPackTypeImpl(CanTupleType tuple, unsigned start, unsigned count) {
480+
assert(start + count <= tuple->getNumElements() && "range out of range");
481+
482+
auto &ctx = tuple->getASTContext();
483+
if (count == 0) return CanPackType::get(ctx, {});
484+
485+
SmallVector<CanType, 4> eltTypes;
486+
eltTypes.reserve(count);
487+
for (unsigned i = start, e = start + count; i != e; ++i)
488+
eltTypes.push_back(tuple.getElementType(i));
489+
return CanPackType::get(ctx, eltTypes);
490+
}

lib/SIL/IR/SILFunction.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,16 @@ SILType SILFunction::getLoweredType(Type t) const {
472472
return getModule().Types.getLoweredType(t, TypeExpansionContext(*this));
473473
}
474474

475+
CanType
476+
SILFunction::getLoweredRValueType(AbstractionPattern orig, Type subst) const {
477+
return getModule().Types.getLoweredRValueType(TypeExpansionContext(*this),
478+
orig, subst);
479+
}
480+
481+
CanType SILFunction::getLoweredRValueType(Type t) const {
482+
return getModule().Types.getLoweredRValueType(TypeExpansionContext(*this), t);
483+
}
484+
475485
SILType SILFunction::getLoweredLoadableType(Type t) const {
476486
auto &M = getModule();
477487
return M.Types.getLoweredLoadableType(t, TypeExpansionContext(*this), M);

lib/SILGen/Conversion.h

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,15 @@ class ConvertingInitialization final : public Initialization {
250250
Finished,
251251

252252
/// The converted value has been extracted.
253-
Extracted
253+
Extracted,
254+
255+
/// We're doing pack initialization instead of the normal state
256+
/// transition, and we haven't been finished yet.
257+
PackExpanding,
258+
259+
/// We're doing pack initialization instead of the normal state
260+
/// transition, and finishInitialization has been called.
261+
FinishedPackExpanding,
254262
};
255263

256264
StateTy State;
@@ -280,6 +288,7 @@ class ConvertingInitialization final : public Initialization {
280288
FinalContext(SGFContext(subInitialization.get())) {
281289
OwnedSubInitialization = std::move(subInitialization);
282290
}
291+
283292

284293
/// Return the conversion to apply to the unconverted value.
285294
const Conversion &getConversion() const {
@@ -345,9 +354,26 @@ class ConvertingInitialization final : public Initialization {
345354

346355
// Bookkeeping.
347356
void finishInitialization(SILGenFunction &SGF) override {
348-
assert(getState() == Initialized);
349-
State = Finished;
357+
if (getState() == PackExpanding) {
358+
FinalContext.getEmitInto()->finishInitialization(SGF);
359+
State = FinishedPackExpanding;
360+
} else {
361+
assert(getState() == Initialized);
362+
State = Finished;
363+
}
350364
}
365+
366+
// Support pack-expansion initialization.
367+
bool canPerformPackExpansionInitialization() const override {
368+
if (auto finalInit = FinalContext.getEmitInto())
369+
return finalInit->canPerformPackExpansionInitialization();
370+
return false;
371+
}
372+
373+
void performPackExpansionInitialization(SILGenFunction &SGF,
374+
SILLocation loc,
375+
SILValue indexWithinComponent,
376+
llvm::function_ref<void(Initialization *into)> fn) override;
351377
};
352378

353379
} // end namespace Lowering

0 commit comments

Comments
 (0)