Skip to content

[6.0] Implement pack element reference captures #74039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
45 changes: 21 additions & 24 deletions include/swift/AST/CaptureInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ template <> struct DenseMapInfo<swift::CapturedValue>;
namespace swift {
class ValueDecl;
class FuncDecl;
class Expr;
class OpaqueValueExpr;
class PackElementExpr;
class VarDecl;
class GenericEnvironment;
class Type;

/// CapturedValue includes both the declaration being captured, along with flags
/// that indicate how it is captured.
Expand All @@ -52,7 +55,7 @@ class CapturedValue {

public:
using Storage =
llvm::PointerIntPair<llvm::PointerUnion<ValueDecl*, OpaqueValueExpr*>, 2,
llvm::PointerIntPair<llvm::PointerUnion<ValueDecl *, Expr *>, 2,
unsigned>;

private:
Expand All @@ -78,15 +81,7 @@ class CapturedValue {
CapturedValue(ValueDecl *Val, unsigned Flags, SourceLoc Loc)
: Value(Val, Flags), Loc(Loc) {}

private:
// This is only used in TypeLowering when forming Lowered Capture
// Info. OpaqueValueExpr captured value should never show up in the AST
// itself.
//
// NOTE: AbstractClosureExpr::getIsolationCrossing relies upon this and
// asserts that it never sees one of these.
explicit CapturedValue(OpaqueValueExpr *Val, unsigned Flags)
: Value(Val, Flags), Loc(SourceLoc()) {}
CapturedValue(Expr *Val, unsigned Flags);

public:
static CapturedValue getDynamicSelfMetadata() {
Expand All @@ -97,36 +92,38 @@ class CapturedValue {
bool isNoEscape() const { return Value.getInt() & IsNoEscape; }

bool isDynamicSelfMetadata() const { return !Value.getPointer(); }
bool isOpaqueValue() const {
return Value.getPointer().is<OpaqueValueExpr *>();

bool isExpr() const {
return Value.getPointer().dyn_cast<Expr *>();
}

bool isPackElement() const;
bool isOpaqueValue() const;

/// Returns true if this captured value is a local capture.
///
/// NOTE: This implies that the value is not dynamic self metadata, since
/// values with decls are the only values that are able to be local captures.
bool isLocalCapture() const;

CapturedValue mergeFlags(CapturedValue cv) {
assert(Value.getPointer() == cv.Value.getPointer() &&
"merging flags on two different value decls");
return CapturedValue(
Storage(Value.getPointer(), getFlags() & cv.getFlags()),
Loc);
CapturedValue mergeFlags(unsigned flags) const {
return CapturedValue(Storage(Value.getPointer(), getFlags() & flags), Loc);
}

ValueDecl *getDecl() const {
assert(Value.getPointer() && "dynamic Self metadata capture does not "
"have a value");
return Value.getPointer().dyn_cast<ValueDecl *>();
}

OpaqueValueExpr *getOpaqueValue() const {
assert(Value.getPointer() && "dynamic Self metadata capture does not "
"have a value");
return Value.getPointer().dyn_cast<OpaqueValueExpr *>();
Expr *getExpr() const {
return Value.getPointer().dyn_cast<Expr *>();
}

OpaqueValueExpr *getOpaqueValue() const;

PackElementExpr *getPackElement() const;

Type getPackElementType() const;

SourceLoc getLoc() const { return Loc; }

unsigned getFlags() const { return Value.getInt(); }
Expand Down
19 changes: 19 additions & 0 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -2237,6 +2237,10 @@ class ConstraintSystem {
/// from declared parameters/result and body.
llvm::MapVector<const ClosureExpr *, FunctionType *> ClosureTypes;

/// Maps closures and local functions to the pack expansion expressions they
/// capture.
llvm::MapVector<AnyFunctionRef, SmallVector<PackExpansionExpr *, 1>> CapturedExpansions;

/// Maps expressions for implied results (e.g implicit 'then' statements,
/// implicit 'return' statements in single expression body closures) to their
/// result kind.
Expand Down Expand Up @@ -3172,6 +3176,19 @@ class ConstraintSystem {
return nullptr;
}

SmallVector<PackExpansionExpr *, 1> getCapturedExpansions(AnyFunctionRef func) const {
auto result = CapturedExpansions.find(func);
if (result == CapturedExpansions.end())
return {};

return result->second;
}

void setCapturedExpansions(AnyFunctionRef func, SmallVector<PackExpansionExpr *, 1> exprs) {
assert(CapturedExpansions.count(func) == 0 && "Cannot reset captured expansions");
CapturedExpansions.insert({func, exprs});
}

TypeVariableType *getKeyPathValueType(const KeyPathExpr *keyPath) const {
auto result = getKeyPathValueTypeIfAvailable(keyPath);
assert(result);
Expand Down Expand Up @@ -6439,6 +6456,7 @@ class ConjunctionElementProducer : public BindingProducer<ConjunctionElement> {
///
/// This includes:
/// - Not yet resolved outer VarDecls (including closure parameters)
/// - Outer pack expansions that are not yet fully resolved
/// - Return statements with a contextual type that has not yet been resolved
///
/// This is required because isolated conjunctions, just like single-expression
Expand All @@ -6460,6 +6478,7 @@ class TypeVarRefCollector : public ASTWalker {

/// Infer the referenced type variables from a given decl.
void inferTypeVars(Decl *D);
void inferTypeVars(PackExpansionExpr *);

MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
Expand Down
38 changes: 37 additions & 1 deletion lib/AST/CaptureInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,36 @@
#include "swift/AST/CaptureInfo.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Expr.h"
#include "swift/AST/GenericEnvironment.h"
#include "llvm/Support/raw_ostream.h"

using namespace swift;

CapturedValue::CapturedValue(Expr *Val, unsigned Flags)
: Value(Val, Flags), Loc(SourceLoc()) {
assert(isa<OpaqueValueExpr>(Val) || isa<PackElementExpr>(Val));
}

bool CapturedValue::isPackElement() const {
return isExpr() && isa<PackElementExpr>(getExpr());
}
bool CapturedValue::isOpaqueValue() const {
return isExpr() && isa<OpaqueValueExpr>(getExpr());
}

OpaqueValueExpr *CapturedValue::getOpaqueValue() const {
return dyn_cast_or_null<OpaqueValueExpr>(getExpr());
}

PackElementExpr *CapturedValue::getPackElement() const {
return dyn_cast_or_null<PackElementExpr>(getExpr());
}

Type CapturedValue::getPackElementType() const {
return getPackElement()->getType();
}

ArrayRef<CapturedValue>
CaptureInfo::CaptureInfoStorage::getCaptures() const {
return llvm::ArrayRef(this->getTrailingObjects<CapturedValue>(), NumCapturedValues);
Expand Down Expand Up @@ -153,7 +178,18 @@ void CaptureInfo::print(raw_ostream &OS) const {

interleave(getCaptures(),
[&](const CapturedValue &capture) {
OS << capture.getDecl()->getBaseName();
if (capture.getDecl())
OS << capture.getDecl()->getBaseName();
else if (capture.isPackElement()) {
OS << "[pack element] ";
capture.getPackElement()->dump(OS);
} else if (capture.isOpaqueValue()) {
OS << "[opaque] ";
capture.getOpaqueValue()->dump(OS);
} else {
OS << "[unknown] ";
assert(false);
}

if (capture.isDirect())
OS << "<direct>";
Expand Down
48 changes: 29 additions & 19 deletions lib/SIL/IR/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1970,33 +1970,39 @@ lowerCaptureContextParameters(TypeConverter &TC, SILDeclRef function,
continue;
}

auto *varDecl = cast<VarDecl>(capture.getDecl());
auto options = SILParameterInfo::Options();

Type type;
VarDecl *varDecl = nullptr;
if (auto *expr = capture.getPackElement()) {
type = expr->getType();
} else {
varDecl = cast<VarDecl>(capture.getDecl());
type = varDecl->getTypeInContext();

// If we're capturing a parameter pack, wrap it in a tuple.
if (type->is<PackExpansionType>()) {
assert(!cast<ParamDecl>(varDecl)->supportsMutation() &&
"Cannot capture a pack as an lvalue");

SmallVector<TupleTypeElt, 1> elts;
elts.push_back(type);
type = TupleType::get(elts, TC.Context);
}

if (isolatedParam == varDecl) {
options |= SILParameterInfo::Isolated;
isolatedParam = nullptr;
}
}

auto type = varDecl->getTypeInContext();
assert(!type->hasLocalArchetype() ||
(genericSig && origGenericSig &&
!genericSig->isEqual(origGenericSig)));
type = mapTypeOutOfContext(type);

auto canType = type->getReducedType(
genericSig ? genericSig : origGenericSig);

auto options = SILParameterInfo::Options();
if (isolatedParam == varDecl) {
options |= SILParameterInfo::Isolated;
isolatedParam = nullptr;
}

// If we're capturing a parameter pack, wrap it in a tuple.
if (isa<PackExpansionType>(canType)) {
assert(!cast<ParamDecl>(varDecl)->supportsMutation() &&
"Cannot capture a pack as an lvalue");

SmallVector<TupleTypeElt, 1> elts;
elts.push_back(canType);
canType = CanTupleType(TupleType::get(elts, TC.Context));
}

auto &loweredTL =
TC.getTypeLowering(AbstractionPattern(genericSig, canType), canType,
expansion);
Expand All @@ -2018,6 +2024,8 @@ lowerCaptureContextParameters(TypeConverter &TC, SILDeclRef function,
break;
}
case CaptureKind::Box: {
assert(varDecl);

// The type in the box is lowered in the minimal context.
auto minimalLoweredTy =
TC.getTypeLowering(AbstractionPattern(genericSig, canType), canType,
Expand All @@ -2035,6 +2043,8 @@ lowerCaptureContextParameters(TypeConverter &TC, SILDeclRef function,
break;
}
case CaptureKind::ImmutableBox: {
assert(varDecl);

// The type in the box is lowered in the minimal context.
auto minimalLoweredTy =
TC.getTypeLowering(AbstractionPattern(genericSig, canType), canType,
Expand Down
69 changes: 54 additions & 15 deletions lib/SIL/IR/TypeLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,19 @@ static bool hasSingletonMetatype(CanType instanceType) {

CaptureKind TypeConverter::getDeclCaptureKind(CapturedValue capture,
TypeExpansionContext expansion) {
if (auto *expr = capture.getPackElement()) {
auto contextTy = expr->getType();
auto &lowering = getTypeLowering(
contextTy, TypeExpansionContext::noOpaqueTypeArchetypesSubstitution(
expansion.getResilienceExpansion()));

assert(!contextTy->isNoncopyable() && "Not implemented");
if (!lowering.isAddressOnly())
return CaptureKind::Constant;

return CaptureKind::Immutable;
}

auto decl = capture.getDecl();
auto *var = cast<VarDecl>(decl);
assert(var->hasStorage() &&
Expand Down Expand Up @@ -4219,7 +4232,10 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) {

// Recursively collect transitive captures from captured local functions.
llvm::DenseSet<AnyFunctionRef> visitedFunctions;
llvm::MapVector<ValueDecl*,CapturedValue> captures;

// FIXME: CapturedValue should just be a hash key
llvm::MapVector<VarDecl *, CapturedValue> varCaptures;
llvm::MapVector<PackElementExpr *, CapturedValue> packElementCaptures;

// If there is a capture of 'self' with dynamic 'Self' type, it goes last so
// that IRGen can pass dynamic 'Self' metadata.
Expand All @@ -4236,9 +4252,29 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) {
std::function<void (AnyFunctionRef)> collectFunctionCaptures;
std::function<void (SILDeclRef)> collectConstantCaptures;

auto recordCapture = [&](CapturedValue capture) {
if (auto *expr = capture.getPackElement()) {
auto existing = packElementCaptures.find(expr);
if (existing != packElementCaptures.end()) {
existing->second = existing->second.mergeFlags(capture.getFlags());
} else {
packElementCaptures.insert(std::pair<PackElementExpr *, CapturedValue>(
expr, capture));
}
} else {
VarDecl *value = cast<VarDecl>(capture.getDecl());
auto existing = varCaptures.find(value);
if (existing != varCaptures.end()) {
existing->second = existing->second.mergeFlags(capture.getFlags());
} else {
varCaptures.insert(std::pair<VarDecl *, CapturedValue>(
value, capture));
}
}
};

collectCaptures = [&](CaptureInfo captureInfo, DeclContext *dc) {
assert(captureInfo.hasBeenComputed());

if (captureInfo.hasGenericParamCaptures())
capturesGenericParams = true;
if (captureInfo.hasDynamicSelfCapture())
Expand All @@ -4253,9 +4289,15 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) {
genericEnv.insert(env);
}

SmallVector<CapturedValue, 4> localCaptures;
captureInfo.getLocalCaptures(localCaptures);
for (auto capture : localCaptures) {
for (auto capture : captureInfo.getCaptures()) {
if (capture.isPackElement()) {
recordCapture(capture);
continue;
}

if (!capture.isLocalCapture())
continue;

// If the capture is of another local function, grab its transitive
// captures instead.
if (auto capturedFn = getAnyFunctionRefFromCapture(capture)) {
Expand Down Expand Up @@ -4367,7 +4409,7 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) {
// If we've already captured the same value already, just merge
// flags.
if (selfCapture && selfCapture->getDecl() == capture.getDecl()) {
selfCapture = selfCapture->mergeFlags(capture);
selfCapture = selfCapture->mergeFlags(capture.getFlags());
continue;

// Otherwise, record the canonical self capture. It will appear
Expand All @@ -4387,13 +4429,7 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) {
}

// Collect non-function captures.
ValueDecl *value = capture.getDecl();
auto existing = captures.find(value);
if (existing != captures.end()) {
existing->second = existing->second.mergeFlags(capture);
} else {
captures.insert(std::pair<ValueDecl *, CapturedValue>(value, capture));
}
recordCapture(capture);
}
};

Expand Down Expand Up @@ -4450,7 +4486,10 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) {
collectConstantCaptures(fn);

SmallVector<CapturedValue, 4> resultingCaptures;
for (auto capturePair : captures) {
for (auto capturePair : varCaptures) {
resultingCaptures.push_back(capturePair.second);
}
for (auto capturePair : packElementCaptures) {
resultingCaptures.push_back(capturePair.second);
}

Expand All @@ -4469,7 +4508,7 @@ TypeConverter::getLoweredLocalCaptures(SILDeclRef fn) {
resultingCaptures.push_back(*selfCapture);
}

// Cache the uniqued set of transitive captures.
// Cache the result.
CaptureInfo info(Context, resultingCaptures,
capturesDynamicSelf, capturesOpaqueValue,
capturesGenericParams, genericEnv.getArrayRef());
Expand Down
Loading