@@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
95
95
ValueRange getLvlBuffers () const override { return {}; }
96
96
97
97
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
98
- ValueRange parentPos) const override {
98
+ ValueRange parentPos, Value inPadZone ) const override {
99
99
assert (parentPos.size () == 1 && " Dense level can not be non-unique." );
100
+ assert (!inPadZone && " Not implememnted" );
100
101
Value p = parentPos.front ();
101
102
Value posLo = MULI (p, lvlSize);
102
103
return {posLo, lvlSize};
@@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
115
116
ValueRange getLvlBuffers () const override { return {}; }
116
117
117
118
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange,
118
- ValueRange parentPos) const override {
119
+ ValueRange parentPos, Value inPadZone) const override {
120
+ assert (!inPadZone && " Not implememnted" );
119
121
assert (parentPos.size () == 1 && " Dense level can not be non-unique." );
120
122
// No need to linearize the position for non-annotated tensors.
121
123
return {C_IDX (0 ), lvlSize};
@@ -129,18 +131,41 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
129
131
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
130
132
131
133
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
132
- ValueRange parentPos) const override {
134
+ ValueRange parentPos, Value inPadZone ) const override {
133
135
134
136
assert (parentPos.size () == 1 &&
135
137
" compressed level must be the first non-unique level." );
136
- Value p = parentPos.front ();
137
138
138
- SmallVector<Value> memCrd (batchPrefix);
139
- memCrd.push_back (p);
140
- Value pLo = genIndexLoad (b, l, getPosBuf (), memCrd);
141
- memCrd.back () = ADDI (p, C_IDX (1 ));
142
- Value pHi = genIndexLoad (b, l, getPosBuf (), memCrd);
143
- return {pLo, pHi};
139
+ auto loadRange = [&b, l, parentPos, batchPrefix, this ]() -> ValuePair {
140
+ Value p = parentPos.front ();
141
+ SmallVector<Value> memCrd (batchPrefix);
142
+ memCrd.push_back (p);
143
+ Value pLo = genIndexLoad (b, l, getPosBuf (), memCrd);
144
+ memCrd.back () = ADDI (p, C_IDX (1 ));
145
+ Value pHi = genIndexLoad (b, l, getPosBuf (), memCrd);
146
+ return {pLo, pHi};
147
+ };
148
+
149
+ if (inPadZone == nullptr )
150
+ return loadRange ();
151
+
152
+ SmallVector<Type, 2 > types{b.getIndexType (), b.getIndexType ()};
153
+ scf::IfOp posRangeIf = b.create <scf::IfOp>(l, types, inPadZone, true );
154
+ // True branch.
155
+ b.setInsertionPointToStart (posRangeIf.thenBlock ());
156
+ // Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
157
+ SmallVector<Value, 2 > emptyRange{C_IDX (0 ), C_IDX (0 )};
158
+ b.create <scf::YieldOp>(l, emptyRange);
159
+
160
+ // False branch.
161
+ b.setInsertionPointToStart (posRangeIf.elseBlock ());
162
+ auto [pLo, pHi] = loadRange ();
163
+ SmallVector<Value, 2 > loadedRange{pLo, pHi};
164
+ b.create <scf::YieldOp>(l, loadedRange);
165
+
166
+ b.setInsertionPointAfter (posRangeIf);
167
+ ValueRange posRange = posRangeIf.getResults ();
168
+ return {posRange.front (), posRange.back ()};
144
169
}
145
170
};
146
171
@@ -151,9 +176,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
151
176
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
152
177
153
178
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
154
- ValueRange parentPos) const override {
179
+ ValueRange parentPos, Value inPadZone ) const override {
155
180
assert (parentPos.size () == 1 &&
156
181
" loose-compressed level must be the first non-unique level." );
182
+ assert (!inPadZone && " Not implememnted" );
157
183
SmallVector<Value> memCrd (batchPrefix);
158
184
Value p = parentPos.front ();
159
185
p = MULI (p, C_IDX (2 ));
@@ -172,8 +198,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
172
198
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
173
199
174
200
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
175
- ValueRange parentPos) const override {
201
+ ValueRange parentPos, Value inPadZone ) const override {
176
202
assert (parentPos.size () == 1 || parentPos.size () == 2 );
203
+ assert (!inPadZone && " Not implememnted" );
177
204
Value p = parentPos.front ();
178
205
Value segHi = parentPos.size () == 2 ? parentPos.back () : nullptr ;
179
206
@@ -191,9 +218,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
191
218
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
192
219
193
220
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
194
- ValueRange parentPos) const override {
221
+ ValueRange parentPos, Value inPadZone ) const override {
195
222
assert (parentPos.size () == 1 && isUnique () &&
196
223
" n:m level can not be non-unique." );
224
+ assert (!inPadZone && " Not implememnted" );
197
225
// Each n:m blk has exactly n specified elements.
198
226
auto n = getN (lt);
199
227
Value posLo = MULI (parentPos.front (), C_IDX (n));
@@ -325,23 +353,7 @@ class TrivialIterator : public ConcreteIterator {
325
353
};
326
354
327
355
void genInitImpl (OpBuilder &b, Location l,
328
- const SparseIterator *parent) override {
329
-
330
- if (isBatchIterator () && batchCrds.size () <= stl.lvl )
331
- batchCrds.resize (stl.lvl + 1 , nullptr );
332
-
333
- Value c0 = C_IDX (0 );
334
- ValueRange pPos = c0;
335
- // If the parent iterator is a batch iterator, we also start from 0 (but
336
- // on a different batch).
337
- if (parent && !parent->isBatchIterator ())
338
- pPos = parent->getCurPosition ();
339
-
340
- ValueRange batchPrefix = parent ? parent->getBatchCrds () : ValueRange{};
341
- std::tie (posLo, posHi) = stl.peekRangeAt (b, l, batchPrefix, pPos);
342
- // Seek to the lowest position.
343
- seek (posLo);
344
- }
356
+ const SparseIterator *parent) override ;
345
357
346
358
ValuePair genForCond (OpBuilder &b, Location l) override {
347
359
if (randomAccessible ())
@@ -465,8 +477,9 @@ class DedupIterator : public ConcreteIterator {
465
477
// A util base-iterator that delegates all methods to the wrapped iterator.
466
478
class SimpleWrapIterator : public SparseIterator {
467
479
public:
468
- SimpleWrapIterator (std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
469
- : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
480
+ SimpleWrapIterator (std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
481
+ unsigned extraCursorVal = 0 )
482
+ : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
470
483
471
484
SmallVector<Type> getCursorValTypes (OpBuilder &b) const override {
472
485
return wrap->getCursorValTypes (b);
@@ -586,9 +599,10 @@ class PadIterator : public SimpleWrapIterator {
586
599
public:
587
600
PadIterator (std::unique_ptr<SparseIterator> &&wrap, Value padLow,
588
601
Value padHigh)
589
- : SimpleWrapIterator(std::move(wrap), IterKind::kPad ), padLow(padLow),
590
- padHigh (padHigh) {
591
- assert (!randomAccessible () && " Not implemented." );
602
+ : SimpleWrapIterator(std::move(wrap), IterKind::kPad ,
603
+ wrap->randomAccessible () ? 1 : 0),
604
+ padLow(padLow), padHigh(padHigh) {
605
+ // assert(!randomAccessible());
592
606
}
593
607
594
608
// For LLVM-style RTTI.
@@ -600,6 +614,19 @@ class PadIterator : public SimpleWrapIterator {
600
614
return std::string (" pad<" ) + wrap->getDebugInterfacePrefix () + " >" ;
601
615
}
602
616
617
+ // For padded dense iterator, we append a `inPadZone: bool` in addition to
618
+ // values used by the wrapped iterator.
619
+ ValueRange getCurPosition () const override { return getCursor (); }
620
+
621
+ SmallVector<Type> getCursorValTypes (OpBuilder &b) const override {
622
+ SmallVector<Type> ret = wrap->getCursorValTypes (b);
623
+ // Need a extra boolean value `inPadZone` for padded dense iterator.
624
+ if (randomAccessible ())
625
+ ret.push_back (b.getI1Type ());
626
+
627
+ return ret;
628
+ }
629
+
603
630
// The upper bound after padding becomes `size + padLow + padHigh`.
604
631
Value upperBound (OpBuilder &b, Location l) const override {
605
632
return ADDI (ADDI (wrap->upperBound (b, l), padLow), padHigh);
@@ -613,6 +640,14 @@ class PadIterator : public SimpleWrapIterator {
613
640
614
641
void locateImpl (OpBuilder &b, Location l, Value crd) override {
615
642
assert (randomAccessible ());
643
+ wrap->locate (b, l, SUBI (crd, padLow));
644
+
645
+ // inPadZone = crd < padLow || crd >= size + padLow.
646
+ Value inPadLow = CMPI (ult, crd, padLow);
647
+ Value inPadHigh = CMPI (uge, crd, ADDI (wrap->upperBound (b, l), padLow));
648
+ getMutCursorVals ().back () = ORI (inPadLow, inPadHigh);
649
+
650
+ updateCrd (crd);
616
651
}
617
652
618
653
Value padLow, padHigh;
@@ -1227,6 +1262,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1227
1262
return p->inflateSubSectTree (b, l, reduc, visitDenseSubSect);
1228
1263
}
1229
1264
1265
+ void TrivialIterator::genInitImpl (OpBuilder &b, Location l,
1266
+ const SparseIterator *parent) {
1267
+
1268
+ if (isBatchIterator () && batchCrds.size () <= stl.lvl )
1269
+ batchCrds.resize (stl.lvl + 1 , nullptr );
1270
+
1271
+ Value c0 = C_IDX (0 );
1272
+ ValueRange pPos = c0;
1273
+ Value inPadZone = nullptr ;
1274
+ // If the parent iterator is a batch iterator, we also start from 0 (but
1275
+ // on a different batch).
1276
+ if (parent && !parent->isBatchIterator ()) {
1277
+ pPos = parent->getCurPosition ();
1278
+ if (llvm::isa<PadIterator>(parent) && parent->randomAccessible ()) {
1279
+ // A padded dense iterator create "sparse" padded zone, which need to be
1280
+ // handled specially.
1281
+ inPadZone = pPos.back ();
1282
+ pPos = pPos.drop_back ();
1283
+ }
1284
+ }
1285
+
1286
+ ValueRange batchPrefix = parent ? parent->getBatchCrds () : ValueRange{};
1287
+ std::tie (posLo, posHi) = stl.peekRangeAt (b, l, batchPrefix, pPos, inPadZone);
1288
+ // Seek to the lowest position.
1289
+ seek (posLo);
1290
+ }
1291
+
1230
1292
void NonEmptySubSectIterator::genInitImpl (OpBuilder &b, Location l,
1231
1293
const SparseIterator *) {
1232
1294
Value c0 = C_IDX (0 );
0 commit comments