Skip to content

Implement the callee side of returning a tuple containing a pack expansion #64135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -2451,6 +2451,20 @@ BEGIN_CAN_TYPE_WRAPPER(TupleType, Type)
CanTupleEltTypeArrayRef getElementTypes() const {
return CanTupleEltTypeArrayRef(getPointer()->getElements());
}

bool containsPackExpansionType() const {
return containsPackExpansionTypeImpl(*this);
}

/// Induce a pack type from a range of the elements of this tuple type.
inline CanTypeWrapper<PackType>
getInducedPackType(unsigned start, unsigned count) const;

private:
static bool containsPackExpansionTypeImpl(CanTupleType tuple);

static CanTypeWrapper<PackType>
getInducedPackTypeImpl(CanTupleType tuple, unsigned start, unsigned count);
END_CAN_TYPE_WRAPPER(TupleType, Type)

/// UnboundGenericType - Represents a generic type where the type arguments have
Expand Down Expand Up @@ -4233,6 +4247,11 @@ class SILResultInfo {
return !isIndirectFormalResult(getConvention());
}

/// Is this a pack result? Pack results are always indirect.
bool isPack() const {
return getConvention() == ResultConvention::Pack;
}

/// Transform this SILResultInfo by applying the user-provided
/// function to its type.
///
Expand Down Expand Up @@ -4408,8 +4427,9 @@ class SILFunctionType final

// These are *normal* results if this is not a coroutine and *yield* results
// otherwise.
unsigned NumAnyResults : 16; // Not including the ErrorResult.
unsigned NumAnyIndirectFormalResults : 16; // Subset of NumAnyResults.
unsigned NumAnyResults; // Not including the ErrorResult.
unsigned NumAnyIndirectFormalResults; // Subset of NumAnyResults.
unsigned NumPackResults; // Subset of NumAnyIndirectFormalResults.

// [NOTE: SILFunctionType-layout]
// The layout of a SILFunctionType in memory is:
Expand Down Expand Up @@ -4589,6 +4609,9 @@ class SILFunctionType final
unsigned getNumDirectFormalResults() const {
return isCoroutine() ? 0 : NumAnyResults - NumAnyIndirectFormalResults;
}
unsigned getNumPackResults() const {
return isCoroutine() ? 0 : NumPackResults;
}

struct IndirectFormalResultFilter {
bool operator()(SILResultInfo result) const {
Expand Down Expand Up @@ -4618,6 +4641,21 @@ class SILFunctionType final
return llvm::make_filter_range(getResults(), DirectFormalResultFilter());
}

struct PackResultFilter {
bool operator()(SILResultInfo result) const {
return result.isPack();
}
};
using PackResultIter =
llvm::filter_iterator<const SILResultInfo *, PackResultFilter>;
using PackResultRange = iterator_range<PackResultIter>;

/// A range of SILResultInfo for all pack results. Pack results are also
/// included in the set of indirect results.
PackResultRange getPackResults() const {
return llvm::make_filter_range(getResults(), PackResultFilter());
}

/// Get a single non-address SILType that represents all formal direct
/// results. The actual SIL result type of an apply instruction that calls
/// this function depends on the current SIL stage and is known by
Expand All @@ -4636,6 +4674,9 @@ class SILFunctionType final
unsigned getNumDirectFormalYields() const {
return isCoroutine() ? NumAnyResults - NumAnyIndirectFormalResults : 0;
}
unsigned getNumPackYields() const {
return isCoroutine() ? NumPackResults : 0;
}

struct IndirectFormalYieldFilter {
bool operator()(SILYieldInfo yield) const {
Expand Down Expand Up @@ -6798,6 +6839,11 @@ BEGIN_CAN_TYPE_WRAPPER(PackType, Type)
}
END_CAN_TYPE_WRAPPER(PackType, Type)

inline CanPackType
CanTupleType::getInducedPackType(unsigned start, unsigned end) const {
return getInducedPackTypeImpl(*this, start, end);
}

/// PackExpansionType - The interface type of the explicit expansion of a
/// corresponding set of variadic generic parameters.
///
Expand Down
4 changes: 4 additions & 0 deletions include/swift/SIL/SILFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,10 @@ class SILFunction

SILType getLoweredType(Type t) const;

CanType getLoweredRValueType(Lowering::AbstractionPattern orig, Type subst) const;

CanType getLoweredRValueType(Type t) const;

SILType getLoweredLoadableType(Type t) const;

SILType getLoweredType(SILType t) const;
Expand Down
30 changes: 19 additions & 11 deletions include/swift/SIL/SILFunctionConventions.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,30 +214,38 @@ class SILFunctionConventions {

/// Get the number of SIL results passed as address-typed arguments.
unsigned getNumIndirectSILResults() const {
return silConv.loweredAddresses ? funcTy->getNumIndirectFormalResults() : 0;
// TODO: Return packs directly in lowered-address mode
return silConv.loweredAddresses ? funcTy->getNumIndirectFormalResults()
: funcTy->getNumPackResults();
}

/// Are any SIL results passed as address-typed arguments?
bool hasIndirectSILResults() const { return getNumIndirectSILResults() != 0; }

using IndirectSILResultIter = SILFunctionType::IndirectFormalResultIter;
using IndirectSILResultRange = SILFunctionType::IndirectFormalResultRange;
struct IndirectSILResultFilter {
bool loweredAddresses;
IndirectSILResultFilter(bool loweredAddresses)
: loweredAddresses(loweredAddresses) {}
bool operator()(SILResultInfo result) const {
return (loweredAddresses ? result.isFormalIndirect() : result.isPack());
}
};
using IndirectSILResultIter =
llvm::filter_iterator<const SILResultInfo *, IndirectSILResultFilter>;
using IndirectSILResultRange = iterator_range<IndirectSILResultIter>;

/// Return a range of indirect result information for results passed as
/// address-typed SIL arguments.
IndirectSILResultRange getIndirectSILResults() const {
if (silConv.loweredAddresses)
return funcTy->getIndirectFormalResults();

return llvm::make_filter_range(
llvm::make_range((const SILResultInfo *)0, (const SILResultInfo *)0),
SILFunctionType::IndirectFormalResultFilter());
funcTy->getResults(),
IndirectSILResultFilter(silConv.loweredAddresses));
}

struct SILResultTypeFunc;

// Gratuitous template parameter is to delay instantiating `mapped_iterator`
// on the incomplete type SILParameterTypeFunc.
// on the incomplete type SILResultTypeFunc.
template<bool _ = false>
using IndirectSILResultTypeIter = typename delay_template_expansion<_,
llvm::mapped_iterator, IndirectSILResultIter, SILResultTypeFunc>::type;
Expand All @@ -253,7 +261,7 @@ class SILFunctionConventions {
/// Get the number of SIL results directly returned by SIL value.
unsigned getNumDirectSILResults() const {
return silConv.loweredAddresses ? funcTy->getNumDirectFormalResults()
: funcTy->getNumResults();
: funcTy->getNumResults() - funcTy->getNumPackResults();
}

/// Like getNumDirectSILResults but @out tuples, which are not flattened in
Expand All @@ -266,7 +274,7 @@ class SILFunctionConventions {
DirectSILResultFilter(bool loweredAddresses)
: loweredAddresses(loweredAddresses) {}
bool operator()(SILResultInfo result) const {
return !(loweredAddresses && result.isFormalIndirect());
return (loweredAddresses ? !result.isFormalIndirect() : !result.isPack());
}
};
using DirectSILResultIter =
Expand Down
27 changes: 17 additions & 10 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4342,20 +4342,27 @@ SILFunctionType::SILFunctionType(
if (coroutineKind == SILCoroutineKind::None) {
assert(yields.empty());
NumAnyResults = normalResults.size();
NumAnyIndirectFormalResults =
std::count_if(normalResults.begin(), normalResults.end(),
[](const SILResultInfo &resultInfo) {
return resultInfo.isFormalIndirect();
});
NumAnyIndirectFormalResults = 0;
NumPackResults = 0;
for (auto &resultInfo : normalResults) {
if (resultInfo.isFormalIndirect())
NumAnyIndirectFormalResults++;
if (resultInfo.isPack())
NumPackResults++;
}
memcpy(getMutableResults().data(), normalResults.data(),
normalResults.size() * sizeof(SILResultInfo));
} else {
assert(normalResults.empty());
assert(normalResults.empty());
NumAnyResults = yields.size();
NumAnyIndirectFormalResults = std::count_if(
yields.begin(), yields.end(), [](const SILYieldInfo &yieldInfo) {
return yieldInfo.isFormalIndirect();
});
NumAnyIndirectFormalResults = 0;
NumPackResults = 0;
for (auto &yieldInfo : yields) {
if (yieldInfo.isFormalIndirect())
NumAnyIndirectFormalResults++;
if (yieldInfo.isPack())
NumPackResults++;
}
memcpy(getMutableYields().data(), yields.data(),
yields.size() * sizeof(SILYieldInfo));
}
Expand Down
23 changes: 23 additions & 0 deletions lib/AST/ParameterPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ bool TupleType::containsPackExpansionType() const {
return false;
}

bool CanTupleType::containsPackExpansionTypeImpl(CanTupleType tuple) {
for (auto eltType : tuple.getElementTypes()) {
if (isa<PackExpansionType>(eltType))
return true;
}

return false;
}

/// (W, {X, Y}..., Z) => (W, X, Y, Z)
Type TupleType::flattenPackTypes() {
bool anyChanged = false;
Expand Down Expand Up @@ -465,3 +474,17 @@ bool SILPackType::containsPackExpansionType() const {

return false;
}

CanPackType
CanTupleType::getInducedPackTypeImpl(CanTupleType tuple, unsigned start, unsigned count) {
assert(start + count <= tuple->getNumElements() && "range out of range");

auto &ctx = tuple->getASTContext();
if (count == 0) return CanPackType::get(ctx, {});

SmallVector<CanType, 4> eltTypes;
eltTypes.reserve(count);
for (unsigned i = start, e = start + count; i != e; ++i)
eltTypes.push_back(tuple.getElementType(i));
return CanPackType::get(ctx, eltTypes);
}
10 changes: 10 additions & 0 deletions lib/SIL/IR/SILFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,16 @@ SILType SILFunction::getLoweredType(Type t) const {
return getModule().Types.getLoweredType(t, TypeExpansionContext(*this));
}

CanType
SILFunction::getLoweredRValueType(AbstractionPattern orig, Type subst) const {
return getModule().Types.getLoweredRValueType(TypeExpansionContext(*this),
orig, subst);
}

CanType SILFunction::getLoweredRValueType(Type t) const {
return getModule().Types.getLoweredRValueType(TypeExpansionContext(*this), t);
}

SILType SILFunction::getLoweredLoadableType(Type t) const {
auto &M = getModule();
return M.Types.getLoweredLoadableType(t, TypeExpansionContext(*this), M);
Expand Down
32 changes: 29 additions & 3 deletions lib/SILGen/Conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,15 @@ class ConvertingInitialization final : public Initialization {
Finished,

/// The converted value has been extracted.
Extracted
Extracted,

/// We're doing pack initialization instead of the normal state
/// transition, and we haven't been finished yet.
PackExpanding,

/// We're doing pack initialization instead of the normal state
/// transition, and finishInitialization has been called.
FinishedPackExpanding,
};

StateTy State;
Expand Down Expand Up @@ -280,6 +288,7 @@ class ConvertingInitialization final : public Initialization {
FinalContext(SGFContext(subInitialization.get())) {
OwnedSubInitialization = std::move(subInitialization);
}


/// Return the conversion to apply to the unconverted value.
const Conversion &getConversion() const {
Expand Down Expand Up @@ -345,9 +354,26 @@ class ConvertingInitialization final : public Initialization {

// Bookkeeping.
void finishInitialization(SILGenFunction &SGF) override {
assert(getState() == Initialized);
State = Finished;
if (getState() == PackExpanding) {
FinalContext.getEmitInto()->finishInitialization(SGF);
State = FinishedPackExpanding;
} else {
assert(getState() == Initialized);
State = Finished;
}
}

// Support pack-expansion initialization.
bool canPerformPackExpansionInitialization() const override {
if (auto finalInit = FinalContext.getEmitInto())
return finalInit->canPerformPackExpansionInitialization();
return false;
}

void performPackExpansionInitialization(SILGenFunction &SGF,
SILLocation loc,
SILValue indexWithinComponent,
llvm::function_ref<void(Initialization *into)> fn) override;
};

} // end namespace Lowering
Expand Down
Loading