Skip to content

Preliminary steps towards support for closures that capture pack element environments #73348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions include/swift/AST/AnyFunctionRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,6 @@ class AnyFunctionRef {
return TheFunction.get<AbstractClosureExpr *>()->getCaptureInfo();
}


bool hasType() const {
if (auto *AFD = TheFunction.dyn_cast<AbstractFunctionDecl *>())
return AFD->hasInterfaceType();
return !TheFunction.get<AbstractClosureExpr *>()->getType().isNull();
}

ParameterList *getParameters() const {
if (auto *AFD = TheFunction.dyn_cast<AbstractFunctionDecl *>())
return AFD->getParameters();
Expand Down Expand Up @@ -174,20 +167,8 @@ class AnyFunctionRef {
return TheFunction.dyn_cast<AbstractClosureExpr*>();
}

/// Return true if this closure is passed as an argument to a function and is
/// known not to escape from that function. In this case, captures can be
/// more efficient.
bool isKnownNoEscape() const {
if (hasType() && !getType()->hasError())
return getType()->castTo<AnyFunctionType>()->isNoEscape();
return false;
}

/// Whether this function is @Sendable.
bool isSendable() const {
if (!hasType())
return false;

if (auto *fnType = getType()->getAs<AnyFunctionType>())
return fnType->isSendable();

Expand Down
45 changes: 34 additions & 11 deletions include/swift/AST/CaptureInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ValueDecl;
class FuncDecl;
class OpaqueValueExpr;
class VarDecl;
class GenericEnvironment;

/// CapturedValue includes both the declaration being captured, along with flags
/// that indicate how it is captured.
Expand Down Expand Up @@ -140,19 +141,27 @@ class DynamicSelfType;
/// Stores information about captured variables.
class CaptureInfo {
class CaptureInfoStorage final
: public llvm::TrailingObjects<CaptureInfoStorage, CapturedValue> {
: public llvm::TrailingObjects<CaptureInfoStorage,
CapturedValue,
GenericEnvironment *> {

DynamicSelfType *DynamicSelf;
OpaqueValueExpr *OpaqueValue;
unsigned Count;
unsigned NumCapturedValues;
unsigned NumGenericEnvironments;

public:
explicit CaptureInfoStorage(unsigned count, DynamicSelfType *dynamicSelf,
OpaqueValueExpr *opaqueValue)
: DynamicSelf(dynamicSelf), OpaqueValue(opaqueValue), Count(count) { }
explicit CaptureInfoStorage(DynamicSelfType *dynamicSelf,
OpaqueValueExpr *opaqueValue,
unsigned numCapturedValues,
unsigned numGenericEnvironments)
: DynamicSelf(dynamicSelf), OpaqueValue(opaqueValue),
NumCapturedValues(numCapturedValues),
NumGenericEnvironments(numGenericEnvironments) { }

ArrayRef<CapturedValue> getCaptures() const {
return llvm::ArrayRef(this->getTrailingObjects<CapturedValue>(), Count);
}
ArrayRef<CapturedValue> getCaptures() const;

ArrayRef<GenericEnvironment *> getGenericEnvironments() const;

DynamicSelfType *getDynamicSelfType() const {
return DynamicSelf;
Expand All @@ -161,6 +170,10 @@ class CaptureInfo {
OpaqueValueExpr *getOpaqueValue() const {
return OpaqueValue;
}

unsigned numTrailingObjects(OverloadToken<CapturedValue>) const {
return NumCapturedValues;
}
};

enum class Flags : unsigned {
Expand All @@ -173,9 +186,11 @@ class CaptureInfo {
public:
/// The default-constructed CaptureInfo is "not yet computed".
CaptureInfo() = default;
CaptureInfo(ASTContext &ctx, ArrayRef<CapturedValue> captures,
CaptureInfo(ASTContext &ctx,
ArrayRef<CapturedValue> captures,
DynamicSelfType *dynamicSelf, OpaqueValueExpr *opaqueValue,
bool genericParamCaptures);
bool genericParamCaptures,
ArrayRef<GenericEnvironment *> genericEnv=ArrayRef<GenericEnvironment*>());

/// A CaptureInfo representing no captures at all.
static CaptureInfo empty();
Expand All @@ -190,12 +205,20 @@ class CaptureInfo {
!hasDynamicSelfCapture() && !hasOpaqueValueCapture();
}

/// Returns all captured values and opaque expressions.
ArrayRef<CapturedValue> getCaptures() const {
assert(hasBeenComputed());
return StorageAndFlags.getPointer()->getCaptures();
}

/// \returns true if the function captures any generic type parameters.
/// Returns all captured pack element environments.
ArrayRef<GenericEnvironment *> getGenericEnvironments() const {
assert(hasBeenComputed());
return StorageAndFlags.getPointer()->getGenericEnvironments();
}

/// \returns true if the function captures the primary generic environment
/// from its innermost declaration context.
bool hasGenericParamCaptures() const {
assert(hasBeenComputed());
return StorageAndFlags.getInt().contains(Flags::HasGenericParamCaptures);
Expand Down
17 changes: 17 additions & 0 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -2229,6 +2229,10 @@ class ConstraintSystem {
/// from declared parameters/result and body.
llvm::MapVector<const ClosureExpr *, FunctionType *> ClosureTypes;

/// Maps closures and local functions to the pack expansion expressions they
/// capture.
llvm::MapVector<AnyFunctionRef, SmallVector<PackExpansionExpr *, 1>> CapturedExpansions;

/// Maps expressions for implied results (e.g implicit 'then' statements,
/// implicit 'return' statements in single expression body closures) to their
/// result kind.
Expand Down Expand Up @@ -3164,6 +3168,19 @@ class ConstraintSystem {
return nullptr;
}

SmallVector<PackExpansionExpr *, 1> getCapturedExpansions(AnyFunctionRef func) const {
auto result = CapturedExpansions.find(func);
if (result == CapturedExpansions.end())
return {};

return result->second;
}

void setCapturedExpansions(AnyFunctionRef func, SmallVector<PackExpansionExpr *, 1> exprs) {
assert(CapturedExpansions.count(func) == 0 && "Cannot reset captured expansions");
CapturedExpansions.insert({func, exprs});
}

TypeVariableType *getKeyPathValueType(const KeyPathExpr *keyPath) const {
auto result = getKeyPathValueTypeIfAvailable(keyPath);
assert(result);
Expand Down
96 changes: 53 additions & 43 deletions lib/AST/ASTVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,20 +198,24 @@ class Verifier : public ASTWalker {
using ScopeLike = llvm::PointerUnion<DeclContext *, BraceStmt *>;
SmallVector<ScopeLike, 4> Scopes;

/// The stack of generic contexts.
using GenericLike = llvm::PointerUnion<DeclContext *, GenericEnvironment *>;
SmallVector<GenericLike, 2> Generics;
/// The stack of declaration contexts we're visiting. The primary
/// archetypes from the innermost generic environment are in scope.
SmallVector<DeclContext *, 2> Generics;

/// The set of all opened existential and opened pack element generic
/// environments that are currently in scope.
llvm::DenseSet<GenericEnvironment *> LocalGenerics;

/// We track the pack expansion expressions in ForEachStmts, because
/// their local generics remain in scope until the end of the statement.
llvm::DenseSet<PackExpansionExpr *> ForEachPatternSequences;

/// The stack of optional evaluations active at this point.
SmallVector<OptionalEvaluationExpr *, 4> OptionalEvaluations;

/// The set of opaque value expressions active at this point.
llvm::DenseMap<OpaqueValueExpr *, unsigned> OpaqueValues;

/// The set of opened existential archetypes that are currently
/// active.
llvm::DenseSet<OpenedArchetypeType *> OpenedExistentialArchetypes;

/// The set of inout to pointer expr that match the following pattern:
///
/// (call-expr
Expand Down Expand Up @@ -632,11 +636,10 @@ class Verifier : public ASTWalker {

bool foundError = type->getCanonicalType().findIf([&](Type type) -> bool {
if (auto archetype = type->getAs<ArchetypeType>()) {
auto root = archetype->getRoot();

// Opaque archetypes are globally available. We don't need to check
// them here.
if (isa<OpaqueTypeArchetypeType>(root))
if (isa<OpaqueTypeArchetypeType>(archetype))
return false;

// Only visit each archetype once.
Expand All @@ -645,11 +648,10 @@ class Verifier : public ASTWalker {

// We should know about archetypes corresponding to opened
// existential archetypes.
if (auto opened = dyn_cast<OpenedArchetypeType>(root)) {
if (OpenedExistentialArchetypes.count(opened) == 0) {
Out << "Found opened existential archetype "
<< root->getString()
<< " outside enclosing OpenExistentialExpr\n";
if (isa<LocalArchetypeType>(archetype)) {
if (LocalGenerics.count(archetype->getGenericEnvironment()) == 0) {
Out << "Found local archetype " << archetype
<< " outside its defining scope\n";
return true;
}

Expand All @@ -659,34 +661,19 @@ class Verifier : public ASTWalker {
// Otherwise, the archetype needs to be from this scope.
if (Generics.empty() || !Generics.back()) {
Out << "AST verification error: archetype outside of generic "
"context: " << root->getString() << "\n";
"context: " << archetype << "\n";
return true;
}

// Get the archetype's generic signature.
GenericEnvironment *archetypeEnv = root->getGenericEnvironment();
GenericEnvironment *archetypeEnv = archetype->getGenericEnvironment();
auto archetypeSig = archetypeEnv->getGenericSignature();

auto genericCtx = Generics.back();
GenericSignature genericSig;
if (auto *genericDC = genericCtx.dyn_cast<DeclContext *>()) {
genericSig = genericDC->getGenericSignatureOfContext();
} else {
auto *genericEnv = genericCtx.get<GenericEnvironment *>();
genericSig = genericEnv->getGenericSignature();

// Check whether this archetype is a substitution from the
// outer generic context of an opened element environment.
if (genericEnv->getKind() == GenericEnvironment::Kind::OpenedElement) {
auto contextSubs = genericEnv->getPackElementContextSubstitutions();
QuerySubstitutionMap isInContext{contextSubs};
if (isInContext(root->getInterfaceType()->castTo<GenericTypeParamType>()))
return false;
}
}
GenericSignature genericSig = genericCtx->getGenericSignatureOfContext();

if (genericSig.getPointer() != archetypeSig.getPointer()) {
Out << "Archetype " << root->getString() << " not allowed "
Out << "Archetype " << archetype->getString() << " not allowed "
<< "in this context\n";
Out << "Archetype generic signature: "
<< archetypeSig->getAsString() << "\n";
Expand Down Expand Up @@ -735,7 +722,7 @@ class Verifier : public ASTWalker {
}
void popScope(DeclContext *scope) {
assert(Scopes.back().get<DeclContext*>() == scope);
assert(Generics.back().get<DeclContext*>() == scope);
assert(Generics.back() == scope);
Scopes.pop_back();
Generics.pop_back();
}
Expand Down Expand Up @@ -808,6 +795,9 @@ class Verifier : public ASTWalker {
if (!shouldVerify(expansion)) {
return false;
}

assert(ForEachPatternSequences.count(expansion) == 0);
ForEachPatternSequences.insert(expansion);
}

if (!S->getElementExpr())
Expand All @@ -821,6 +811,10 @@ class Verifier : public ASTWalker {
void cleanup(ForEachStmt *S) {
if (auto *expansion =
dyn_cast<PackExpansionExpr>(S->getParsedSequence())) {
assert(ForEachPatternSequences.count(expansion) != 0);
ForEachPatternSequences.erase(expansion);

// Clean up for real.
cleanup(expansion);
}

Expand Down Expand Up @@ -851,6 +845,16 @@ class Verifier : public ASTWalker {
OpaqueValues.erase(expr->getInterpolationExpr());
}

void pushLocalGenerics(GenericEnvironment *env) {
assert(LocalGenerics.count(env)==0);
LocalGenerics.insert(env);
}

void popLocalGenerics(GenericEnvironment *env) {
assert(LocalGenerics.count(env)==1);
LocalGenerics.erase(env);
}

bool shouldVerify(OpenExistentialExpr *expr) {
if (!shouldVerify(cast<Expr>(expr)))
return false;
Expand All @@ -862,8 +866,8 @@ class Verifier : public ASTWalker {

assert(!OpaqueValues.count(expr->getOpaqueValue()));
OpaqueValues[expr->getOpaqueValue()] = 0;
assert(OpenedExistentialArchetypes.count(expr->getOpenedArchetype())==0);
OpenedExistentialArchetypes.insert(expr->getOpenedArchetype());

pushLocalGenerics(expr->getOpenedArchetype()->getGenericEnvironment());
return true;
}

Expand All @@ -875,22 +879,28 @@ class Verifier : public ASTWalker {

assert(OpaqueValues.count(expr->getOpaqueValue()));
OpaqueValues.erase(expr->getOpaqueValue());
assert(OpenedExistentialArchetypes.count(expr->getOpenedArchetype())==1);
OpenedExistentialArchetypes.erase(expr->getOpenedArchetype());

popLocalGenerics(expr->getOpenedArchetype()->getGenericEnvironment());
}

bool shouldVerify(PackExpansionExpr *expr) {
if (!shouldVerify(cast<Expr>(expr)))
return false;

Generics.push_back(expr->getGenericEnvironment());
// Don't push local generics again when we visit the expr inside
// the ForEachStmt.
if (auto *genericEnv = expr->getGenericEnvironment())
if (ForEachPatternSequences.count(expr) == 0)
pushLocalGenerics(genericEnv);
return true;
}

void cleanup(PackExpansionExpr *E) {
assert(Generics.back().get<GenericEnvironment *>() ==
E->getGenericEnvironment());
Generics.pop_back();
void cleanup(PackExpansionExpr *expr) {
// If this is a pack iteration pattern, don't pop local generics
// until we exit the ForEachStmt.
if (auto *genericEnv = expr->getGenericEnvironment())
if (ForEachPatternSequences.count(expr) == 0)
popLocalGenerics(genericEnv);
}

bool shouldVerify(MakeTemporarilyEscapableExpr *expr) {
Expand Down
Loading