Skip to content

Commit 90118b3

Browse files
committed
[MLIR][analysis] Lattice: Fix automatic delegation of meet to lattice value classes
The class `Lattice` should automatically delegate invocations of the meet operator to the meet operation of the associated lattice value class if that class provides a static function called `meet`. This process fails for two reasons: 1. `Lattice::has_meet` checks for a member function `meet` without arguments of the lattice value class, although it should check for a static member function. 2. The function template `Lattice::meet<VT>()` implementing the default meet operation directly in the lattice is always present and takes precedence over the delegating function template `Lattice::meet<VT, std::integral_constant<bool, true>>()`. This change fixes the automatic delegation of the meet operation of a lattice to the lattice value class in the presence of a static `meet` function by conditionally enabling either the delegating function template or the non-delegating function template and by changing `Lattice::has_meet` so that it checks for a static `meet` member function in the lattice value type. The test from `TestSparseBackwardDataFlowAnalysis.cpp` is changed, such that the `meet` function is not provided directly in the `WrittenTo` lattice, but by the `Lattice` base class in order to trigger delegation to a lattice value class.
1 parent b8ed69e commit 90118b3

File tree

2 files changed

+47
-21
lines changed

2 files changed

+47
-21
lines changed

mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,15 @@ class Lattice : public AbstractSparseLattice {
132132
/// analysis, lattices will only have a `join`, no `meet`, but we want to use
133133
/// the same `Lattice` class for both directions.
134134
template <typename T, typename... Args>
135-
using has_meet = decltype(std::declval<T>().meet());
135+
using has_meet = decltype(&T::meet);
136136
template <typename T>
137137
using lattice_has_meet = llvm::is_detected<has_meet, T>;
138138

139139
/// Meet (intersect) the information contained in the 'rhs' value with this
140140
/// lattice. Returns if the state of the current lattice changed. If the
141141
/// lattice elements don't have a `meet` method, this is a no-op (see below.)
142-
template <typename VT, std::enable_if_t<lattice_has_meet<VT>::value>>
142+
template <typename VT,
143+
std::enable_if_t<lattice_has_meet<VT>::value> * = nullptr>
143144
ChangeResult meet(const VT &rhs) {
144145
ValueT newValue = ValueT::meet(value, rhs);
145146
assert(ValueT::meet(newValue, value) == newValue &&
@@ -155,7 +156,8 @@ class Lattice : public AbstractSparseLattice {
155156
return ChangeResult::Change;
156157
}
157158

158-
template <typename VT>
159+
template <typename VT,
160+
std::enable_if_t<!lattice_has_meet<VT>::value> * = nullptr>
159161
ChangeResult meet(const VT &rhs) {
160162
return ChangeResult::NoChange;
161163
}

mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,54 @@ using namespace mlir::dataflow;
1818

1919
namespace {
2020

21-
/// This lattice represents, for a given value, the set of memory resources that
22-
/// this value, or anything derived from this value, is potentially written to.
23-
struct WrittenTo : public AbstractSparseLattice {
24-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
25-
using AbstractSparseLattice::AbstractSparseLattice;
21+
/// Lattice value storing the a set of memory resources that something
22+
/// is written to.
23+
struct WrittenToLatticeValue {
24+
bool operator==(const WrittenToLatticeValue &other) {
25+
return this->writes == other.writes;
26+
}
2627

27-
void print(raw_ostream &os) const override {
28-
os << "[";
29-
llvm::interleave(
30-
writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
31-
os << "]";
28+
static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs,
29+
const WrittenToLatticeValue &rhs) {
30+
WrittenToLatticeValue res = lhs;
31+
(void)res.addWrites(rhs.writes);
32+
33+
return res;
3234
}
35+
36+
static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs,
37+
const WrittenToLatticeValue &rhs) {
38+
// Should not be triggered by this test, but required by `Lattice<T>`
39+
assert(false);
40+
}
41+
3342
ChangeResult addWrites(const SetVector<StringAttr> &writes) {
3443
int sizeBefore = this->writes.size();
3544
this->writes.insert(writes.begin(), writes.end());
3645
int sizeAfter = this->writes.size();
3746
return sizeBefore == sizeAfter ? ChangeResult::NoChange
3847
: ChangeResult::Change;
3948
}
40-
ChangeResult meet(const AbstractSparseLattice &other) override {
41-
const auto *rhs = reinterpret_cast<const WrittenTo *>(&other);
42-
return addWrites(rhs->writes);
49+
50+
void print(raw_ostream &os) const {
51+
os << "[";
52+
llvm::interleave(
53+
writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
54+
os << "]";
4355
}
4456

57+
void clear() { writes.clear(); }
58+
4559
SetVector<StringAttr> writes;
4660
};
4761

62+
/// This lattice represents, for a given value, the set of memory resources that
63+
/// this value, or anything derived from this value, is potentially written to.
64+
struct WrittenTo : public Lattice<WrittenToLatticeValue> {
65+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
66+
using Lattice::Lattice;
67+
};
68+
4869
/// An analysis that, by going backwards along the dataflow graph, annotates
4970
/// each value with all the memory resources it (or anything derived from it)
5071
/// is eventually written to.
@@ -65,7 +86,9 @@ class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
6586
void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
6687
ArrayRef<const WrittenTo *> results) override;
6788

68-
void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
89+
void setToExitState(WrittenTo *lattice) override {
90+
lattice->getValue().clear();
91+
}
6992

7093
private:
7194
bool assumeFuncWrites;
@@ -77,7 +100,8 @@ void WrittenToAnalysis::visitOperation(Operation *op,
77100
if (auto store = dyn_cast<memref::StoreOp>(op)) {
78101
SetVector<StringAttr> newWrites;
79102
newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
80-
propagateIfChanged(operands[0], operands[0]->addWrites(newWrites));
103+
propagateIfChanged(operands[0],
104+
operands[0]->getValue().addWrites(newWrites));
81105
return;
82106
} // By default, every result of an op depends on every operand.
83107
for (const WrittenTo *r : results) {
@@ -95,7 +119,7 @@ void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
95119
newWrites.insert(
96120
StringAttr::get(operand.getOwner()->getContext(),
97121
"brancharg" + Twine(operand.getOperandNumber())));
98-
propagateIfChanged(lattice, lattice->addWrites(newWrites));
122+
propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
99123
}
100124

101125
void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
@@ -105,7 +129,7 @@ void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
105129
newWrites.insert(
106130
StringAttr::get(operand.getOwner()->getContext(),
107131
"callarg" + Twine(operand.getOperandNumber())));
108-
propagateIfChanged(lattice, lattice->addWrites(newWrites));
132+
propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
109133
}
110134

111135
void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
@@ -124,7 +148,7 @@ void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
124148
call.getOperation()->getName().getStringRef());
125149
}
126150
newWrites.insert(name);
127-
propagateIfChanged(lattice, lattice->addWrites(newWrites));
151+
propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
128152
}
129153
}
130154

0 commit comments

Comments
 (0)