Skip to content

[ConstraintSystem] Track AST depth information directly #22082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/swift/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,10 +524,10 @@ class alignas(8) Expr {
/// the parent map.
llvm::DenseMap<Expr *, Expr *> getParentMap();

/// Produce a mapping from each subexpression to its depth in the root
/// expression. The root expression has depth 0, its children have depth
/// 1, etc.
llvm::DenseMap<Expr *, unsigned> getDepthMap();
/// Produce a mapping from each subexpression to its depth and parent,
/// in the root expression. The root expression has depth 0, its children have
/// depth 1, etc.
llvm::DenseMap<Expr *, std::pair<unsigned, Expr *>> getDepthMap();

/// Produce a mapping from each expression to its index according to a
/// preorder traversal of the expressions. The parent has index 0, its first
Expand Down
13 changes: 7 additions & 6 deletions lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,17 +698,18 @@ llvm::DenseMap<Expr *, Expr *> Expr::getParentMap() {
return parentMap;
}

llvm::DenseMap<Expr *, unsigned> Expr::getDepthMap() {
llvm::DenseMap<Expr *, std::pair<unsigned, Expr *>> Expr::getDepthMap() {
class RecordingTraversal : public ASTWalker {
public:
llvm::DenseMap<Expr *, unsigned> &DepthMap;
llvm::DenseMap<Expr *, std::pair<unsigned, Expr *>> &DepthMap;
unsigned Depth = 0;

explicit RecordingTraversal(llvm::DenseMap<Expr *, unsigned> &depthMap)
: DepthMap(depthMap) { }
explicit RecordingTraversal(
llvm::DenseMap<Expr *, std::pair<unsigned, Expr *>> &depthMap)
: DepthMap(depthMap) {}

std::pair<bool, Expr *> walkToExprPre(Expr *E) override {
DepthMap[E] = Depth;
DepthMap[E] = {Depth, Parent.getAsExpr()};
Depth++;
return { true, E };
}
Expand All @@ -719,7 +720,7 @@ llvm::DenseMap<Expr *, unsigned> Expr::getDepthMap() {
}
};

llvm::DenseMap<Expr *, unsigned> depthMap;
llvm::DenseMap<Expr *, std::pair<unsigned, Expr *>> depthMap;
RecordingTraversal traversal(depthMap);
walk(traversal);
return depthMap;
Expand Down
11 changes: 5 additions & 6 deletions lib/Sema/CSRanking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ static Type getUnlabeledType(Type type, ASTContext &ctx) {
SolutionCompareResult ConstraintSystem::compareSolutions(
ConstraintSystem &cs, ArrayRef<Solution> solutions,
const SolutionDiff &diff, unsigned idx1, unsigned idx2,
llvm::DenseMap<Expr *, unsigned> &weights) {
llvm::DenseMap<Expr *, std::pair<unsigned, Expr *>> &weights) {
if (cs.TC.getLangOpts().DebugConstraintSolver) {
auto &log = cs.getASTContext().TypeCheckerDebug->getStream();
log.indent(cs.solverState->depth * 2)
Expand Down Expand Up @@ -806,7 +806,7 @@ SolutionCompareResult ConstraintSystem::compareSolutions(
if (auto *anchor = locator->getAnchor()) {
auto weight = weights.find(anchor);
if (weight != weights.end())
return weight->getSecond() + 1;
return weight->getSecond().first + 1;
}

return 1;
Expand Down Expand Up @@ -1212,7 +1212,6 @@ SolutionCompareResult ConstraintSystem::compareSolutions(

Optional<unsigned>
ConstraintSystem::findBestSolution(SmallVectorImpl<Solution> &viable,
llvm::DenseMap<Expr *, unsigned> &weights,
bool minimize) {
if (viable.empty())
return None;
Expand All @@ -1236,7 +1235,7 @@ ConstraintSystem::findBestSolution(SmallVectorImpl<Solution> &viable,
SmallVector<bool, 16> losers(viable.size(), false);
unsigned bestIdx = 0;
for (unsigned i = 1, n = viable.size(); i != n; ++i) {
switch (compareSolutions(*this, viable, diff, i, bestIdx, weights)) {
switch (compareSolutions(*this, viable, diff, i, bestIdx, ExprWeights)) {
case SolutionCompareResult::Identical:
// FIXME: Might want to warn about this in debug builds, so we can
// find a way to eliminate the redundancy in the search space.
Expand All @@ -1260,7 +1259,7 @@ ConstraintSystem::findBestSolution(SmallVectorImpl<Solution> &viable,
if (i == bestIdx)
continue;

switch (compareSolutions(*this, viable, diff, bestIdx, i, weights)) {
switch (compareSolutions(*this, viable, diff, bestIdx, i, ExprWeights)) {
case SolutionCompareResult::Identical:
// FIXME: Might want to warn about this in debug builds, so we can
// find a way to eliminate the redundancy in the search space.
Expand Down Expand Up @@ -1312,7 +1311,7 @@ ConstraintSystem::findBestSolution(SmallVectorImpl<Solution> &viable,
if (losers[j])
continue;

switch (compareSolutions(*this, viable, diff, i, j, weights)) {
switch (compareSolutions(*this, viable, diff, i, j, ExprWeights)) {
case SolutionCompareResult::Identical:
// FIXME: Dub one of these the loser arbitrarily?
break;
Expand Down
18 changes: 7 additions & 11 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,16 +327,12 @@ void truncate(llvm::SmallSetVector<T, N> &vec, unsigned newSize) {
} // end anonymous namespace

ConstraintSystem::SolverState::SolverState(
Expr *const expr, ConstraintSystem &cs,
FreeTypeVariableBinding allowFreeTypeVariables)
ConstraintSystem &cs, FreeTypeVariableBinding allowFreeTypeVariables)
: CS(cs), AllowFreeTypeVariables(allowFreeTypeVariables) {
assert(!CS.solverState &&
"Constraint system should not already have solver state!");
CS.solverState = this;

if (expr)
ExprWeights = expr->getDepthMap();

++NumSolutionAttempts;
SolutionAttempt = NumSolutionAttempts;

Expand Down Expand Up @@ -498,12 +494,12 @@ Optional<Solution>
ConstraintSystem::solveSingle(FreeTypeVariableBinding allowFreeTypeVariables,
bool allowFixes) {

SolverState state(nullptr, *this, allowFreeTypeVariables);
SolverState state(*this, allowFreeTypeVariables);
state.recordFixes = allowFixes;

SmallVector<Solution, 4> solutions;
solve(solutions);
filterSolutions(solutions, state.ExprWeights);
filterSolutions(solutions);

if (solutions.size() != 1)
return Optional<Solution>();
Expand Down Expand Up @@ -538,7 +534,7 @@ bool ConstraintSystem::Candidate::solve(
};

// Allocate new constraint system for sub-expression.
ConstraintSystem cs(TC, DC, None);
ConstraintSystem cs(TC, DC, None, E);
cs.baseCS = &BaseCS;

// Set up expression type checker timer for the candidate.
Expand Down Expand Up @@ -589,7 +585,7 @@ bool ConstraintSystem::Candidate::solve(
// Try to solve the system and record all available solutions.
llvm::SmallVector<Solution, 2> solutions;
{
SolverState state(E, cs, FreeTypeVariableBinding::Allow);
SolverState state(cs, FreeTypeVariableBinding::Allow);

// Use solve which doesn't try to filter solution list.
// Because we want the whole set of possible domain choices.
Expand Down Expand Up @@ -1179,7 +1175,7 @@ bool ConstraintSystem::solve(Expr *const expr,
SmallVectorImpl<Solution> &solutions,
FreeTypeVariableBinding allowFreeTypeVariables) {
// Set up solver state.
SolverState state(expr, *this, allowFreeTypeVariables);
SolverState state(*this, allowFreeTypeVariables);

// Solve the system.
solve(solutions);
Expand All @@ -1200,7 +1196,7 @@ bool ConstraintSystem::solve(Expr *const expr,
// a single best solution to use, if not explicitly disabled
// by constraint system options.
if (!retainAllSolutions())
filterSolutions(solutions, state.ExprWeights);
filterSolutions(solutions);

// We fail if there is no solution or the expression was too complex.
return solutions.empty() || getExpressionTooComplex(solutions);
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/CSStep.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class SolverStep {

void filterSolutions(SmallVectorImpl<Solution> &solutions, bool minimize) {
if (!CS.retainAllSolutions())
CS.filterSolutions(solutions, CS.solverState->ExprWeights, minimize);
CS.filterSolutions(solutions, minimize);
}

/// Check whether constraint solver is running in "debug" mode,
Expand Down
18 changes: 11 additions & 7 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,15 @@ ExpressionTimer::~ExpressionTimer() {
}

ConstraintSystem::ConstraintSystem(TypeChecker &tc, DeclContext *dc,
ConstraintSystemOptions options)
ConstraintSystemOptions options,
Expr *expr)
: TC(tc), DC(dc), Options(options),
Arena(tc.Context, Allocator),
CG(*new ConstraintGraph(*this))
{
if (expr)
ExprWeights = expr->getDepthMap();

assert(DC && "context required");
}

Expand Down Expand Up @@ -2063,16 +2067,15 @@ bool ConstraintSystem::salvage(SmallVectorImpl<Solution> &viable, Expr *expr) {

{
// Set up solver state.
SolverState state(expr, *this, FreeTypeVariableBinding::Disallow);
SolverState state(*this, FreeTypeVariableBinding::Disallow);
state.recordFixes = true;

// Solve the system.
solve(viable);

// Check whether we have a best solution; this can happen if we found
// a series of fixes that worked.
if (auto best = findBestSolution(viable, state.ExprWeights,
/*minimize=*/true)) {
if (auto best = findBestSolution(viable, /*minimize=*/true)) {
if (*best != 0)
viable[0] = std::move(viable[*best]);
viable.erase(viable.begin() + 1, viable.end());
Expand Down Expand Up @@ -2289,10 +2292,11 @@ bool ConstraintSystem::diagnoseAmbiguity(Expr *expr,
if (it == indexMap.end())
continue;
unsigned index = it->second;
it = depthMap.find(anchor);
if (it == depthMap.end())

auto e = depthMap.find(anchor);
if (e == depthMap.end())
continue;
unsigned depth = it->second;
unsigned depth = e->second.first;

// If we don't have a name to hang on to, it'll be hard to diagnose this
// overload.
Expand Down
27 changes: 13 additions & 14 deletions lib/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,8 @@ class ConstraintSystem {

private:

llvm::DenseMap<Expr *, std::pair<unsigned, Expr *>> ExprWeights;

/// Allocator used for all of the related constraint systems.
llvm::BumpPtrAllocator Allocator;

Expand Down Expand Up @@ -1163,12 +1165,10 @@ class ConstraintSystem {

/// Describes the current solver state.
struct SolverState {
SolverState(Expr *const expr, ConstraintSystem &cs,
SolverState(ConstraintSystem &cs,
FreeTypeVariableBinding allowFreeTypeVariables);
~SolverState();

llvm::DenseMap<Expr *, unsigned> ExprWeights;

/// The constraint system.
ConstraintSystem &CS;

Expand Down Expand Up @@ -1513,7 +1513,8 @@ class ConstraintSystem {
};

ConstraintSystem(TypeChecker &tc, DeclContext *dc,
ConstraintSystemOptions options);
ConstraintSystemOptions options,
Expr *expr = nullptr);
~ConstraintSystem();

/// Retrieve the type checker associated with this constraint system.
Expand Down Expand Up @@ -1563,13 +1564,13 @@ class ConstraintSystem {
/// set of solutions should be filtered even if there is
/// no single best solution, see `findBestSolution` for
/// more details.
void filterSolutions(SmallVectorImpl<Solution> &solutions,
llvm::DenseMap<Expr *, unsigned> &weights,
bool minimize = false) {
void
filterSolutions(SmallVectorImpl<Solution> &solutions,
bool minimize = false) {
if (solutions.size() < 2)
return;

if (auto best = findBestSolution(solutions, weights, minimize)) {
if (auto best = findBestSolution(solutions, minimize)) {
if (*best != 0)
solutions[0] = std::move(solutions[*best]);
solutions.erase(solutions.begin() + 1, solutions.end());
Expand Down Expand Up @@ -3144,11 +3145,10 @@ class ConstraintSystem {
/// \param diff The differences among the solutions.
/// \param idx1 The index of the first solution.
/// \param idx2 The index of the second solution.
/// \param weights The weights of the sub-expressions used for ranking.
static SolutionCompareResult
compareSolutions(ConstraintSystem &cs, ArrayRef<Solution> solutions,
const SolutionDiff &diff, unsigned idx1, unsigned idx2,
llvm::DenseMap<Expr *, unsigned> &weights);
llvm::DenseMap<Expr *, std::pair<unsigned, Expr *>> &weights);

public:
/// Increase the score of the given kind for the current (partial) solution
Expand All @@ -3163,7 +3163,6 @@ class ConstraintSystem {
/// solution.
///
/// \param solutions The set of viable solutions to consider.
/// \param weights The weights of the sub-expressions used for ranking.
///
/// \param minimize If true, then in the case where there is no single
/// best solution, minimize the set of solutions by removing any solutions
Expand All @@ -3172,9 +3171,9 @@ class ConstraintSystem {
///
/// \returns The index of the best solution, or nothing if there was no
/// best solution.
Optional<unsigned> findBestSolution(SmallVectorImpl<Solution> &solutions,
llvm::DenseMap<Expr *, unsigned> &weights,
bool minimize);
Optional<unsigned>
findBestSolution(SmallVectorImpl<Solution> &solutions,
bool minimize);

/// Apply a given solution to the expression, producing a fully
/// type-checked expression.
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/TypeCheckConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2066,7 +2066,7 @@ Type TypeChecker::typeCheckExpressionImpl(Expr *&expr, DeclContext *dc,
if (options.contains(TypeCheckExprFlags::AllowUnresolvedTypeVariables))
csOptions |= ConstraintSystemFlags::AllowUnresolvedTypeVariables;

ConstraintSystem cs(*this, dc, csOptions);
ConstraintSystem cs(*this, dc, csOptions, expr);
cs.baseCS = baseCS;

// Verify that a purpose was specified if a convertType was. Note that it is
Expand Down