Skip to content

Commit 4efc35a

Browse files
committed
[AST] Expand TypeJoin expression to support joining over a type
1 parent 7ae1be4 commit 4efc35a

File tree

6 files changed

+61
-15
lines changed

6 files changed

+61
-15
lines changed

include/swift/AST/Expr.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6006,11 +6006,24 @@ class TypeJoinExpr final : public Expr,
60066006
return { getTrailingObjects<Expr *>(), getNumElements() };
60076007
}
60086008

6009-
TypeJoinExpr(DeclRefExpr *var, ArrayRef<Expr *> elements);
6009+
TypeJoinExpr(llvm::PointerUnion<DeclRefExpr *, TypeBase *> result,
6010+
ArrayRef<Expr *> elements);
6011+
6012+
static TypeJoinExpr *
6013+
createImpl(ASTContext &ctx,
6014+
llvm::PointerUnion<DeclRefExpr *, TypeBase *> varOrType,
6015+
ArrayRef<Expr *> elements);
60106016

60116017
public:
60126018
static TypeJoinExpr *create(ASTContext &ctx, DeclRefExpr *var,
6013-
ArrayRef<Expr *> exprs);
6019+
ArrayRef<Expr *> exprs) {
6020+
return createImpl(ctx, var, exprs);
6021+
}
6022+
6023+
static TypeJoinExpr *create(ASTContext &ctx, Type joinType,
6024+
ArrayRef<Expr *> exprs) {
6025+
return createImpl(ctx, joinType.getPointer(), exprs);
6026+
}
60146027

60156028
SourceLoc getLoc() const { return SourceLoc(); }
60166029
SourceRange getSourceRange() const { return SourceRange(); }

lib/AST/ASTDumper.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2995,8 +2995,12 @@ class PrintExpr : public ExprVisitor<PrintExpr> {
29952995
void visitTypeJoinExpr(TypeJoinExpr *E) {
29962996
printCommon(E, "type_join_expr");
29972997

2998-
PrintWithColorRAII(OS, DeclColor) << " var=";
2999-
printRec(E->getVar());
2998+
if (auto *var = E->getVar()) {
2999+
PrintWithColorRAII(OS, DeclColor) << " var=";
3000+
printRec(var);
3001+
OS << '\n';
3002+
}
3003+
30003004
OS << '\n';
30013005

30023006
for (auto *member : E->getElements()) {

lib/AST/ASTWalker.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,10 +1260,12 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
12601260
}
12611261

12621262
Expr *visitTypeJoinExpr(TypeJoinExpr *E) {
1263-
if (auto *newVar = dyn_cast<DeclRefExpr>(doIt(E->getVar()))) {
1264-
E->setVar(newVar);
1265-
} else {
1266-
return nullptr;
1263+
if (auto *var = E->getVar()) {
1264+
if (auto *newVar = dyn_cast<DeclRefExpr>(doIt(var))) {
1265+
E->setVar(newVar);
1266+
} else {
1267+
return nullptr;
1268+
}
12671269
}
12681270

12691271
for (unsigned i = 0, e = E->getNumElements(); i != e; ++i) {

lib/AST/Expr.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2516,21 +2516,31 @@ RegexLiteralExpr::createParsed(ASTContext &ctx, SourceLoc loc,
25162516
/*implicit*/ false);
25172517
}
25182518

2519-
TypeJoinExpr::TypeJoinExpr(DeclRefExpr *varRef, ArrayRef<Expr *> elements)
2520-
: Expr(ExprKind::TypeJoin, /*implicit=*/true, Type()), Var(varRef) {
2521-
assert(Var);
2519+
TypeJoinExpr::TypeJoinExpr(llvm::PointerUnion<DeclRefExpr *, TypeBase *> result,
2520+
ArrayRef<Expr *> elements)
2521+
: Expr(ExprKind::TypeJoin, /*implicit=*/true, Type()), Var(nullptr) {
2522+
2523+
if (auto *varRef = result.dyn_cast<DeclRefExpr *>()) {
2524+
assert(varRef);
2525+
Var = varRef;
2526+
} else {
2527+
auto joinType = Type(result.get<TypeBase *>());
2528+
assert(joinType && "expected non-null type");
2529+
setType(joinType);
2530+
}
25222531

25232532
Bits.TypeJoinExpr.NumElements = elements.size();
25242533
// Copy elements.
25252534
std::uninitialized_copy(elements.begin(), elements.end(),
25262535
getTrailingObjects<Expr *>());
25272536
}
25282537

2529-
TypeJoinExpr *TypeJoinExpr::create(ASTContext &ctx, DeclRefExpr *var,
2530-
ArrayRef<Expr *> elements) {
2538+
TypeJoinExpr *TypeJoinExpr::createImpl(
2539+
ASTContext &ctx, llvm::PointerUnion<DeclRefExpr *, TypeBase *> varOrType,
2540+
ArrayRef<Expr *> elements) {
25312541
size_t size = totalSizeToAlloc<Expr *>(elements.size());
25322542
void *mem = ctx.Allocate(size, alignof(TypeJoinExpr));
2533-
return new (mem) TypeJoinExpr(var, elements);
2543+
return new (mem) TypeJoinExpr(varOrType, elements);
25342544
}
25352545

25362546
SourceRange MacroExpansionExpr::getSourceRange() const {

lib/Sema/CSGen.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3725,7 +3725,16 @@ namespace {
37253725
CS.getConstraintLocator(element));
37263726
}
37273727

3728-
auto resultTy = CS.getType(expr->getVar());
3728+
Type resultTy;
3729+
3730+
if (auto *var = expr->getVar()) {
3731+
resultTy = CS.getType(var);
3732+
} else {
3733+
resultTy = expr->getType();
3734+
}
3735+
3736+
assert(resultTy);
3737+
37293738
// The type of a join expression is obtained by performing
37303739
// a "join-meet" operation on deduced types of its elements
37313740
// and the underlying variable.

lib/Sema/CSSyntacticElement.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ class TypeVariableRefFinder : public ASTWalker {
6161
ClosureDCs.push_back(closure);
6262
}
6363

64+
if (auto *joinExpr = dyn_cast<TypeJoinExpr>(expr)) {
65+
// If this join is over a known type, let's
66+
// analyze it too because it can contain type
67+
// variables.
68+
if (!joinExpr->getVar())
69+
inferVariables(joinExpr->getType());
70+
}
71+
6472
if (auto *DRE = dyn_cast<DeclRefExpr>(expr)) {
6573
auto *decl = DRE->getDecl();
6674

0 commit comments

Comments
 (0)