Skip to content

[Sema] Implement basic type checking for pack expansion expressions. #61678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions include/swift/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,11 @@ class alignas(8) Expr : public ASTAllocated<Expr> {
IsObjC : 1
);

SWIFT_INLINE_BITFIELD_FULL(PackExpansionExpr, Expr, 32,
: NumPadBits,
NumBindings : 32
);

SWIFT_INLINE_BITFIELD_FULL(SequenceExpr, Expr, 32,
: NumPadBits,
NumElements : 32
Expand Down Expand Up @@ -3519,20 +3524,54 @@ class VarargExpansionExpr : public Expr {
/// that naturally accept a comma-separated list of values, including
/// call argument lists, the elements of a tuple value, and the source
/// of a for-in loop.
class PackExpansionExpr: public Expr {
class PackExpansionExpr final : public Expr,
private llvm::TrailingObjects<PackExpansionExpr,
OpaqueValueExpr *, Expr *> {
friend TrailingObjects;

Expr *PatternExpr;
SourceLoc DotsLoc;
GenericEnvironment *Environment;

PackExpansionExpr(Expr *patternExpr,
ArrayRef<OpaqueValueExpr *> opaqueValues,
ArrayRef<Expr *> bindings,
SourceLoc dotsLoc,
GenericEnvironment *environment,
bool implicit, Type type)
: Expr(ExprKind::PackExpansion, implicit, type),
PatternExpr(patternExpr), DotsLoc(dotsLoc) {}
PatternExpr(patternExpr), DotsLoc(dotsLoc), Environment(environment) {
assert(opaqueValues.size() == bindings.size());
Bits.PackExpansionExpr.NumBindings = opaqueValues.size();

assert(Bits.PackExpansionExpr.NumBindings > 0 &&
"PackExpansionExpr must have pack references");

std::uninitialized_copy(opaqueValues.begin(), opaqueValues.end(),
getTrailingObjects<OpaqueValueExpr *>());
std::uninitialized_copy(bindings.begin(), bindings.end(),
getTrailingObjects<Expr *>());
}

size_t numTrailingObjects(OverloadToken<OpaqueValueExpr *>) const {
return getNumBindings();
}

size_t numTrailingObjects(OverloadToken<Expr *>) const {
return getNumBindings();
}

MutableArrayRef<Expr *> getMutableBindings() {
return {getTrailingObjects<Expr *>(), getNumBindings()};
}

public:
static PackExpansionExpr *create(ASTContext &ctx,
Expr *patternExpr,
ArrayRef<OpaqueValueExpr *> opaqueValues,
ArrayRef<Expr *> bindings,
SourceLoc dotsLoc,
GenericEnvironment *environment,
bool implicit = false,
Type type = Type());

Expand All @@ -3542,6 +3581,26 @@ class PackExpansionExpr: public Expr {
PatternExpr = patternExpr;
}

unsigned getNumBindings() const {
return Bits.PackExpansionExpr.NumBindings;
}

ArrayRef<OpaqueValueExpr *> getOpaqueValues() {
return {getTrailingObjects<OpaqueValueExpr *>(), getNumBindings()};
}

ArrayRef<Expr *> getBindings() {
return {getTrailingObjects<Expr *>(), getNumBindings()};
}

void setBinding(unsigned i, Expr *e) {
getMutableBindings()[i] = e;
}

GenericEnvironment *getGenericEnvironment() {
return Environment;
}

SourceLoc getStartLoc() const {
return PatternExpr->getStartLoc();
}
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/GenericSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ class GenericSignature {

RequiredProtocols protos;
LayoutConstraint layout;

Type packShape;
};

private:
Expand Down
4 changes: 4 additions & 0 deletions include/swift/AST/Identifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ class Identifier {
return is("??");
}

bool isExpansionOperator() const {
return is("...");
}

/// isOperatorStartCodePoint - Return true if the specified code point is a
/// valid start of an operator.
static bool isOperatorStartCodePoint(uint32_t C) {
Expand Down
15 changes: 12 additions & 3 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -6009,11 +6009,17 @@ BEGIN_CAN_TYPE_WRAPPER(OpenedArchetypeType, ArchetypeType)
}
END_CAN_TYPE_WRAPPER(OpenedArchetypeType, ArchetypeType)

/// A wrapper around a shape type to use in ArchetypeTrailingObjects
/// for PackArchetypeType.
struct PackShape {
Type shapeType;
};

/// An archetype that represents an opaque element of a type
/// parameter pack in context.
class PackArchetypeType final
: public ArchetypeType,
private ArchetypeTrailingObjects<PackArchetypeType> {
private ArchetypeTrailingObjects<PackArchetypeType, PackShape> {
friend TrailingObjects;
friend ArchetypeType;

Expand All @@ -6024,18 +6030,21 @@ class PackArchetypeType final
/// by this routine.
static CanTypeWrapper<PackArchetypeType>
get(const ASTContext &Ctx, GenericEnvironment *GenericEnv,
Type InterfaceType,
Type InterfaceType, Type ShapeType,
SmallVectorImpl<ProtocolDecl *> &ConformsTo, Type Superclass,
LayoutConstraint Layout);

// Returns the reduced shape type for this pack archetype.
Type getShape() const;

static bool classof(const TypeBase *T) {
return T->getKind() == TypeKind::PackArchetype;
}

private:
PackArchetypeType(const ASTContext &Ctx, GenericEnvironment *GenericEnv,
Type InterfaceType, ArrayRef<ProtocolDecl *> ConformsTo,
Type Superclass, LayoutConstraint Layout);
Type Superclass, LayoutConstraint Layout, PackShape Shape);
};
BEGIN_CAN_TYPE_WRAPPER(PackArchetypeType, ArchetypeType)
END_CAN_TYPE_WRAPPER(PackArchetypeType, ArchetypeType)
Expand Down
4 changes: 4 additions & 0 deletions include/swift/Sema/Constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ enum class ConstraintKind : char {
/// Represents an AST node contained in a body of a function/closure.
/// It only has an AST node to generate constraints and infer the type for.
SyntacticElement,
/// The first type is the opened pack element type of the second type, which
/// is the pattern of a pack expansion type.
PackElementOf,
/// Do not add new uses of this, it only exists to retain compatibility for
/// rdar://85263844.
///
Expand Down Expand Up @@ -686,6 +689,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
case ConstraintKind::OneWayBindParam:
case ConstraintKind::DefaultClosureType:
case ConstraintKind::UnresolvedMemberChainBase:
case ConstraintKind::PackElementOf:
return ConstraintClassification::Relational;

case ConstraintKind::ValueMember:
Expand Down
12 changes: 12 additions & 0 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -5501,6 +5501,18 @@ class ConstraintSystem {
TypeMatchOptions flags,
ConstraintLocatorBuilder locator);

/// Attempt to simplify a PackElementOf constraint.
///
/// Solving this constraint is delayed until the element type is fully
/// resolved with no type variables. The element type is then mapped out
/// of the opened element context and into the context of the surrounding
/// function, effecively substituting opened element archetypes with their
/// corresponding pack archetypes, and bound to the second type.
SolutionKind
simplifyPackElementOfConstraint(Type first, Type second,
TypeMatchOptions flags,
ConstraintLocatorBuilder locator);

/// Attempt to simplify the ApplicableFunction constraint.
SolutionKind simplifyApplicableFnConstraint(
Type type1, Type type2,
Expand Down
13 changes: 2 additions & 11 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3136,15 +3136,6 @@ TupleType *TupleType::get(ArrayRef<TupleTypeElt> Fields, const ASTContext &C) {
properties |= eltTy->getRecursiveProperties();
}

// Enforce an invariant.
for (unsigned i = 0, e = Fields.size(); i < e; ++i) {
if (Fields[i].getType()->is<PackExpansionType>()) {
assert(i == e - 1 || Fields[i + 1].hasName() &&
"Tuple element with pack expansion type cannot be followed "
"by an unlabeled element");
}
}

auto arena = getArena(properties);

void *InsertPos = nullptr;
Expand Down Expand Up @@ -5611,8 +5602,8 @@ ASTContext::getOpenedElementSignature(CanGenericSignature baseGenericSig) {

auto eraseParameterPack = [&](GenericTypeParamType *paramType) {
return GenericTypeParamType::get(
paramType->getDepth(), paramType->getIndex(),
/*isParameterPack=*/false, *this);
/*isParameterPack=*/false, paramType->getDepth(),
paramType->getIndex(), *this);
};

for (auto paramType : baseGenericSig.getGenericParams()) {
Expand Down
30 changes: 30 additions & 0 deletions lib/AST/ASTVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,36 @@ class Verifier : public ASTWalker {
OpenedExistentialArchetypes.erase(expr->getOpenedArchetype());
}

bool shouldVerify(PackExpansionExpr *expr) {
if (!shouldVerify(cast<Expr>(expr)))
return false;

Generics.push_back(expr->getGenericEnvironment()->getGenericSignature());

for (auto *placeholder : expr->getOpaqueValues()) {
assert(!OpaqueValues.count(placeholder));
OpaqueValues[placeholder] = 0;
}

return true;
}

void verifyCheckedAlways(PackExpansionExpr *E) {
// Remove the element generic environment before verifying
// the pack expansion type, which contains pack archetypes.
assert(Generics.back().get<GenericSignature>().getPointer() ==
E->getGenericEnvironment()->getGenericSignature().getPointer());
Generics.pop_back();
verifyCheckedAlwaysBase(E);
}

void cleanup(PackExpansionExpr *expr) {
for (auto *placeholder : expr->getOpaqueValues()) {
assert(OpaqueValues.count(placeholder));
OpaqueValues.erase(placeholder);
}
}

bool shouldVerify(MakeTemporarilyEscapableExpr *expr) {
if (!shouldVerify(cast<Expr>(expr)))
return false;
Expand Down
15 changes: 11 additions & 4 deletions lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1239,10 +1239,17 @@ VarargExpansionExpr *VarargExpansionExpr::createArrayExpansion(ASTContext &ctx,

PackExpansionExpr *
PackExpansionExpr::create(ASTContext &ctx, Expr *patternExpr,
SourceLoc dotsLoc, bool implicit,
Type type) {
return new (ctx) PackExpansionExpr(patternExpr, dotsLoc,
implicit, type);
ArrayRef<OpaqueValueExpr *> opaqueValues,
ArrayRef<Expr *> bindings, SourceLoc dotsLoc,
GenericEnvironment *environment,
bool implicit, Type type) {
size_t size =
totalSizeToAlloc<OpaqueValueExpr *, Expr *>(opaqueValues.size(),
bindings.size());
void *mem = ctx.Allocate(size, alignof(PackExpansionExpr));
return ::new (mem) PackExpansionExpr(patternExpr, opaqueValues,
bindings, dotsLoc, environment,
implicit, type);
}

SequenceExpr *SequenceExpr::create(ASTContext &ctx, ArrayRef<Expr*> elements) {
Expand Down
1 change: 1 addition & 0 deletions lib/AST/GenericEnvironment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ GenericEnvironment::getOrCreateArchetypeFromInterfaceType(Type depType) {
if (rootGP->isParameterPack()) {
assert(getKind() == Kind::Primary);
result = PackArchetypeType::get(ctx, this, requirements.anchor,
requirements.packShape,
requirements.protos, superclass,
requirements.layout);
} else {
Expand Down
28 changes: 21 additions & 7 deletions lib/AST/PackExpansionMatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,23 @@

using namespace swift;

static PackType *gatherTupleElements(ArrayRef<TupleTypeElt> &elts,
Identifier name,
ASTContext &ctx) {
static Type createPackBinding(ASTContext &ctx, ArrayRef<Type> types) {
// If there is only one element and it's a pack expansion type,
// return the pattern type directly. Because PackType can only appear
// inside a PackExpansion, PackType(PackExpansionType()) will always
// simplify to the pattern type.
if (types.size() == 1) {
if (auto *expansion = types.front()->getAs<PackExpansionType>()) {
return expansion->getPatternType();
}
}

return PackType::get(ctx, types);
}

static Type gatherTupleElements(ArrayRef<TupleTypeElt> &elts,
Identifier name,
ASTContext &ctx) {
SmallVector<Type, 2> types;

if (!elts.empty() && elts.front().getName() == name) {
Expand All @@ -36,7 +50,7 @@ static PackType *gatherTupleElements(ArrayRef<TupleTypeElt> &elts,
} while (!elts.empty() && !elts.front().hasName());
}

return PackType::get(ctx, types);
return createPackBinding(ctx, types);
}

TuplePackMatcher::TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple)
Expand Down Expand Up @@ -69,7 +83,7 @@ bool TuplePackMatcher::match() {
"Tuple element with pack expansion type cannot be followed "
"by an unlabeled element");

auto *rhs = gatherTupleElements(rhsElts, lhsElt.getName(), ctx);
auto rhs = gatherTupleElements(rhsElts, lhsElt.getName(), ctx);
pairs.emplace_back(lhsExpansionType->getPatternType(), rhs, idx++);
continue;
}
Expand All @@ -89,7 +103,7 @@ bool TuplePackMatcher::match() {
"Tuple element with pack expansion type cannot be followed "
"by an unlabeled element");

auto *lhs = gatherTupleElements(lhsElts, rhsElt.getName(), ctx);
auto lhs = gatherTupleElements(lhsElts, rhsElt.getName(), ctx);
pairs.emplace_back(lhs, rhsExpansionType->getPatternType(), idx++);
continue;
}
Expand Down Expand Up @@ -299,4 +313,4 @@ bool PackMatcher::match() {
// - The prefix and suffix are mismatched, so we're left with something
// like {T..., Int} vs {Float, U...}.
return true;
}
}
8 changes: 6 additions & 2 deletions lib/AST/RequirementMachine/ConcreteContraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,12 @@ ConcreteContraction::substRequirement(const Requirement &req) const {
auto firstType = req.getFirstType();

switch (req.getKind()) {
case RequirementKind::SameShape:
llvm_unreachable("Same-shape requirement not supported here");
case RequirementKind::SameShape: {
auto substFirstType = substType(firstType);
auto substSecondType = substType(req.getSecondType());

return Requirement(req.getKind(), substFirstType, substSecondType);
}

case RequirementKind::Superclass:
case RequirementKind::SameType: {
Expand Down
6 changes: 5 additions & 1 deletion lib/AST/RequirementMachine/GenericSignatureQueries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ RequirementMachine::getLocalRequirements(

GenericSignature::LocalRequirements result;
result.anchor = Map.getTypeForTerm(term, genericParams);
result.packShape = getReducedShape(depType);

auto *props = Map.lookUpProperties(term);
if (!props)
Expand Down Expand Up @@ -789,13 +790,16 @@ void RequirementMachine::verify(const MutableTerm &term) const {
erased.add(Symbol::forName(symbol.getName(), Context));
break;

case Symbol::Kind::Shape:
erased.add(symbol);
break;

case Symbol::Kind::Protocol:
case Symbol::Kind::GenericParam:
case Symbol::Kind::Layout:
case Symbol::Kind::Superclass:
case Symbol::Kind::ConcreteType:
case Symbol::Kind::ConcreteConformance:
case Symbol::Kind::Shape:
llvm::errs() << "Bad interior symbol " << symbol << " in " << term << "\n";
abort();
break;
Expand Down
Loading