@@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
95
95
ValueRange getLvlBuffers () const override { return {}; }
96
96
97
97
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
98
- ValueRange parentPos) const override {
98
+ ValueRange parentPos, Value inPadZone ) const override {
99
99
assert (parentPos.size () == 1 && " Dense level can not be non-unique." );
100
+ assert (!inPadZone && " Not implemented" );
100
101
Value p = parentPos.front ();
101
102
Value posLo = MULI (p, lvlSize);
102
103
return {posLo, lvlSize};
@@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
115
116
ValueRange getLvlBuffers () const override { return {}; }
116
117
117
118
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange,
118
- ValueRange parentPos) const override {
119
+ ValueRange parentPos, Value inPadZone) const override {
120
+ assert (!inPadZone && " Not implemented" );
119
121
assert (parentPos.size () == 1 && " Dense level can not be non-unique." );
120
122
// No need to linearize the position for non-annotated tensors.
121
123
return {C_IDX (0 ), lvlSize};
@@ -129,18 +131,42 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
129
131
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
130
132
131
133
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
132
- ValueRange parentPos) const override {
134
+ ValueRange parentPos, Value inPadZone ) const override {
133
135
134
136
assert (parentPos.size () == 1 &&
135
137
" compressed level must be the first non-unique level." );
136
- Value p = parentPos.front ();
137
138
138
- SmallVector<Value> memCrd (batchPrefix);
139
- memCrd.push_back (p);
140
- Value pLo = genIndexLoad (b, l, getPosBuf (), memCrd);
141
- memCrd.back () = ADDI (p, C_IDX (1 ));
142
- Value pHi = genIndexLoad (b, l, getPosBuf (), memCrd);
143
- return {pLo, pHi};
139
+ auto loadRange = [&b, l, parentPos, batchPrefix, this ]() -> ValuePair {
140
+ Value p = parentPos.front ();
141
+ SmallVector<Value> memCrd (batchPrefix);
142
+ memCrd.push_back (p);
143
+ Value pLo = genIndexLoad (b, l, getPosBuf (), memCrd);
144
+ memCrd.back () = ADDI (p, C_IDX (1 ));
145
+ Value pHi = genIndexLoad (b, l, getPosBuf (), memCrd);
146
+ return {pLo, pHi};
147
+ };
148
+
149
+ if (inPadZone == nullptr )
150
+ return loadRange ();
151
+
152
+ SmallVector<Type, 2 > types{b.getIndexType (), b.getIndexType ()};
153
+ scf::IfOp posRangeIf = b.create <scf::IfOp>(l, types, inPadZone, true );
154
+ // True branch, returns a "fake" empty range [0, 0) if parent
155
+ // iterator is in pad zone.
156
+ b.setInsertionPointToStart (posRangeIf.thenBlock ());
157
+
158
+ SmallVector<Value, 2 > emptyRange{C_IDX (0 ), C_IDX (0 )};
159
+ b.create <scf::YieldOp>(l, emptyRange);
160
+
161
+ // False branch, returns the actual range.
162
+ b.setInsertionPointToStart (posRangeIf.elseBlock ());
163
+ auto [pLo, pHi] = loadRange ();
164
+ SmallVector<Value, 2 > loadedRange{pLo, pHi};
165
+ b.create <scf::YieldOp>(l, loadedRange);
166
+
167
+ b.setInsertionPointAfter (posRangeIf);
168
+ ValueRange posRange = posRangeIf.getResults ();
169
+ return {posRange.front (), posRange.back ()};
144
170
}
145
171
};
146
172
@@ -151,9 +177,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
151
177
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
152
178
153
179
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
154
- ValueRange parentPos) const override {
180
+ ValueRange parentPos, Value inPadZone ) const override {
155
181
assert (parentPos.size () == 1 &&
156
182
" loose-compressed level must be the first non-unique level." );
183
+ assert (!inPadZone && " Not implemented" );
157
184
SmallVector<Value> memCrd (batchPrefix);
158
185
Value p = parentPos.front ();
159
186
p = MULI (p, C_IDX (2 ));
@@ -172,8 +199,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
172
199
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
173
200
174
201
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
175
- ValueRange parentPos) const override {
202
+ ValueRange parentPos, Value inPadZone ) const override {
176
203
assert (parentPos.size () == 1 || parentPos.size () == 2 );
204
+ assert (!inPadZone && " Not implemented" );
177
205
Value p = parentPos.front ();
178
206
Value segHi = parentPos.size () == 2 ? parentPos.back () : nullptr ;
179
207
@@ -191,9 +219,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
191
219
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
192
220
193
221
ValuePair peekRangeAt (OpBuilder &b, Location l, ValueRange batchPrefix,
194
- ValueRange parentPos) const override {
222
+ ValueRange parentPos, Value inPadZone ) const override {
195
223
assert (parentPos.size () == 1 && isUnique () &&
196
224
" n:m level can not be non-unique." );
225
+ assert (!inPadZone && " Not implemented" );
197
226
// Each n:m blk has exactly n specified elements.
198
227
auto n = getN (lt);
199
228
Value posLo = MULI (parentPos.front (), C_IDX (n));
@@ -325,23 +354,7 @@ class TrivialIterator : public ConcreteIterator {
325
354
};
326
355
327
356
void genInitImpl (OpBuilder &b, Location l,
328
- const SparseIterator *parent) override {
329
-
330
- if (isBatchIterator () && batchCrds.size () <= stl.lvl )
331
- batchCrds.resize (stl.lvl + 1 , nullptr );
332
-
333
- Value c0 = C_IDX (0 );
334
- ValueRange pPos = c0;
335
- // If the parent iterator is a batch iterator, we also start from 0 (but
336
- // on a different batch).
337
- if (parent && !parent->isBatchIterator ())
338
- pPos = parent->getCurPosition ();
339
-
340
- ValueRange batchPrefix = parent ? parent->getBatchCrds () : ValueRange{};
341
- std::tie (posLo, posHi) = stl.peekRangeAt (b, l, batchPrefix, pPos);
342
- // Seek to the lowest position.
343
- seek (posLo);
344
- }
357
+ const SparseIterator *parent) override ;
345
358
346
359
ValuePair genForCond (OpBuilder &b, Location l) override {
347
360
if (randomAccessible ())
@@ -465,15 +478,17 @@ class DedupIterator : public ConcreteIterator {
465
478
// A util base-iterator that delegates all methods to the wrapped iterator.
466
479
class SimpleWrapIterator : public SparseIterator {
467
480
public:
468
- SimpleWrapIterator (std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
469
- : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
481
+ SimpleWrapIterator (std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
482
+ unsigned extraCursorVal = 0 )
483
+ : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
470
484
471
485
SmallVector<Type> getCursorValTypes (OpBuilder &b) const override {
472
486
return wrap->getCursorValTypes (b);
473
487
}
474
488
bool isBatchIterator () const override { return wrap->isBatchIterator (); }
475
489
bool randomAccessible () const override { return wrap->randomAccessible (); };
476
490
bool iteratableByFor () const override { return wrap->iteratableByFor (); };
491
+
477
492
SmallVector<Value> serialize () const override { return wrap->serialize (); };
478
493
void deserialize (ValueRange vs) override { wrap->deserialize (vs); };
479
494
ValueRange getCurPosition () const override { return wrap->getCurPosition (); }
@@ -586,10 +601,9 @@ class PadIterator : public SimpleWrapIterator {
586
601
public:
587
602
PadIterator (std::unique_ptr<SparseIterator> &&wrap, Value padLow,
588
603
Value padHigh)
589
- : SimpleWrapIterator(std::move(wrap), IterKind::kPad ), padLow(padLow),
590
- padHigh (padHigh) {
591
- assert (!randomAccessible () && " Not implemented." );
592
- }
604
+ : SimpleWrapIterator(std::move(wrap), IterKind::kPad ,
605
+ wrap->randomAccessible () ? 1 : 0),
606
+ padLow(padLow), padHigh(padHigh) {}
593
607
594
608
// For LLVM-style RTTI.
595
609
static bool classof (const SparseIterator *from) {
@@ -600,6 +614,26 @@ class PadIterator : public SimpleWrapIterator {
600
614
return std::string (" pad<" ) + wrap->getDebugInterfacePrefix () + " >" ;
601
615
}
602
616
617
+ // Returns a pair of values for *upper*, *lower* bound respectively.
618
+ ValuePair genForCond (OpBuilder &b, Location l) override {
619
+ if (randomAccessible ())
620
+ return {getCrd (), upperBound (b, l)};
621
+ return wrap->genForCond (b, l);
622
+ }
623
+
624
+ // For padded dense iterator, we append a `inPadZone: bool` in addition to
625
+ // values used by the wrapped iterator.
626
+ ValueRange getCurPosition () const override { return getCursor (); }
627
+
628
+ SmallVector<Type> getCursorValTypes (OpBuilder &b) const override {
629
+ SmallVector<Type> ret = wrap->getCursorValTypes (b);
630
+ // Need an extra boolean value `inPadZone` for padded dense iterator.
631
+ if (randomAccessible ())
632
+ ret.push_back (b.getI1Type ());
633
+
634
+ return ret;
635
+ }
636
+
603
637
// The upper bound after padding becomes `size + padLow + padHigh`.
604
638
Value upperBound (OpBuilder &b, Location l) const override {
605
639
return ADDI (ADDI (wrap->upperBound (b, l), padLow), padHigh);
@@ -613,6 +647,14 @@ class PadIterator : public SimpleWrapIterator {
613
647
614
648
void locateImpl (OpBuilder &b, Location l, Value crd) override {
615
649
assert (randomAccessible ());
650
+ wrap->locate (b, l, SUBI (crd, padLow));
651
+
652
+ // inPadZone = crd < padLow || crd >= size + padLow.
653
+ Value inPadLow = CMPI (ult, crd, padLow);
654
+ Value inPadHigh = CMPI (uge, crd, ADDI (wrap->upperBound (b, l), padLow));
655
+ getMutCursorVals ().back () = ORI (inPadLow, inPadHigh);
656
+
657
+ updateCrd (crd);
616
658
}
617
659
618
660
Value padLow, padHigh;
@@ -1227,6 +1269,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1227
1269
return p->inflateSubSectTree (b, l, reduc, visitDenseSubSect);
1228
1270
}
1229
1271
1272
+ void TrivialIterator::genInitImpl (OpBuilder &b, Location l,
1273
+ const SparseIterator *parent) {
1274
+
1275
+ if (isBatchIterator () && batchCrds.size () <= stl.lvl )
1276
+ batchCrds.resize (stl.lvl + 1 , nullptr );
1277
+
1278
+ Value c0 = C_IDX (0 );
1279
+ ValueRange pPos = c0;
1280
+ Value inPadZone = nullptr ;
1281
+ // If the parent iterator is a batch iterator, we also start from 0 (but
1282
+ // on a different batch).
1283
+ if (parent && !parent->isBatchIterator ()) {
1284
+ pPos = parent->getCurPosition ();
1285
+ if (llvm::isa<PadIterator>(parent) && parent->randomAccessible ()) {
1286
+ // A padded dense iterator create "sparse" padded zone, which need to be
1287
+ // handled specially.
1288
+ inPadZone = pPos.back ();
1289
+ pPos = pPos.drop_back ();
1290
+ }
1291
+ }
1292
+
1293
+ ValueRange batchPrefix = parent ? parent->getBatchCrds () : ValueRange{};
1294
+ std::tie (posLo, posHi) = stl.peekRangeAt (b, l, batchPrefix, pPos, inPadZone);
1295
+ // Seek to the lowest position.
1296
+ seek (posLo);
1297
+ }
1298
+
1230
1299
void NonEmptySubSectIterator::genInitImpl (OpBuilder &b, Location l,
1231
1300
const SparseIterator *) {
1232
1301
Value c0 = C_IDX (0 );
0 commit comments