Skip to content

SILGen and SIL type lowering support for vanishing tuples #64887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 4, 2023
20 changes: 15 additions & 5 deletions include/swift/SIL/AbstractionPattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -1283,8 +1283,9 @@ class AbstractionPattern {
}

/// Is the given tuple type a valid substitution of this abstraction
/// pattern?
bool matchesTuple(CanTupleType substType) const;
/// pattern? Note that the type doesn't have to be a tuple type in the
/// case of a vanishing tuple.
bool matchesTuple(CanType substType) const;

bool isTuple() const {
switch (getKind()) {
Expand Down Expand Up @@ -1344,6 +1345,14 @@ class AbstractionPattern {

bool doesTupleContainPackExpansionType() const;

/// If this type is a tuple type that vanishes (is flattened to its
/// singleton non-expansion element) under the stored substitutions,
/// return the abstraction pattern of the surviving element.
///
/// If the surviving element came from an expansion element, the
/// returned element is the pattern type of the expansion.
Optional<AbstractionPattern> getVanishingTupleElementPatternType() const;

static AbstractionPattern
projectTupleElementType(const AbstractionPattern *base, size_t index) {
return base->getTupleElementType(index);
Expand All @@ -1360,8 +1369,9 @@ class AbstractionPattern {
/// original type and how many elements of the substituted type they
/// expand to.
///
/// This pattern must be a tuple pattern.
void forEachTupleElement(CanTupleType substType,
/// This pattern must be a tuple pattern. The substituted type may be
/// a non-tuple only if this is a vanshing tuple pattern.
void forEachTupleElement(CanType substType,
llvm::function_ref<void(TupleElementGenerator &element)> fn) const;

/// Perform a parallel visitation of the elements of a tuple type,
Expand All @@ -1372,7 +1382,7 @@ class AbstractionPattern {
///
/// This pattern must match the substituted type, but it may be an
/// opaque pattern.
void forEachExpandedTupleElement(CanTupleType substType,
void forEachExpandedTupleElement(CanType substType,
llvm::function_ref<void(AbstractionPattern origEltType,
CanType substEltType,
const TupleTypeElt &elt)> handleElement) const;
Expand Down
49 changes: 40 additions & 9 deletions include/swift/SIL/AbstractionPatternGenerators.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,13 @@ class FunctionParamGenerator {
/// to the current orig parameter.
unsigned getSubstIndex() const {
assert(!isFinished());
return origParamIndex;
return substParamIndex;
}

IntRange<unsigned> getSubstIndexRange() const {
assert(!isFinished());
return IntRange<unsigned>(substParamIndex,
substParamIndex + numSubstParamsForOrigParam);
}

/// Return the parameter flags for the current orig parameter.
Expand Down Expand Up @@ -157,8 +163,9 @@ class TupleElementGenerator {
/// during construction.
AbstractionPattern origTupleType;

/// The substitute tuple type. Set once during construction.
CanTupleType substTupleType;
/// The substituted type. A tuple type unless this is a vanishing
/// tuple. Set once during construction.
CanType substType;

/// The number of orig elements to traverse. Set once during
/// construction.
Expand All @@ -176,6 +183,10 @@ class TupleElementGenerator {
/// orig element.
unsigned numSubstEltsForOrigElt;

/// Whether the orig tuple type is a vanishing tuple, i.e. substitution
/// turns it into a singleton element.
bool origTupleVanishes;

/// Whether the orig tuple type is opaque, i.e. does not permit us to
/// call getNumTupleElements() and similar accessors. Set once during
/// construction.
Expand All @@ -189,6 +200,9 @@ class TupleElementGenerator {
/// pattern type.
AbstractionPattern origEltType = AbstractionPattern::getInvalid();

/// A scratch element that is used for vanishing tuple types.
mutable TupleTypeElt scratchSubstElt;

/// Load the informaton for the current orig element into the
/// fields above for it.
void loadElement() {
Expand All @@ -202,7 +216,7 @@ class TupleElementGenerator {

public:
TupleElementGenerator(AbstractionPattern origTupleType,
CanTupleType substTupleType);
CanType substType);

/// Is the traversal finished? If so, none of the getters below
/// are allowed to be called.
Expand All @@ -228,14 +242,22 @@ class TupleElementGenerator {
/// to the current orig element.
unsigned getSubstIndex() const {
assert(!isFinished());
return origEltIndex;
return substEltIndex;
}

IntRange<unsigned> getSubstIndexRange() const {
assert(!isFinished());
return IntRange<unsigned>(substEltIndex,
substEltIndex + numSubstEltsForOrigElt);
}

/// Return a tuple element for the current orig element.
TupleTypeElt getOrigElement() const {
assert(!isFinished());
// If the orig tuple is opaque, it can't have vanished, so this
// cast of substType is okay.
return (origTupleTypeIsOpaque
? substTupleType->getElement(substEltIndex)
? cast<TupleType>(substType)->getElement(substEltIndex)
: cast<TupleType>(origTupleType.getType())
->getElement(origEltIndex));
}
Expand All @@ -257,15 +279,24 @@ class TupleElementGenerator {
/// pack expansion, this will have exactly one element.
CanTupleEltTypeArrayRef getSubstTypes() const {
assert(!isFinished());
return substTupleType.getElementTypes().slice(substEltIndex,
numSubstEltsForOrigElt);
if (!origTupleVanishes) {
return cast<TupleType>(substType)
.getElementTypes().slice(substEltIndex,
numSubstEltsForOrigElt);
} else if (numSubstEltsForOrigElt == 0) {
return CanTupleEltTypeArrayRef();
} else {
scratchSubstElt = TupleTypeElt(substType);
return CanTupleEltTypeArrayRef(scratchSubstElt);
}
}

/// Call this to finalize the traversal and assert that it was done
/// properly.
void finish() {
assert(isFinished() && "didn't finish the traversal");
assert(substEltIndex == substTupleType->getNumElements() &&
assert(substEltIndex == (origTupleVanishes ? 1 :
cast<TupleType>(substType)->getNumElements()) &&
"didn't exhaust subst elements; possible missing subs on "
"orig tuple type");
}
Expand Down
104 changes: 90 additions & 14 deletions lib/SIL/IR/AbstractionPattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ LayoutConstraint AbstractionPattern::getLayoutConstraint() const {
}
}

bool AbstractionPattern::matchesTuple(CanTupleType substType) const {
bool AbstractionPattern::matchesTuple(CanType substType) const {
switch (getKind()) {
case Kind::Invalid:
llvm_unreachable("querying invalid abstraction pattern!");
Expand Down Expand Up @@ -311,11 +311,19 @@ bool AbstractionPattern::matchesTuple(CanTupleType substType) const {
return false;
LLVM_FALLTHROUGH;
case Kind::Tuple: {
if (getVanishingTupleElementPatternType()) {
// TODO: recurse into elements.
return true;
}

auto substTupleType = dyn_cast<TupleType>(substType);
if (!substTupleType) return false;

size_t nextSubstIndex = 0;
auto nextComponentIsAcceptable = [&](bool isPackExpansion) -> bool {
if (nextSubstIndex == substType->getNumElements())
if (nextSubstIndex == substTupleType->getNumElements())
return false;
auto substComponentType = substType.getElementType(nextSubstIndex++);
auto substComponentType = substTupleType.getElementType(nextSubstIndex++);
return (isPackExpansion == isa<PackExpansionType>(substComponentType));
};
for (auto elt : getTupleElementTypes()) {
Expand All @@ -333,7 +341,7 @@ bool AbstractionPattern::matchesTuple(CanTupleType substType) const {
return false;
}
}
return nextSubstIndex == substType->getNumElements();
return nextSubstIndex == substTupleType->getNumElements();
}
}
llvm_unreachable("bad kind");
Expand Down Expand Up @@ -469,7 +477,63 @@ bool AbstractionPattern::doesTupleContainPackExpansionType() const {
llvm_unreachable("bad kind");
}

void AbstractionPattern::forEachTupleElement(CanTupleType substType,
Optional<AbstractionPattern>
AbstractionPattern::getVanishingTupleElementPatternType() const {
if (!isTuple()) return None;
if (!GenericSubs) return None;

// Substitution causes tuples to vanish when substituting the elements
// produces a singleton tuple and it didn't start that way.

auto numOrigElts = getNumTupleElements();

// Track whether we've found a single element.
Optional<AbstractionPattern> singletonEltType;
bool hadOrigExpansion = false;
for (auto index : range(numOrigElts)) {
auto eltType = getTupleElementType(index);

// If this pattern isn't a pack expansion, we've got a new candidate
// singleton. If this is the second such candidate, of course, it's
// not a singleton.
if (!eltType.isPackExpansion()) {
if (singletonEltType) return None;
singletonEltType = eltType;

// Otherwise, check what the expansion shape expands to.
} else {
hadOrigExpansion = true;

auto expansionType = cast<PackExpansionType>(eltType.getType());
auto substShape = cast<PackType>(
expansionType.getCountType().subst(GenericSubs)->getCanonicalType());
auto expansionCount = substShape->getNumElements();

// If it expands to multiple elements or to a single expansion, we
// won't have a singleton tuple. If it expands to a single scalar
// element, this is a singleton candidate.
if (expansionCount > 1) {
return None;
} else if (expansionCount == 1) {
auto substExpansion =
dyn_cast<PackExpansionType>(substShape.getElementType(0));
if (substExpansion)
return None;
if (singletonEltType)
return None;
singletonEltType = eltType.getPackExpansionPatternType();
}
}
}

// If we found a singleton scalar element, and we didn't start with
// a singleton element, that's the index we want to return.
if (singletonEltType && !(numOrigElts == 1 && !hadOrigExpansion))
return singletonEltType;
return None;
}

void AbstractionPattern::forEachTupleElement(CanType substType,
llvm::function_ref<void(TupleElementGenerator &)> handleElement) const {
TupleElementGenerator elt(*this, substType);
for (; !elt.isFinished(); elt.advance()) {
Expand All @@ -480,35 +544,46 @@ void AbstractionPattern::forEachTupleElement(CanTupleType substType,

TupleElementGenerator::TupleElementGenerator(
AbstractionPattern origTupleType,
CanTupleType substTupleType)
: origTupleType(origTupleType), substTupleType(substTupleType) {
CanType substType)
: origTupleType(origTupleType), substType(substType) {
assert(origTupleType.isTuple());
assert(origTupleType.matchesTuple(substTupleType));
assert(origTupleType.matchesTuple(substType));

origTupleVanishes =
origTupleType.getVanishingTupleElementPatternType().hasValue();
origTupleTypeIsOpaque = origTupleType.isOpaqueTuple();
numOrigElts = origTupleType.getNumTupleElements();

if (!isFinished()) loadElement();
}

void AbstractionPattern::forEachExpandedTupleElement(CanTupleType substType,
void AbstractionPattern::forEachExpandedTupleElement(CanType substType,
llvm::function_ref<void(AbstractionPattern origEltType,
CanType substEltType,
const TupleTypeElt &elt)>
handleElement) const {
assert(matchesTuple(substType));

auto substEltTypes = substType.getElementTypes();

// Handle opaque patterns by just iterating the substituted components.
if (!isTuple()) {
auto substTupleType = cast<TupleType>(substType);
auto substEltTypes = substTupleType.getElementTypes();
for (auto i : indices(substEltTypes)) {
handleElement(getTupleElementType(i), substEltTypes[i],
substType->getElement(i));
substTupleType->getElement(i));
}
return;
}

// For vanishing tuples, just call the callback once.
if (auto origEltType = getVanishingTupleElementPatternType()) {
handleElement(*origEltType, substType, TupleTypeElt(substType));
return;
}

auto substTupleType = cast<TupleType>(substType);
auto substEltTypes = substTupleType.getElementTypes();

// For non-opaque patterns, we have to iterate the original components
// in order to match things up properly, but we'll still end up calling
// once per substituted element.
Expand All @@ -517,7 +592,7 @@ void AbstractionPattern::forEachExpandedTupleElement(CanTupleType substType,
auto origEltType = getTupleElementType(origEltIndex);
if (!origEltType.isPackExpansion()) {
handleElement(origEltType, substEltTypes[substEltIndex],
substType->getElement(substEltIndex));
substTupleType->getElement(substEltIndex));
substEltIndex++;
} else {
auto origPatternType = origEltType.getPackExpansionPatternType();
Expand All @@ -532,7 +607,8 @@ void AbstractionPattern::forEachExpandedTupleElement(CanTupleType substType,
// be misleading in one way or another.
handleElement(isa<PackExpansionType>(substEltType)
? origEltType : origPatternType,
substEltType, substType->getElement(substEltIndex));
substEltType,
substTupleType->getElement(substEltIndex));
substEltIndex++;
}
}
Expand Down
7 changes: 3 additions & 4 deletions lib/SIL/IR/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1270,8 +1270,7 @@ class DestructureResults {
void destructure(AbstractionPattern origType, CanType substType) {
// Recur into tuples.
if (origType.isTuple()) {
auto substTupleType = cast<TupleType>(substType);
origType.forEachTupleElement(substTupleType,
origType.forEachTupleElement(substType,
[&](TupleElementGenerator &elt) {
// If the original element type is not a pack expansion, just
// pull off the next substituted element type.
Expand Down Expand Up @@ -1646,7 +1645,7 @@ class DestructureInputs {

// Tuples get expanded unless they're inout.
if (origType.isTuple() && ownership != ValueOwnership::InOut) {
expandTuple(ownership, forSelf, origType, cast<TupleType>(substType),
expandTuple(ownership, forSelf, origType, substType,
isNonDifferentiable);
return;
}
Expand Down Expand Up @@ -1683,7 +1682,7 @@ class DestructureInputs {

/// Recursively expand a tuple type into separate parameters.
void expandTuple(ValueOwnership ownership, bool forSelf,
AbstractionPattern origType, CanTupleType substType,
AbstractionPattern origType, CanType substType,
bool isNonDifferentiable) {
assert(ownership != ValueOwnership::InOut);
assert(origType.isTuple());
Expand Down
Loading