Skip to content

[mlir][sparse] Support pretty print to debug sparse iteration. #80207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 1, 2024

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jan 31, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

Patch is 71.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80207.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+17-2)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (+7)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (+3-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp (+3-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp (+8-5)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+12-8)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp (+203-112)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h (+108-64)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+48-228)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index e93e2aefb344f..8b2875a751d4a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -47,6 +47,12 @@ enum class ReinterpretMapScope {
   kExceptGeneric, // reinterprets operation other than linalg.generic
 };
 
+/// Defines a scope for reinterpret map pass.
+enum class DebugSparseIteration {
+  kNone,          // generate fully inlined (and functional) sparse iteration
+  kInterfaceOnly, // generate only place-holder for sparse iteration
+};
+
 #define GEN_PASS_DECL
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 
@@ -74,11 +80,20 @@ std::unique_ptr<Pass> createPreSparsificationRewritePass();
 
 /// Options for the Sparsification pass.
 struct SparsificationOptions {
+  SparsificationOptions(SparseParallelizationStrategy p, DebugSparseIteration d,
+                        bool enableRT)
+      : parallelizationStrategy(p), debugSparseIteration(d),
+        enableRuntimeLibrary(enableRT) {}
+
   SparsificationOptions(SparseParallelizationStrategy p, bool enableRT)
-      : parallelizationStrategy(p), enableRuntimeLibrary(enableRT) {}
+      : SparsificationOptions(p, DebugSparseIteration::kNone, enableRT) {}
+
   SparsificationOptions()
-      : SparsificationOptions(SparseParallelizationStrategy::kNone, true) {}
+      : SparsificationOptions(SparseParallelizationStrategy::kNone,
+                              DebugSparseIteration::kNone, true) {}
+
   SparseParallelizationStrategy parallelizationStrategy;
+  DebugSparseIteration debugSparseIteration;
   bool enableRuntimeLibrary;
 };
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index f38779ed9ed2b..126b91510d391 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -130,6 +130,13 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
              clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
                         "any-storage-any-loop",
                         "Enable sparse parallelization for any storage and loop."))}]>,
+    Option<"debugSparseIteration", "debug-sparse-iteration", "mlir::DebugSparseIteration",
+           "mlir::DebugSparseIteration::kNone",
+           "Pretty print sparse loops to debug sparse iteration", [{llvm::cl::values(
+             clEnumValN(mlir::DebugSparseIteration::kNone, "none",
+                        "Turn off pretty printing and generates functional code."),
+             clEnumValN(mlir::DebugSparseIteration::kInterfaceOnly, "interface-only",
+                        "Generate non-functional interfaces for sparse iteration."))}]>,
     Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
            "true", "Enable runtime library for manipulating sparse tensors">,
   ];
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 375e10f9068e4..0ae9f6483588d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -82,13 +82,15 @@ struct SparsificationPass
   SparsificationPass(const SparsificationPass &pass) = default;
   SparsificationPass(const SparsificationOptions &options) {
     parallelization = options.parallelizationStrategy;
+    debugSparseIteration = options.debugSparseIteration;
     enableRuntimeLibrary = options.enableRuntimeLibrary;
   }
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     // Translate strategy flags to strategy options.
-    SparsificationOptions options(parallelization, enableRuntimeLibrary);
+    SparsificationOptions options(parallelization, debugSparseIteration,
+                                  enableRuntimeLibrary);
     // Apply sparsification and cleanup rewriting.
     RewritePatternSet patterns(ctx);
     populateSparsificationPatterns(patterns, options);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 5266ca7213bfc..2ceb214052aa3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1369,7 +1369,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
       return failure();
 
     // Recursively generates code if admissible.
-    env.startEmit();
+    env.startEmit(options.debugSparseIteration);
     genBuffers(env, rewriter);
     // TODO: Constant affine expression should be handled differently when using
     // slice-based codegen, it does not matter now because we already reject the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
index d3de55e4d59bd..0af1cc1745f51 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
@@ -59,7 +59,7 @@ LogicalResult CodegenEnv::initTensorExp() {
   return success();
 }
 
-void CodegenEnv::startEmit() {
+void CodegenEnv::startEmit(DebugSparseIteration emitStrategy) {
   assert(insChain == nullptr && "must only start emitting once");
   if (sparseOut) {
     insChain = sparseOut->get();
@@ -96,7 +96,8 @@ void CodegenEnv::startEmit() {
       /*dependentLvlGetter=*/
       [this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
         return merger().getDependentLoops(t, lvl);
-      });
+      },
+      emitStrategy);
 }
 
 std::optional<Operation *> CodegenEnv::genLoopBoundary(
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
index 728af841cc7b1..7eeddac48f4f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
@@ -52,7 +52,7 @@ class CodegenEnv {
   Merger &merger() { return latticeMerger; }
   LoopEmitter &emitter() { return loopEmitter; }
 
-  void startEmit();
+  void startEmit(DebugSparseIteration emitStrategy);
 
   /// Generates loop boundary statements (entering/exiting loops). The function
   /// passes and updates the passed-in parameters.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 3fa4004ae460e..8c1680a393181 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -81,17 +81,20 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
 
 LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
                          bool isSparseOut, unsigned numLoops,
-                         DependentLvlGetter dimGetter) {
+                         DependentLvlGetter dimGetter,
+                         DebugSparseIteration emitStrategy) {
   initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
 }
 
 void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
                              bool isSparseOut, unsigned numLoops,
-                             DependentLvlGetter dimGetter) {
+                             DependentLvlGetter dimGetter,
+                             DebugSparseIteration emitStrategy) {
   // First initialize the top-level type of the fields.
   this->loopTag = loopTag;
   this->hasOutput = hasOutput;
   this->isSparseOut = isSparseOut;
+  SparseIterator::setDebugSparseIteration(emitStrategy);
 
   const unsigned numManifestTensors = ts.size();
   const unsigned synTensorId = numManifestTensors;
@@ -169,7 +172,7 @@ LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
     Value offset = genSliceOffset(builder, loc, tensors[t], l);
     Value stride = genSliceStride(builder, loc, tensors[t], l);
     auto slicedIt = makeSlicedLevelIterator(std::move(it), offset, stride,
-                                            lvls[t][l]->size());
+                                            lvls[t][l]->getSize());
     return slicedIt;
   }
   return it;
@@ -465,7 +468,7 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
 
   // Construct the while-loop with a parameter for each coordinate.
   for (SparseIterator *it : spIters) {
-    ValueRange itVals = it->getItVals();
+    ValueRange itVals = it->getCursor();
     ivs.append(itVals.begin(), itVals.end());
   }
 
@@ -724,7 +727,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
       // Forward the sparse iterator.
       Value cmp = CMPI(eq, it.getCrd(), iv);
       it.forwardIf(builder, loc, cmp);
-      operands.append(it.getItVals().begin(), it.getItVals().end());
+      operands.append(it.getCursor().begin(), it.getCursor().end());
       // const Value newPos = whileOp->getResult(o++);
       // Following loops continue iteration from the break point of the
       // current while loop.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index d0f447d926f71..e0b4f81487a68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
 #include "mlir/IR/PatternMatch.h"
 
@@ -84,14 +85,17 @@ class LoopEmitter {
   /// `isSparseOut` indicates that the sparse output tensor is empty,
   /// so the loop emitter will generate loops over it according to the
   /// level-sizes.
-  void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
-                  bool hasOutput = false, bool isSparseOut = false,
-                  unsigned numLoops = 0, DependentLvlGetter getter = nullptr);
-
-  explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
-                       bool hasOutput = false, bool isSparseOut = false,
-                       unsigned numLoops = 0,
-                       DependentLvlGetter getter = nullptr);
+  void
+  initialize(ValueRange tensors, StringAttr loopTag = nullptr,
+             bool hasOutput = false, bool isSparseOut = false,
+             unsigned numLoops = 0, DependentLvlGetter getter = nullptr,
+             DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
+
+  explicit LoopEmitter(
+      ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false,
+      bool isSparseOut = false, unsigned numLoops = 0,
+      DependentLvlGetter getter = nullptr,
+      DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
 
   /// Starts a loop emitting session by generating all the buffers needed
   /// for iterating over the tensors.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 98323c2195461..bdaf794744bea 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -46,20 +46,6 @@ using ValueTuple = std::tuple<Value, Value, Value>;
 
 namespace {
 
-class SparseLevel : public SparseTensorLevel {
-public:
-  SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-              Value crdBuffer)
-      : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
-
-  Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
-    return genIndexLoad(b, l, crdBuffer, iv);
-  }
-
-protected:
-  const Value crdBuffer;
-};
-
 class DenseLevel : public SparseTensorLevel {
 public:
   DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
@@ -74,53 +60,27 @@ class DenseLevel : public SparseTensorLevel {
                         Value max) const override {
     assert(max == nullptr && "Dense level can not be non-unique.");
     if (encoded) {
-      Value posLo = MULI(p, lvlSize);
-      return {posLo, lvlSize};
+      Value posLo = MULI(p, getSize());
+      return {posLo, getSize()};
     }
     // No need to linearize the position for non-annotated tensors.
-    return {C_IDX(0), lvlSize};
+    return {C_IDX(0), getSize()};
   }
 
   const bool encoded;
 };
 
-class CompressedLevel : public SparseLevel {
+class SparseLevel : public SparseTensorLevel {
 public:
-  CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-                  Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
-
-  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
-                        Value max) const override {
-    if (max == nullptr) {
-      Value pLo = genIndexLoad(b, l, posBuffer, p);
-      Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
-      return {pLo, pHi};
-    }
-    llvm_unreachable("compressed-nu should be the first non-unique level.");
+  SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+              ValueRange lvlBuf)
+      : SparseTensorLevel(tid, lvl, lt, lvlSize, lvlBuf) {
+    assert(!lvlBuf.empty());
   }
 
-private:
-  const Value posBuffer;
-};
-
-class LooseCompressedLevel : public SparseLevel {
-public:
-  LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-                       Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
-
-  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
-                        Value max) const override {
-    assert(max == nullptr && "loss compressed level can not be non-unique.");
-    p = MULI(p, C_IDX(2));
-    Value pLo = genIndexLoad(b, l, posBuffer, p);
-    Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
-    return {pLo, pHi};
+  Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
+    return genIndexLoad(b, l, getLvlBufs().front(), iv);
   }
-
-private:
-  const Value posBuffer;
 };
 
 class SingletonLevel : public SparseLevel {
@@ -142,8 +102,8 @@ class SingletonLevel : public SparseLevel {
 class TwoOutFourLevel : public SparseLevel {
 public:
   TwoOutFourLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-                  Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+                  Value crdBuf)
+      : SparseLevel(tid, lvl, lt, lvlSize, crdBuf) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
                         Value max) const override {
@@ -154,6 +114,39 @@ class TwoOutFourLevel : public SparseLevel {
   }
 };
 
+class CompressedLevel : public SparseLevel {
+public:
+  CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+                  Value posBuffer, Value crdBuffer)
+      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
+
+  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+                        Value max) const override {
+    if (max == nullptr) {
+      Value pLo = genIndexLoad(b, l, getPosBuf(), p);
+      Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
+      return {pLo, pHi};
+    }
+    llvm_unreachable("compressed-nu should be the first non-unique level.");
+  }
+};
+
+class LooseCompressedLevel : public SparseLevel {
+public:
+  LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+                       Value posBuffer, Value crdBuffer)
+      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
+
+  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+                        Value max) const override {
+    assert(max == nullptr && "loss compressed level can not be non-unique.");
+    p = MULI(p, C_IDX(2));
+    Value pLo = genIndexLoad(b, l, getPosBuf(), p);
+    Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
+    return {pLo, pHi};
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -203,7 +196,8 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
 // SparseIterator derived classes.
 //===----------------------------------------------------------------------===//
 
-namespace {
+namespace mlir {
+namespace sparse_tensor {
 
 // The iterator that traverses a concrete sparse tensor levels. High-level
 // abstract iterators wrap it to achieve more complex goals (such as collapsing
@@ -212,12 +206,11 @@ namespace {
 class ConcreteIterator : public SparseIterator {
 protected:
   ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
-                   unsigned itValCnt)
-      : SparseIterator(kind, stl.tid, stl.lvl, itValCnt, itValsStorage),
-        stl(stl) {
-    // Allocate enough storage for iterator values.
-    itValsStorage.resize(itValCnt);
-  }
+                   unsigned cursorValCnt)
+      : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
+        stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
+    assert(getCursor().size() == cursorValCnt);
+  };
 
 public:
   // For LLVM-style RTTI.
@@ -228,22 +221,34 @@ class ConcreteIterator : public SparseIterator {
   bool randomAccessible() const override { return isDenseLT(stl.getLT()); };
   bool iteratableByFor() const override { return kind != IterKind::kDedup; };
   Value upperBound(OpBuilder &b, Location l) const override {
-    return stl.size();
+    return stl.getSize();
   };
 
 protected:
+  const SparseTensorLevel &stl;
   // Owner of the storage, all wrappers build on top of a concrete iterator
   // share the same storage such that the iterator values are always
   // synchronized.
-  SmallVector<Value> itValsStorage;
-  const SparseTensorLevel &stl;
+  SmallVector<Value> cursorValsStorage;
 };
 
+} // namespace sparse_tensor
+} // namespace mlir
+
+namespace {
+
 class TrivialIterator : public ConcreteIterator {
 public:
   TrivialIterator(const SparseTensorLevel &stl)
       : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
 
+  std::string getDebugInterfacePrefix() const override {
+    return std::string("trivial<") + stl.toString() + ">";
+  }
+  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+    return {b.getIndexType()};
+  }
+
   SmallVector<Value> serialize() const override {
     SmallVector<Value> ret;
     ret.push_back(getItPos());
@@ -286,12 +291,12 @@ class TrivialIterator : public ConcreteIterator {
     return std::make_pair(getItPos(), posHi);
   }
 
-  Value genNotEnd(OpBuilder &b, Location l) override {
+  Value genNotEndImpl(OpBuilder &b, Location l) override {
     // We used the first level bound as the bound the collapsed set of levels.
     return CMPI(ult, getItPos(), posHi);
   }
 
-  Value deref(OpBuilder &b, Location l) override {
+  Value derefImpl(OpBuilder &b, Location l) override {
     if (randomAccessible()) {
       updateCrd(SUBI(getItPos(), posLo));
     } else {
@@ -302,24 +307,24 @@ class TrivialIterator : public ConcreteIterator {
 
   ValueRange forwardImpl(OpBuilder &b, Location l) override {
     seek(ADDI(getItPos(), C_IDX(1)));
-    return getItVals();
+    return getCursor();
   }
 
   ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
-    Value curPos = getItVals().front();
+    Value curPos = getCursor().front();
     Value nxPos = forward(b, l).front();
     seek(SELECT(cond, nxPos, curPos));
-    return getItVals();
+    return getCursor();
   }
 
-  void locate(OpBuilder &b, Location l, Value crd) override {
+  void locateImpl(OpBuilder &b, Location l, Value crd) override {
     assert(randomAccessible());
     // Seek to the linearized position.
     seek(ADDI(crd, posLo));
     updateCrd(crd);
   }
 
-  Value getItPos() const { return getItVals().front(); }
+  Value getItPos() const { return getCursor().front(); }
   Value posLo, posHi;
 };
 
@@ -337,6 +342,13 @@ class DedupIterator : public ...
[truncated]

@PeimingLiu PeimingLiu merged commit 4a653b4 into llvm:main Feb 1, 2024
@PeimingLiu PeimingLiu deleted the tensor-levels branch February 1, 2024 23:28
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants