Skip to content

[Constraint solver] Introduce one-way binding constraints. #25983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions include/swift/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -5321,6 +5321,33 @@ class KeyPathDotExpr : public Expr {
}
};

/// Expression node that effects a "one-way" constraint in
/// the constraint system, allowing type information to flow from the
/// subexpression outward but not the other way.
///
/// One-way expressions are generally implicit and synthetic, introduced by
/// the type checker. However, there is a built-in expression of the
/// form \c Builtin.one_way(x) that forms a one-way constraint coming out
/// of expression `x` that can be used for testing purposes.
class OneWayExpr : public Expr {
Expr *SubExpr;

public:
/// Construct an implicit one-way expression from the given subexpression.
OneWayExpr(Expr *subExpr)
: Expr(ExprKind::OneWay, /*isImplicit=*/true), SubExpr(subExpr) { }

SourceLoc getLoc() const { return SubExpr->getLoc(); }
SourceRange getSourceRange() const { return SubExpr->getSourceRange(); }

Expr *getSubExpr() const { return SubExpr; }
void setSubExpr(Expr *subExpr) { SubExpr = subExpr; }

static bool classof(const Expr *E) {
return E->getKind() == ExprKind::OneWay;
}
};

inline bool Expr::isInfixOperator() const {
return isa<BinaryExpr>(this) || isa<IfExpr>(this) ||
isa<AssignExpr>(this) || isa<ExplicitCastExpr>(this);
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/ExprNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ EXPR(EditorPlaceholder, Expr)
EXPR(ObjCSelector, Expr)
EXPR(KeyPath, Expr)
UNCHECKED_EXPR(KeyPathDot, Expr)
UNCHECKED_EXPR(OneWay, Expr)
EXPR(Tap, Expr)
LAST_EXPR(Tap)

Expand Down
3 changes: 3 additions & 0 deletions include/swift/Basic/LangOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ namespace swift {
/// before termination of the shrink phrase of the constraint solver.
unsigned SolverShrinkUnsolvedThreshold = 10;

/// Enable one-way constraints in function builders.
bool FunctionBuilderOneWayConstraints = false;

/// Disable the shrink phase of the expression type checker.
bool SolverDisableShrink = false;

Expand Down
4 changes: 4 additions & 0 deletions include/swift/Option/FrontendOptions.td
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,10 @@ def Rmodule_interface_rebuild : Flag<["-"], "Rmodule-interface-rebuild">,

def solver_expression_time_threshold_EQ : Joined<["-"], "solver-expression-time-threshold=">;

def enable_function_builder_one_way_constraints : Flag<["-"],
"enable-function-builder-one-way-constraints">,
HelpText<"Enable one-way constraints in the function builder transformation">;

def solver_disable_shrink :
Flag<["-"], "solver-disable-shrink">,
HelpText<"Disable the shrink phase of expression type checking">;
Expand Down
7 changes: 7 additions & 0 deletions lib/AST/ASTDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2794,6 +2794,13 @@ class PrintExpr : public ExprVisitor<PrintExpr> {
PrintWithColorRAII(OS, ParenthesisColor) << ')';
}

void visitOneWayExpr(OneWayExpr *E) {
printCommon(E, "one_way_expr");
OS << '\n';
printRec(E->getSubExpr());
PrintWithColorRAII(OS, ParenthesisColor) << ')';
}

void visitTapExpr(TapExpr *E) {
printCommon(E, "tap_expr");
PrintWithColorRAII(OS, DeclColor) << " var=";
Expand Down
12 changes: 12 additions & 0 deletions lib/AST/ASTWalker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,18 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,

Expr *visitKeyPathDotExpr(KeyPathDotExpr *E) { return E; }

Expr *visitOneWayExpr(OneWayExpr *E) {
if (auto oldSubExpr = E->getSubExpr()) {
if (auto subExpr = doIt(oldSubExpr)) {
E->setSubExpr(subExpr);
} else {
return nullptr;
}
}

return E;
}

Expr *visitTapExpr(TapExpr *E) {
if (auto oldSubExpr = E->getSubExpr()) {
if (auto subExpr = doIt(oldSubExpr)) {
Expand Down
2 changes: 2 additions & 0 deletions lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ ConcreteDeclRef Expr::getReferencedDecl() const {
NO_REFERENCE(ObjCSelector);
NO_REFERENCE(KeyPath);
NO_REFERENCE(KeyPathDot);
PASS_THROUGH_REFERENCE(OneWay, getSubExpr);
NO_REFERENCE(Tap);

#undef SIMPLE_REFERENCE
Expand Down Expand Up @@ -539,6 +540,7 @@ bool Expr::canAppendPostfixExpression(bool appendingPostfixOperator) const {
case ExprKind::Error:
case ExprKind::CodeCompletion:
case ExprKind::LazyInitializer:
case ExprKind::OneWay:
return false;

case ExprKind::NilLiteral:
Expand Down
2 changes: 2 additions & 0 deletions lib/Frontend/CompilerInvocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,

if (Args.getLastArg(OPT_solver_disable_shrink))
Opts.SolverDisableShrink = true;
if (Args.getLastArg(OPT_enable_function_builder_one_way_constraints))
Opts.FunctionBuilderOneWayConstraints = true;

if (const Arg *A = Args.getLastArg(OPT_value_recursion_threshold)) {
unsigned threshold;
Expand Down
23 changes: 17 additions & 6 deletions lib/Sema/BuilderTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ class BuilderClosureVisitor

private:
/// Produce a builder call to the given named function with the given arguments.
CallExpr *buildCallIfWanted(SourceLoc loc,
Identifier fnName, ArrayRef<Expr *> args,
ArrayRef<Identifier> argLabels = {}) {
Expr *buildCallIfWanted(SourceLoc loc,
Identifier fnName, ArrayRef<Expr *> args,
ArrayRef<Identifier> argLabels = {}) {
if (!wantExpr)
return nullptr;

Expand Down Expand Up @@ -81,9 +81,17 @@ class BuilderClosureVisitor
typeExpr, loc, fnName, DeclNameLoc(loc), /*implicit=*/true);
SourceLoc openLoc = args.empty() ? loc : args.front()->getStartLoc();
SourceLoc closeLoc = args.empty() ? loc : args.back()->getEndLoc();
return CallExpr::create(ctx, memberRef, openLoc, args,
argLabels, argLabelLocs, closeLoc,
/*trailing closure*/ nullptr, /*implicit*/true);
Expr *result = CallExpr::create(ctx, memberRef, openLoc, args,
argLabels, argLabelLocs, closeLoc,
/*trailing closure*/ nullptr,
/*implicit*/true);

if (ctx.LangOpts.FunctionBuilderOneWayConstraints) {
// Form a one-way constraint to prevent backward propagation.
result = new (ctx) OneWayExpr(result);
}

return result;
}

/// Check whether the builder supports the given operation.
Expand Down Expand Up @@ -160,6 +168,9 @@ class BuilderClosureVisitor
}

auto expr = node.get<Expr *>();
if (wantExpr && ctx.LangOpts.FunctionBuilderOneWayConstraints)
expr = new (ctx) OneWayExpr(expr);

expressions.push_back(expr);
}

Expand Down
4 changes: 4 additions & 0 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4634,6 +4634,10 @@ namespace {
llvm_unreachable("found KeyPathDotExpr in CSApply");
}

Expr *visitOneWayExpr(OneWayExpr *E) {
return E->getSubExpr();
}

Expr *visitTapExpr(TapExpr *E) {
auto type = simplifyType(cs.getType(E));

Expand Down
12 changes: 12 additions & 0 deletions lib/Sema/CSBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,18 @@ ConstraintSystem::getPotentialBindings(TypeVariableType *typeVar) {
result.FullyBound = true;
}
break;

case ConstraintKind::OneWayBind: {
// Don't produce any bindings if this type variable is on the left-hand
// side of a one-way binding.
auto firstType = constraint->getFirstType();
if (auto *tv = firstType->getAs<TypeVariableType>()) {
if (tv->getImpl().getRepresentative(nullptr) == typeVar)
return {typeVar};
}

break;
}
}
}

Expand Down
31 changes: 23 additions & 8 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3080,6 +3080,14 @@ namespace {
llvm_unreachable("found KeyPathDotExpr in CSGen");
}

Type visitOneWayExpr(OneWayExpr *expr) {
auto locator = CS.getConstraintLocator(expr);
auto resultTypeVar = CS.createTypeVariable(locator, 0);
CS.addConstraint(ConstraintKind::OneWayBind, resultTypeVar,
CS.getType(expr->getSubExpr()), locator);
return resultTypeVar;
}

Type visitTapExpr(TapExpr *expr) {
DeclContext *varDC = expr->getVar()->getDeclContext();
assert(varDC == CS.DC || (varDC && isa<AbstractClosureExpr>(varDC)) &&
Expand Down Expand Up @@ -3113,7 +3121,8 @@ namespace {
Join,
JoinInout,
JoinMeta,
JoinNonexistent
JoinNonexistent,
OneWay,
};

static TypeOperation getTypeOperation(UnresolvedDotExpr *UDE,
Expand All @@ -3127,6 +3136,7 @@ namespace {

return llvm::StringSwitch<TypeOperation>(
UDE->getName().getBaseIdentifier().str())
.Case("one_way", TypeOperation::OneWay)
.Case("type_join", TypeOperation::Join)
.Case("type_join_inout", TypeOperation::JoinInout)
.Case("type_join_meta", TypeOperation::JoinMeta)
Expand All @@ -3135,14 +3145,14 @@ namespace {
}

Type resultOfTypeOperation(TypeOperation op, Expr *Arg) {
auto *tuple = dyn_cast<TupleExpr>(Arg);
assert(tuple && "Expected argument tuple for join operations!");
auto *tuple = cast<TupleExpr>(Arg);

auto *lhs = tuple->getElement(0);
auto *rhs = tuple->getElement(1);

switch (op) {
case TypeOperation::None:
case TypeOperation::OneWay:
llvm_unreachable(
"We should have a valid type operation at this point!");

Expand Down Expand Up @@ -3582,18 +3592,23 @@ namespace {
/// Once we've visited the children of the given expression,
/// generate constraints from the expression.
Expr *walkToExprPost(Expr *expr) override {

// Handle the Builtin.type_join* family of calls by replacing
// them with dot_self_expr of type_expr with the type being the
// result of the join.
// Translate special type-checker Builtin calls into simpler expressions.
if (auto *apply = dyn_cast<ApplyExpr>(expr)) {
auto fnExpr = apply->getFn();
if (auto *UDE = dyn_cast<UnresolvedDotExpr>(fnExpr)) {
auto &CS = CG.getConstraintSystem();
auto typeOperation =
ConstraintGenerator::getTypeOperation(UDE, CS.getASTContext());

if (typeOperation != ConstraintGenerator::TypeOperation::None) {
if (typeOperation == ConstraintGenerator::TypeOperation::OneWay) {
// For a one-way constraint, create the OneWayExpr node.
auto *arg = cast<ParenExpr>(apply->getArg())->getSubExpr();
expr = new (CS.getASTContext()) OneWayExpr(arg);
} else if (typeOperation !=
ConstraintGenerator::TypeOperation::None) {
// Handle the Builtin.type_join* family of calls by replacing
// them with dot_self_expr of type_expr with the type being the
// result of the join.
auto joinMetaTy =
CG.resultOfTypeOperation(typeOperation, apply->getArg());
auto joinTy = joinMetaTy->castTo<MetatypeType>()->getInstanceType();
Expand Down
37 changes: 37 additions & 0 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2,
case ConstraintKind::BridgingConversion:
case ConstraintKind::FunctionInput:
case ConstraintKind::FunctionResult:
case ConstraintKind::OneWayBind:
llvm_unreachable("Not a conversion");
}

Expand Down Expand Up @@ -1208,6 +1209,7 @@ static bool matchFunctionRepresentations(FunctionTypeRepresentation rep1,
case ConstraintKind::ValueMember:
case ConstraintKind::FunctionInput:
case ConstraintKind::FunctionResult:
case ConstraintKind::OneWayBind:
return false;
}

Expand Down Expand Up @@ -1385,6 +1387,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
case ConstraintKind::BridgingConversion:
case ConstraintKind::FunctionInput:
case ConstraintKind::FunctionResult:
case ConstraintKind::OneWayBind:
llvm_unreachable("Not a relational constraint");
}

Expand Down Expand Up @@ -2768,6 +2771,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
case ConstraintKind::ValueMember:
case ConstraintKind::FunctionInput:
case ConstraintKind::FunctionResult:
case ConstraintKind::OneWayBind:
llvm_unreachable("Not a relational constraint");
}
}
Expand Down Expand Up @@ -5175,6 +5179,29 @@ ConstraintSystem::simplifyDefaultableConstraint(
return SolutionKind::Solved;
}

ConstraintSystem::SolutionKind
ConstraintSystem::simplifyOneWayConstraint(
ConstraintKind kind,
Type first, Type second, TypeMatchOptions flags,
ConstraintLocatorBuilder locator) {
// Determine whether the second type can be fully simplified. Only then
// will this constraint be resolved.
Type secondSimplified = simplifyType(second);
if (secondSimplified->hasTypeVariable()) {
if (flags.contains(TMF_GenerateConstraints)) {
addUnsolvedConstraint(
Constraint::create(*this, kind, first, second,
getConstraintLocator(locator)));
return SolutionKind::Solved;
}

return SolutionKind::Unsolved;
}

// Translate this constraint into a one-way binding constraint.
return matchTypes(first, secondSimplified, ConstraintKind::Bind, flags,
locator);
}

ConstraintSystem::SolutionKind
ConstraintSystem::simplifyDynamicTypeOfConstraint(
Expand Down Expand Up @@ -7203,6 +7230,9 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
return simplifyFunctionComponentConstraint(kind, first, second,
subflags, locator);

case ConstraintKind::OneWayBind:
return simplifyOneWayConstraint(kind, first, second, subflags, locator);

case ConstraintKind::ValueMember:
case ConstraintKind::UnresolvedValueMember:
case ConstraintKind::BindOverload:
Expand Down Expand Up @@ -7558,6 +7588,13 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) {
case ConstraintKind::Disjunction:
// Disjunction constraints are never solved here.
return SolutionKind::Unsolved;

case ConstraintKind::OneWayBind:
return simplifyOneWayConstraint(constraint.getKind(),
constraint.getFirstType(),
constraint.getSecondType(),
TMF_GenerateConstraints,
constraint.getLocator());
}

llvm_unreachable("Unhandled ConstraintKind in switch.");
Expand Down
3 changes: 3 additions & 0 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,11 @@ Solution ConstraintSystem::finalize() {
// multiple entries. We should use an optimized PartialSolution
// structure for that use case, which would optimize a lot of
// stuff here.
#if false
assert((solution.OpenedTypes.count(opened.first) == 0 ||
solution.OpenedTypes[opened.first] == opened.second)
&& "Already recorded");
#endif
solution.OpenedTypes.insert(opened);
}

Expand Down Expand Up @@ -1681,6 +1683,7 @@ void ConstraintSystem::ArgumentInfoCollector::walk(Type argType) {
case ConstraintKind::SelfObjectOfProtocol:
case ConstraintKind::ConformsTo:
case ConstraintKind::Defaultable:
case ConstraintKind::OneWayBind:
break;
}
}
Expand Down
Loading