Skip to content

ConstExtract: Refactor handling of AvailabilitySpec #79449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions include/swift/AST/ConstTypeInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,24 @@ class BuilderValue : public CompileTimeValue {
///
class ConditionalMember : public BuilderMember {
public:
class AvailabilitySpec {
private:
AvailabilityDomain Domain;
llvm::VersionTuple Version;

public:
AvailabilitySpec(AvailabilityDomain Domain, llvm::VersionTuple Version)
: Domain(Domain), Version(Version) {}

AvailabilityDomain getDomain() const { return Domain; }
llvm::VersionTuple getVersion() const { return Version; }
};

ConditionalMember(MemberKind MemberKind,
std::vector<AvailabilitySpec> AvailabilityAttributes,
std::vector<AvailabilitySpec> AvailabilitySpecs,
std::vector<std::shared_ptr<BuilderMember>> IfElements,
std::vector<std::shared_ptr<BuilderMember>> ElseElements)
: BuilderMember(MemberKind),
AvailabilityAttributes(AvailabilityAttributes),
: BuilderMember(MemberKind), AvailabilitySpecs(AvailabilitySpecs),
IfElements(IfElements), ElseElements(ElseElements) {}

ConditionalMember(MemberKind MemberKind,
Expand All @@ -238,9 +250,8 @@ class BuilderValue : public CompileTimeValue {
(Kind == MemberKind::Optional);
}

std::optional<std::vector<AvailabilitySpec>>
getAvailabilityAttributes() const {
return AvailabilityAttributes;
std::optional<std::vector<AvailabilitySpec>> getAvailabilitySpecs() const {
return AvailabilitySpecs;
}
std::vector<std::shared_ptr<BuilderMember>> getIfElements() const {
return IfElements;
Expand All @@ -250,7 +261,7 @@ class BuilderValue : public CompileTimeValue {
}

private:
std::optional<std::vector<AvailabilitySpec>> AvailabilityAttributes;
std::optional<std::vector<AvailabilitySpec>> AvailabilitySpecs;
std::vector<std::shared_ptr<BuilderMember>> IfElements;
std::vector<std::shared_ptr<BuilderMember>> ElseElements;
};
Expand Down
120 changes: 72 additions & 48 deletions lib/ConstExtract/ConstExtract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,15 @@ parseProtocolListFromFile(StringRef protocolListFilePath,
}

std::vector<std::shared_ptr<BuilderValue::BuilderMember>>
getResultBuilderMembersFromBraceStmt(BraceStmt *braceStmt);
getResultBuilderMembersFromBraceStmt(BraceStmt *braceStmt,
const DeclContext *declContext);

static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr);
static std::shared_ptr<CompileTimeValue>
extractCompileTimeValue(Expr *expr, const DeclContext *declContext);

static std::vector<FunctionParameter>
extractFunctionArguments(const ArgumentList *args) {
extractFunctionArguments(const ArgumentList *args,
const DeclContext *declContext) {
std::vector<FunctionParameter> parameters;

for (auto arg : *args) {
Expand All @@ -188,7 +191,8 @@ extractFunctionArguments(const ArgumentList *args) {
} else if (auto optionalInject = dyn_cast<InjectIntoOptionalExpr>(argExpr)) {
argExpr = optionalInject->getSubExpr();
}
parameters.push_back({label, type, extractCompileTimeValue(argExpr)});
parameters.push_back(
{label, type, extractCompileTimeValue(argExpr, declContext)});
}

return parameters;
Expand Down Expand Up @@ -224,7 +228,8 @@ static std::optional<std::string> extractRawLiteral(Expr *expr) {
return std::nullopt;
}

static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
static std::shared_ptr<CompileTimeValue>
extractCompileTimeValue(Expr *expr, const DeclContext *declContext) {
if (expr) {
switch (expr->getKind()) {
case ExprKind::BooleanLiteral:
Expand All @@ -247,7 +252,8 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
auto arrayExpr = cast<ArrayExpr>(expr);
std::vector<std::shared_ptr<CompileTimeValue>> elementValues;
for (const auto elementExpr : arrayExpr->getElements()) {
elementValues.push_back(extractCompileTimeValue(elementExpr));
elementValues.push_back(
extractCompileTimeValue(elementExpr, declContext));
}
return std::make_shared<ArrayValue>(elementValues);
}
Expand All @@ -256,7 +262,7 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
auto dictionaryExpr = cast<DictionaryExpr>(expr);
std::vector<std::shared_ptr<TupleValue>> tuples;
for (auto elementExpr : dictionaryExpr->getElements()) {
auto elementValue = extractCompileTimeValue(elementExpr);
auto elementValue = extractCompileTimeValue(elementExpr, declContext);
if (isa<TupleValue>(elementValue.get())) {
tuples.push_back(std::static_pointer_cast<TupleValue>(elementValue));
}
Expand All @@ -279,13 +285,15 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
? std::nullopt
: std::optional<std::string>(elementName.str().str());

elements.push_back({label, elementExpr->getType(),
extractCompileTimeValue(elementExpr)});
elements.push_back(
{label, elementExpr->getType(),
extractCompileTimeValue(elementExpr, declContext)});
}
} else {
for (auto elementExpr : tupleExpr->getElements()) {
elements.push_back({std::nullopt, elementExpr->getType(),
extractCompileTimeValue(elementExpr)});
elements.push_back(
{std::nullopt, elementExpr->getType(),
extractCompileTimeValue(elementExpr, declContext)});
}
}
return std::make_shared<TupleValue>(elements);
Expand All @@ -301,13 +309,13 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
declRefExpr->getDecl()->getName().getBaseIdentifier().str().str();

std::vector<FunctionParameter> parameters =
extractFunctionArguments(callExpr->getArgs());
extractFunctionArguments(callExpr->getArgs(), declContext);
return std::make_shared<FunctionCallValue>(identifier, parameters);
}

if (functionKind == ExprKind::ConstructorRefCall) {
std::vector<FunctionParameter> parameters =
extractFunctionArguments(callExpr->getArgs());
extractFunctionArguments(callExpr->getArgs(), declContext);
return std::make_shared<InitCallValue>(callExpr->getType(), parameters);
}

Expand All @@ -320,7 +328,7 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
declRefExpr->getDecl()->getName().getBaseIdentifier().str().str();

std::vector<FunctionParameter> parameters =
extractFunctionArguments(callExpr->getArgs());
extractFunctionArguments(callExpr->getArgs(), declContext);

auto declRef = dotSyntaxCallExpr->getFn()->getReferencedDecl();
switch (declRef.getDecl()->getKind()) {
Expand Down Expand Up @@ -364,23 +372,23 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {

case ExprKind::Erasure: {
auto erasureExpr = cast<ErasureExpr>(expr);
return extractCompileTimeValue(erasureExpr->getSubExpr());
return extractCompileTimeValue(erasureExpr->getSubExpr(), declContext);
}

case ExprKind::Paren: {
auto parenExpr = cast<ParenExpr>(expr);
return extractCompileTimeValue(parenExpr->getSubExpr());
return extractCompileTimeValue(parenExpr->getSubExpr(), declContext);
}

case ExprKind::PropertyWrapperValuePlaceholder: {
auto placeholderExpr = cast<PropertyWrapperValuePlaceholderExpr>(expr);
return extractCompileTimeValue(
placeholderExpr->getOriginalWrappedValue());
return extractCompileTimeValue(placeholderExpr->getOriginalWrappedValue(),
declContext);
}

case ExprKind::Coerce: {
auto coerceExpr = cast<CoerceExpr>(expr);
return extractCompileTimeValue(coerceExpr->getSubExpr());
return extractCompileTimeValue(coerceExpr->getSubExpr(), declContext);
}

case ExprKind::DotSelf: {
Expand All @@ -394,7 +402,8 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {

case ExprKind::UnderlyingToOpaque: {
auto underlyingToOpaque = cast<UnderlyingToOpaqueExpr>(expr);
return extractCompileTimeValue(underlyingToOpaque->getSubExpr());
return extractCompileTimeValue(underlyingToOpaque->getSubExpr(),
declContext);
}

case ExprKind::DefaultArgument: {
Expand Down Expand Up @@ -445,12 +454,13 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {

case ExprKind::InjectIntoOptional: {
auto injectIntoOptionalExpr = cast<InjectIntoOptionalExpr>(expr);
return extractCompileTimeValue(injectIntoOptionalExpr->getSubExpr());
return extractCompileTimeValue(injectIntoOptionalExpr->getSubExpr(),
declContext);
}

case ExprKind::Load: {
auto loadExpr = cast<LoadExpr>(expr);
return extractCompileTimeValue(loadExpr->getSubExpr());
return extractCompileTimeValue(loadExpr->getSubExpr(), declContext);
}

case ExprKind::MemberRef: {
Expand All @@ -474,7 +484,7 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
Ctx, [&](bool isInterpolation, CallExpr *segment) -> void {
auto arg = segment->getArgs()->get(0);
auto expr = arg.getExpr();
segments.push_back(extractCompileTimeValue(expr));
segments.push_back(extractCompileTimeValue(expr, declContext));
});

return std::make_shared<InterpolatedStringLiteralValue>(segments);
Expand All @@ -483,7 +493,8 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
case ExprKind::Closure: {
auto closureExpr = cast<ClosureExpr>(expr);
auto body = closureExpr->getBody();
auto resultBuilderMembers = getResultBuilderMembersFromBraceStmt(body);
auto resultBuilderMembers =
getResultBuilderMembersFromBraceStmt(body, declContext);

if (!resultBuilderMembers.empty()) {
return std::make_shared<BuilderValue>(resultBuilderMembers);
Expand All @@ -493,7 +504,7 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {

case ExprKind::DerivedToBase: {
auto derivedExpr = cast<DerivedToBaseExpr>(expr);
return extractCompileTimeValue(derivedExpr->getSubExpr());
return extractCompileTimeValue(derivedExpr->getSubExpr(), declContext);
}
default: {
break;
Expand All @@ -504,8 +515,8 @@ static std::shared_ptr<CompileTimeValue> extractCompileTimeValue(Expr *expr) {
return std::make_shared<RuntimeValue>();
}

static CustomAttrValue
extractAttributeValue(const CustomAttr *attr) {
static CustomAttrValue extractAttributeValue(const CustomAttr *attr,
const DeclContext *declContext) {
std::vector<FunctionParameter> parameters;
if (const auto *args = attr->getArgs()) {
for (auto arg : *args) {
Expand All @@ -518,8 +529,8 @@ extractAttributeValue(const CustomAttr *attr) {
argExpr = decl->getTypeCheckedDefaultExpr();
}
}
parameters.push_back(
{label, argExpr->getType(), extractCompileTimeValue(argExpr)});
parameters.push_back({label, argExpr->getType(),
extractCompileTimeValue(argExpr, declContext)});
}
}
return {attr, parameters};
Expand All @@ -529,7 +540,8 @@ static AttrValueVector
extractPropertyWrapperAttrValues(VarDecl *propertyDecl) {
AttrValueVector customAttrValues;
for (auto *propertyWrapper : propertyDecl->getAttachedPropertyWrappers())
customAttrValues.push_back(extractAttributeValue(propertyWrapper));
customAttrValues.push_back(
extractAttributeValue(propertyWrapper, propertyDecl->getDeclContext()));
return customAttrValues;
}

Expand All @@ -541,7 +553,9 @@ extractTypePropertyInfo(VarDecl *propertyDecl) {

if (const auto binding = propertyDecl->getParentPatternBinding()) {
if (const auto originalInit = binding->getInit(0)) {
return {propertyDecl, extractCompileTimeValue(originalInit),
return {propertyDecl,
extractCompileTimeValue(originalInit,
propertyDecl->getInnermostDeclContext()),
propertyWrapperValues};
}
}
Expand All @@ -551,9 +565,11 @@ extractTypePropertyInfo(VarDecl *propertyDecl) {
auto node = body->getFirstElement();
if (auto *stmt = node.dyn_cast<Stmt *>()) {
if (stmt->getKind() == StmtKind::Return) {
return {propertyDecl,
extractCompileTimeValue(cast<ReturnStmt>(stmt)->getResult()),
propertyWrapperValues};
return {
propertyDecl,
extractCompileTimeValue(cast<ReturnStmt>(stmt)->getResult(),
accessorDecl->getInnermostDeclContext()),
propertyWrapperValues};
}
}
}
Expand Down Expand Up @@ -992,16 +1008,19 @@ getResultBuilderElementFromASTNode(const ASTNode node) {
if (auto *D = node.dyn_cast<Decl *>()) {
if (auto *patternBinding = dyn_cast<PatternBindingDecl>(D)) {
if (auto originalInit = patternBinding->getOriginalInit(0)) {
return extractCompileTimeValue(originalInit);
return extractCompileTimeValue(
originalInit, patternBinding->getInnermostDeclContext());
}
}
}
return std::nullopt;
}

BuilderValue::ConditionalMember
getConditionalMemberFromIfStmt(const IfStmt *ifStmt) {
std::vector<AvailabilitySpec> AvailabilityAttributes;
getConditionalMemberFromIfStmt(const IfStmt *ifStmt,
const DeclContext *declContext) {
std::vector<BuilderValue::ConditionalMember::AvailabilitySpec>
AvailabilitySpecs;
std::vector<std::shared_ptr<BuilderValue::BuilderMember>> IfElements;
std::vector<std::shared_ptr<BuilderValue::BuilderMember>> ElseElements;
if (auto thenBraceStmt = ifStmt->getThenStmt()) {
Expand All @@ -1016,7 +1035,7 @@ getConditionalMemberFromIfStmt(const IfStmt *ifStmt) {
if (auto elseStmt = ifStmt->getElseStmt()) {
if (auto *elseIfStmt = dyn_cast<IfStmt>(elseStmt)) {
ElseElements.push_back(std::make_shared<BuilderValue::ConditionalMember>(
getConditionalMemberFromIfStmt(elseIfStmt)));
getConditionalMemberFromIfStmt(elseIfStmt, declContext)));
} else if (auto *elseBraceStmt = dyn_cast<BraceStmt>(elseStmt)) {
for (auto elem : elseBraceStmt->getElements()) {
if (auto memberElement = getResultBuilderElementFromASTNode(elem)) {
Expand All @@ -1035,20 +1054,22 @@ getConditionalMemberFromIfStmt(const IfStmt *ifStmt) {
if (elt.getKind() == StmtConditionElement::CK_Availability) {
for (auto *Q : elt.getAvailability()->getQueries()) {
if (Q->getPlatform() != PlatformKind::none) {
AvailabilityAttributes.push_back(*Q);
auto spec = BuilderValue::ConditionalMember::AvailabilitySpec(
*Q->getDomain(), Q->getVersion());
AvailabilitySpecs.push_back(spec);
}
}
memberKind = BuilderValue::LimitedAvailability;
break;
}
}

if (AvailabilityAttributes.empty()) {
if (AvailabilitySpecs.empty()) {
return BuilderValue::ConditionalMember(memberKind, IfElements,
ElseElements);
}

return BuilderValue::ConditionalMember(memberKind, AvailabilityAttributes,
return BuilderValue::ConditionalMember(memberKind, AvailabilitySpecs,
IfElements, ElseElements);
}

Expand All @@ -1067,7 +1088,8 @@ getBuildArrayMemberFromForEachStmt(const ForEachStmt *forEachStmt) {
}

std::vector<std::shared_ptr<BuilderValue::BuilderMember>>
getResultBuilderMembersFromBraceStmt(BraceStmt *braceStmt) {
getResultBuilderMembersFromBraceStmt(BraceStmt *braceStmt,
const DeclContext *declContext) {
std::vector<std::shared_ptr<BuilderValue::BuilderMember>>
ResultBuilderMembers;
for (auto elem : braceStmt->getElements()) {
Expand All @@ -1079,7 +1101,7 @@ getResultBuilderMembersFromBraceStmt(BraceStmt *braceStmt) {
if (auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
ResultBuilderMembers.push_back(
std::make_shared<BuilderValue::ConditionalMember>(
getConditionalMemberFromIfStmt(ifStmt)));
getConditionalMemberFromIfStmt(ifStmt, declContext)));
} else if (auto *doStmt = dyn_cast<DoStmt>(stmt)) {
if (auto body = doStmt->getBody()) {
for (auto elem : body->getElements()) {
Expand All @@ -1106,7 +1128,8 @@ createBuilderCompileTimeValue(CustomAttr *AttachedResultBuilder,
if (!VarDecl->getAllAccessors().empty()) {
if (auto accessor = VarDecl->getAllAccessors()[0]) {
if (auto braceStmt = accessor->getTypecheckedBody()) {
ResultBuilderMembers = getResultBuilderMembersFromBraceStmt(braceStmt);
ResultBuilderMembers = getResultBuilderMembersFromBraceStmt(
braceStmt, accessor->getDeclContext());
}
}
}
Expand Down Expand Up @@ -1159,12 +1182,13 @@ void writeBuilderMember(

default: {
auto member = cast<BuilderValue::ConditionalMember>(Member);
if (auto availabilityAttributes = member->getAvailabilityAttributes()) {
if (auto availabilitySpecs = member->getAvailabilitySpecs()) {
JSON.attributeArray("availabilityAttributes", [&] {
for (auto elem : *availabilityAttributes) {
for (auto elem : *availabilitySpecs) {
JSON.object([&] {
JSON.attribute("platform",
platformString(elem.getPlatform()).str());
JSON.attribute(
"platform",
platformString(elem.getDomain().getPlatformKind()).str());
JSON.attribute("minVersion", elem.getVersion().getAsString());
});
}
Expand Down