@@ -45,6 +45,19 @@ using ValueTuple = std::tuple<Value, Value, Value>;
45
45
// ===----------------------------------------------------------------------===//
46
46
47
47
namespace {
48
+ class SparseLevel : public SparseTensorLevel {
49
+ public:
50
+ SparseLevel (unsigned tid, Level lvl, LevelType lt, Value lvlSize,
51
+ Value crdBuffer)
52
+ : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
53
+
54
+ Value peekCrdAt (OpBuilder &b, Location l, Value iv) const override {
55
+ return genIndexLoad (b, l, crdBuffer, iv);
56
+ }
57
+
58
+ protected:
59
+ const Value crdBuffer;
60
+ };
48
61
49
62
class DenseLevel : public SparseTensorLevel {
50
63
public:
@@ -60,27 +73,53 @@ class DenseLevel : public SparseTensorLevel {
60
73
Value max) const override {
61
74
assert (max == nullptr && " Dense level can not be non-unique." );
62
75
if (encoded) {
63
- Value posLo = MULI (p, getSize () );
64
- return {posLo, getSize () };
76
+ Value posLo = MULI (p, lvlSize );
77
+ return {posLo, lvlSize };
65
78
}
66
79
// No need to linearize the position for non-annotated tensors.
67
- return {C_IDX (0 ), getSize () };
80
+ return {C_IDX (0 ), lvlSize };
68
81
}
69
82
70
83
const bool encoded;
71
84
};
72
85
73
- class SparseLevel : public SparseTensorLevel {
86
+ class CompressedLevel : public SparseLevel {
74
87
public:
75
- SparseLevel (unsigned tid, Level lvl, LevelType lt, Value lvlSize,
76
- ValueRange lvlBuf)
77
- : SparseTensorLevel(tid, lvl, lt, lvlSize, lvlBuf) {
78
- assert (!lvlBuf.empty ());
88
+ CompressedLevel (unsigned tid, Level lvl, LevelType lt, Value lvlSize,
89
+ Value posBuffer, Value crdBuffer)
90
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
91
+
92
+ ValuePair peekRangeAt (OpBuilder &b, Location l, Value p,
93
+ Value max) const override {
94
+ if (max == nullptr ) {
95
+ Value pLo = genIndexLoad (b, l, posBuffer, p);
96
+ Value pHi = genIndexLoad (b, l, posBuffer, ADDI (p, C_IDX (1 )));
97
+ return {pLo, pHi};
98
+ }
99
+ llvm_unreachable (" compressed-nu should be the first non-unique level." );
79
100
}
80
101
81
- Value peekCrdAt (OpBuilder &b, Location l, Value iv) const override {
82
- return genIndexLoad (b, l, getLvlBufs ().front (), iv);
102
+ private:
103
+ const Value posBuffer;
104
+ };
105
+
106
+ class LooseCompressedLevel : public SparseLevel {
107
+ public:
108
+ LooseCompressedLevel (unsigned tid, Level lvl, LevelType lt, Value lvlSize,
109
+ Value posBuffer, Value crdBuffer)
110
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
111
+
112
+ ValuePair peekRangeAt (OpBuilder &b, Location l, Value p,
113
+ Value max) const override {
114
+ assert (max == nullptr && " loss compressed level can not be non-unique." );
115
+ p = MULI (p, C_IDX (2 ));
116
+ Value pLo = genIndexLoad (b, l, posBuffer, p);
117
+ Value pHi = genIndexLoad (b, l, posBuffer, ADDI (p, C_IDX (1 )));
118
+ return {pLo, pHi};
83
119
}
120
+
121
+ private:
122
+ const Value posBuffer;
84
123
};
85
124
86
125
class SingletonLevel : public SparseLevel {
@@ -102,8 +141,8 @@ class SingletonLevel : public SparseLevel {
102
141
class TwoOutFourLevel : public SparseLevel {
103
142
public:
104
143
TwoOutFourLevel (unsigned tid, Level lvl, LevelType lt, Value lvlSize,
105
- Value crdBuf )
106
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuf ) {}
144
+ Value crdBuffer )
145
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer ) {}
107
146
108
147
ValuePair peekRangeAt (OpBuilder &b, Location l, Value p,
109
148
Value max) const override {
@@ -114,39 +153,6 @@ class TwoOutFourLevel : public SparseLevel {
114
153
}
115
154
};
116
155
117
- class CompressedLevel : public SparseLevel {
118
- public:
119
- CompressedLevel (unsigned tid, Level lvl, LevelType lt, Value lvlSize,
120
- Value posBuffer, Value crdBuffer)
121
- : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
122
-
123
- ValuePair peekRangeAt (OpBuilder &b, Location l, Value p,
124
- Value max) const override {
125
- if (max == nullptr ) {
126
- Value pLo = genIndexLoad (b, l, getPosBuf (), p);
127
- Value pHi = genIndexLoad (b, l, getPosBuf (), ADDI (p, C_IDX (1 )));
128
- return {pLo, pHi};
129
- }
130
- llvm_unreachable (" compressed-nu should be the first non-unique level." );
131
- }
132
- };
133
-
134
- class LooseCompressedLevel : public SparseLevel {
135
- public:
136
- LooseCompressedLevel (unsigned tid, Level lvl, LevelType lt, Value lvlSize,
137
- Value posBuffer, Value crdBuffer)
138
- : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
139
-
140
- ValuePair peekRangeAt (OpBuilder &b, Location l, Value p,
141
- Value max) const override {
142
- assert (max == nullptr && " loss compressed level can not be non-unique." );
143
- p = MULI (p, C_IDX (2 ));
144
- Value pLo = genIndexLoad (b, l, getPosBuf (), p);
145
- Value pHi = genIndexLoad (b, l, getPosBuf (), ADDI (p, C_IDX (1 )));
146
- return {pLo, pHi};
147
- }
148
- };
149
-
150
156
} // namespace
151
157
152
158
// ===----------------------------------------------------------------------===//
@@ -195,9 +201,7 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
195
201
// ===----------------------------------------------------------------------===//
196
202
// SparseIterator derived classes.
197
203
// ===----------------------------------------------------------------------===//
198
-
199
- namespace mlir {
200
- namespace sparse_tensor {
204
+ namespace {
201
205
202
206
// The iterator that traverses a concrete sparse tensor levels. High-level
203
207
// abstract iterators wrap it to achieve more complex goals (such as collapsing
@@ -232,11 +236,6 @@ class ConcreteIterator : public SparseIterator {
232
236
SmallVector<Value> cursorValsStorage;
233
237
};
234
238
235
- } // namespace sparse_tensor
236
- } // namespace mlir
237
-
238
- namespace {
239
-
240
239
class TrivialIterator : public ConcreteIterator {
241
240
public:
242
241
TrivialIterator (const SparseTensorLevel &stl)
0 commit comments