Skip to content

Commit 398599f

Browse files
Peiming LiuPeimingLiu
authored andcommitted
revert unintended change
1 parent 4dbf038 commit 398599f

File tree

2 files changed

+59
-83
lines changed

2 files changed

+59
-83
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 52 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,19 @@ using ValueTuple = std::tuple<Value, Value, Value>;
4545
//===----------------------------------------------------------------------===//
4646

4747
namespace {
48+
class SparseLevel : public SparseTensorLevel {
49+
public:
50+
SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
51+
Value crdBuffer)
52+
: SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
53+
54+
Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
55+
return genIndexLoad(b, l, crdBuffer, iv);
56+
}
57+
58+
protected:
59+
const Value crdBuffer;
60+
};
4861

4962
class DenseLevel : public SparseTensorLevel {
5063
public:
@@ -60,27 +73,53 @@ class DenseLevel : public SparseTensorLevel {
6073
Value max) const override {
6174
assert(max == nullptr && "Dense level can not be non-unique.");
6275
if (encoded) {
63-
Value posLo = MULI(p, getSize());
64-
return {posLo, getSize()};
76+
Value posLo = MULI(p, lvlSize);
77+
return {posLo, lvlSize};
6578
}
6679
// No need to linearize the position for non-annotated tensors.
67-
return {C_IDX(0), getSize()};
80+
return {C_IDX(0), lvlSize};
6881
}
6982

7083
const bool encoded;
7184
};
7285

73-
class SparseLevel : public SparseTensorLevel {
86+
class CompressedLevel : public SparseLevel {
7487
public:
75-
SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
76-
ValueRange lvlBuf)
77-
: SparseTensorLevel(tid, lvl, lt, lvlSize, lvlBuf) {
78-
assert(!lvlBuf.empty());
88+
CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
89+
Value posBuffer, Value crdBuffer)
90+
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
91+
92+
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
93+
Value max) const override {
94+
if (max == nullptr) {
95+
Value pLo = genIndexLoad(b, l, posBuffer, p);
96+
Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
97+
return {pLo, pHi};
98+
}
99+
llvm_unreachable("compressed-nu should be the first non-unique level.");
79100
}
80101

81-
Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
82-
return genIndexLoad(b, l, getLvlBufs().front(), iv);
102+
private:
103+
const Value posBuffer;
104+
};
105+
106+
class LooseCompressedLevel : public SparseLevel {
107+
public:
108+
LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
109+
Value posBuffer, Value crdBuffer)
110+
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
111+
112+
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
113+
Value max) const override {
114+
assert(max == nullptr && "loss compressed level can not be non-unique.");
115+
p = MULI(p, C_IDX(2));
116+
Value pLo = genIndexLoad(b, l, posBuffer, p);
117+
Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
118+
return {pLo, pHi};
83119
}
120+
121+
private:
122+
const Value posBuffer;
84123
};
85124

86125
class SingletonLevel : public SparseLevel {
@@ -102,8 +141,8 @@ class SingletonLevel : public SparseLevel {
102141
class TwoOutFourLevel : public SparseLevel {
103142
public:
104143
TwoOutFourLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
105-
Value crdBuf)
106-
: SparseLevel(tid, lvl, lt, lvlSize, crdBuf) {}
144+
Value crdBuffer)
145+
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
107146

108147
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
109148
Value max) const override {
@@ -114,39 +153,6 @@ class TwoOutFourLevel : public SparseLevel {
114153
}
115154
};
116155

117-
class CompressedLevel : public SparseLevel {
118-
public:
119-
CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
120-
Value posBuffer, Value crdBuffer)
121-
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
122-
123-
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
124-
Value max) const override {
125-
if (max == nullptr) {
126-
Value pLo = genIndexLoad(b, l, getPosBuf(), p);
127-
Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
128-
return {pLo, pHi};
129-
}
130-
llvm_unreachable("compressed-nu should be the first non-unique level.");
131-
}
132-
};
133-
134-
class LooseCompressedLevel : public SparseLevel {
135-
public:
136-
LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
137-
Value posBuffer, Value crdBuffer)
138-
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
139-
140-
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
141-
Value max) const override {
142-
assert(max == nullptr && "loss compressed level can not be non-unique.");
143-
p = MULI(p, C_IDX(2));
144-
Value pLo = genIndexLoad(b, l, getPosBuf(), p);
145-
Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
146-
return {pLo, pHi};
147-
}
148-
};
149-
150156
} // namespace
151157

152158
//===----------------------------------------------------------------------===//
@@ -195,9 +201,7 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
195201
//===----------------------------------------------------------------------===//
196202
// SparseIterator derived classes.
197203
//===----------------------------------------------------------------------===//
198-
199-
namespace mlir {
200-
namespace sparse_tensor {
204+
namespace {
201205

202206
// The iterator that traverses a concrete sparse tensor levels. High-level
203207
// abstract iterators wrap it to achieve more complex goals (such as collapsing
@@ -232,11 +236,6 @@ class ConcreteIterator : public SparseIterator {
232236
SmallVector<Value> cursorValsStorage;
233237
};
234238

235-
} // namespace sparse_tensor
236-
} // namespace mlir
237-
238-
namespace {
239-
240239
class TrivialIterator : public ConcreteIterator {
241240
public:
242241
TrivialIterator(const SparseTensorLevel &stl)

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
namespace mlir {
1616
namespace sparse_tensor {
1717

18-
class ConcreteIterator;
19-
20-
/// The base class for all types of sparse tensor levels. It provides
21-
/// interfaces to query the loop range (see `peekRangeAt`) and look up the
22-
/// coordinates (see `peekCrdAt`).
18+
/// The base class for all types of sparse tensor levels. It provides interfaces
19+
/// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
20+
/// `peekCrdAt`).
2321
class SparseTensorLevel {
2422
SparseTensorLevel(SparseTensorLevel &&) = delete;
2523
SparseTensorLevel(const SparseTensorLevel &) = delete;
@@ -33,6 +31,7 @@ class SparseTensorLevel {
3331
return std::string(toMLIRString(lt)) + "[" + std::to_string(tid) + "," +
3432
std::to_string(lvl) + "]";
3533
}
34+
3635
virtual Value peekCrdAt(OpBuilder &b, Location l, Value iv) const = 0;
3736

3837
/// Peeks the lower and upper bound to *fully* traverse the level with
@@ -53,17 +52,7 @@ class SparseTensorLevel {
5352

5453
Level getLevel() const { return lvl; }
5554
LevelType getLT() const { return lt; }
56-
Value getSize() const { return lvlVals.front(); }
57-
Value getCrdBuf() const {
58-
assert(lvlVals.size() > 1);
59-
return lvlVals[1];
60-
}
61-
Value getPosBuf() const {
62-
assert(lvlVals.size() > 2);
63-
return lvlVals[2];
64-
}
65-
ValueRange getLvlVals() const { return lvlVals; }
66-
ValueRange getLvlBufs() const { return ValueRange(lvlVals).drop_front(); }
55+
Value getSize() const { return lvlSize; }
6756

6857
//
6958
// Level properties
@@ -72,24 +61,12 @@ class SparseTensorLevel {
7261

7362
protected:
7463
SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize)
75-
: tid(tid), lvl(lvl), lt(lt), lvlVals() {
76-
lvlVals.push_back(lvlSize);
77-
};
78-
79-
SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize,
80-
ValueRange lvlBufs)
81-
: tid(tid), lvl(lvl), lt(lt), lvlVals() {
82-
lvlVals.push_back(lvlSize);
83-
lvlVals.append(lvlBufs.begin(), lvlBufs.end());
84-
};
64+
: tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize){};
8565

8666
public:
8767
const unsigned tid, lvl;
8868
const LevelType lt;
89-
// The first value in the vector is always lvlsize; for sparse levels, the
90-
// second value is always the coordinate buffer; for sparse level with
91-
// position buffers, the third value is always the position buffer.
92-
SmallVector<Value, 3> lvlVals;
69+
const Value lvlSize;
9370
};
9471

9572
enum class IterKind : uint8_t {

0 commit comments

Comments
 (0)