@@ -462,11 +462,54 @@ class DedupIterator : public ConcreteIterator {
462
462
Value posHi;
463
463
};
464
464
465
+ // A util base-iterator that delegates all methods to the wrapped iterator.
466
+ class SimpleWrapIterator : public SparseIterator {
467
+ public:
468
+ SimpleWrapIterator (std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
469
+ : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
470
+
471
+ SmallVector<Type> getCursorValTypes (OpBuilder &b) const override {
472
+ return wrap->getCursorValTypes (b);
473
+ }
474
+ bool isBatchIterator () const override { return wrap->isBatchIterator (); }
475
+ bool randomAccessible () const override { return wrap->randomAccessible (); };
476
+ bool iteratableByFor () const override { return wrap->iteratableByFor (); };
477
+ SmallVector<Value> serialize () const override { return wrap->serialize (); };
478
+ void deserialize (ValueRange vs) override { wrap->deserialize (vs); };
479
+ ValueRange getCurPosition () const override { return wrap->getCurPosition (); }
480
+ void genInitImpl (OpBuilder &b, Location l,
481
+ const SparseIterator *parent) override {
482
+ wrap->genInit (b, l, parent);
483
+ }
484
+ Value genNotEndImpl (OpBuilder &b, Location l) override {
485
+ return wrap->genNotEndImpl (b, l);
486
+ }
487
+ ValueRange forwardImpl (OpBuilder &b, Location l) override {
488
+ return wrap->forward (b, l);
489
+ };
490
+ Value upperBound (OpBuilder &b, Location l) const override {
491
+ return wrap->upperBound (b, l);
492
+ };
493
+
494
+ Value derefImpl (OpBuilder &b, Location l) override {
495
+ return wrap->derefImpl (b, l);
496
+ }
497
+
498
+ void locateImpl (OpBuilder &b, Location l, Value crd) override {
499
+ return wrap->locate (b, l, crd);
500
+ }
501
+
502
+ SparseIterator &getWrappedIterator () const { return *wrap; }
503
+
504
+ protected:
505
+ std::unique_ptr<SparseIterator> wrap;
506
+ };
507
+
465
508
//
466
509
// A filter iterator wrapped from another iterator. The filter iterator update
467
510
// the wrapped iterator *in-place*.
468
511
//
469
- class FilterIterator : public SparseIterator {
512
+ class FilterIterator : public SimpleWrapIterator {
470
513
// Coorindate translation between crd loaded from the wrap iterator and the
471
514
// filter iterator.
472
515
Value fromWrapCrd (OpBuilder &b, Location l, Value wrapCrd) const {
@@ -487,8 +530,8 @@ class FilterIterator : public SparseIterator {
487
530
// when crd always < size.
488
531
FilterIterator (std::unique_ptr<SparseIterator> &&wrap, Value offset,
489
532
Value stride, Value size)
490
- : SparseIterator(IterKind:: kFilter , *wrap ), offset(offset),
491
- stride (stride), size(size), wrap(std::move(wrap)) {}
533
+ : SimpleWrapIterator(std::move(wrap), IterKind:: kFilter ), offset(offset),
534
+ stride (stride), size(size) {}
492
535
493
536
// For LLVM-style RTTI.
494
537
static bool classof (const SparseIterator *from) {
@@ -498,19 +541,10 @@ class FilterIterator : public SparseIterator {
498
541
std::string getDebugInterfacePrefix () const override {
499
542
return std::string (" filter<" ) + wrap->getDebugInterfacePrefix () + " >" ;
500
543
}
501
- SmallVector<Type> getCursorValTypes (OpBuilder &b) const override {
502
- return wrap->getCursorValTypes (b);
503
- }
504
544
505
- bool isBatchIterator () const override { return wrap->isBatchIterator (); }
506
- bool randomAccessible () const override { return wrap->randomAccessible (); };
507
545
bool iteratableByFor () const override { return randomAccessible (); };
508
546
Value upperBound (OpBuilder &b, Location l) const override { return size; };
509
547
510
- SmallVector<Value> serialize () const override { return wrap->serialize (); };
511
- void deserialize (ValueRange vs) override { wrap->deserialize (vs); };
512
- ValueRange getCurPosition () const override { return wrap->getCurPosition (); }
513
-
514
548
void genInitImpl (OpBuilder &b, Location l,
515
549
const SparseIterator *parent) override {
516
550
wrap->genInit (b, l, parent);
@@ -541,7 +575,47 @@ class FilterIterator : public SparseIterator {
541
575
ValueRange forwardImpl (OpBuilder &b, Location l) override ;
542
576
543
577
Value offset, stride, size;
544
- std::unique_ptr<SparseIterator> wrap;
578
+ };
579
+
580
+ //
581
+ // A pad iterator wrapped from another iterator. The pad iterator updates
582
+ // the wrapped iterator *in-place*.
583
+ //
584
+ class PadIterator : public SimpleWrapIterator {
585
+
586
+ public:
587
+ PadIterator (std::unique_ptr<SparseIterator> &&wrap, Value padLow,
588
+ Value padHigh)
589
+ : SimpleWrapIterator(std::move(wrap), IterKind::kPad ), padLow(padLow),
590
+ padHigh (padHigh) {
591
+ assert (!randomAccessible () && " Not implemented." );
592
+ }
593
+
594
+ // For LLVM-style RTTI.
595
+ static bool classof (const SparseIterator *from) {
596
+ return from->kind == IterKind::kPad ;
597
+ }
598
+
599
+ std::string getDebugInterfacePrefix () const override {
600
+ return std::string (" pad<" ) + wrap->getDebugInterfacePrefix () + " >" ;
601
+ }
602
+
603
+ // The upper bound after padding becomes `size + padLow + padHigh`.
604
+ Value upperBound (OpBuilder &b, Location l) const override {
605
+ return ADDI (ADDI (wrap->upperBound (b, l), padLow), padHigh);
606
+ };
607
+
608
+ // The pad_coord = coord + pad_lo
609
+ Value derefImpl (OpBuilder &b, Location l) override {
610
+ updateCrd (ADDI (wrap->deref (b, l), padLow));
611
+ return getCrd ();
612
+ }
613
+
614
+ void locateImpl (OpBuilder &b, Location l, Value crd) override {
615
+ assert (randomAccessible ());
616
+ }
617
+
618
+ Value padLow, padHigh;
545
619
};
546
620
547
621
class NonEmptySubSectIterator : public SparseIterator {
@@ -1408,10 +1482,19 @@ sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
1408
1482
return ret;
1409
1483
}
1410
1484
1485
+ std::unique_ptr<SparseIterator>
1486
+ sparse_tensor::makePaddedIterator (std::unique_ptr<SparseIterator> &&sit,
1487
+ Value padLow, Value padHigh,
1488
+ SparseEmitStrategy strategy) {
1489
+ auto ret = std::make_unique<PadIterator>(std::move (sit), padLow, padHigh);
1490
+ ret->setSparseEmitStrategy (strategy);
1491
+ return ret;
1492
+ }
1493
+
1411
1494
static const SparseIterator *tryUnwrapFilter (const SparseIterator *it) {
1412
1495
auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1413
1496
if (filter)
1414
- return filter->wrap . get ();
1497
+ return & filter->getWrappedIterator ();
1415
1498
return it;
1416
1499
}
1417
1500
0 commit comments