Skip to content

[mlir][sparse] rename files and unifies APIs #88162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 9, 2024

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Apr 9, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/88162.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+1-1)
  • (renamed) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp (+41-23)
  • (renamed) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h (+1)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 3c0f82fc00bb9d..af3a1b48f45af9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -20,7 +20,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   Utils/IterationGraphSorter.cpp
   Utils/LoopEmitter.cpp
   Utils/SparseTensorDescriptor.cpp
-  Utils/SparseTensorLevel.cpp
+  Utils/SparseTensorIterator.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index b5a0ac8484abdd..59c3e49264dbe1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -11,7 +11,7 @@
 
 #include <vector>
 
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
 
 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
similarity index 96%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index bc27fae5d19480..3f2ee89c854d69 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
 #include "CodegenUtils.h"
 
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -46,21 +46,41 @@ using ValueTuple = std::tuple<Value, Value, Value>;
 
 namespace {
 
+template <bool hasPosBuffer>
 class SparseLevel : public SparseTensorLevel {
+  // It is either a array of size 2 or size 1 depending on whether the space
+  // level requires a position array.
+  using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
+                                     std::array<Value, 1>>;
+
 public:
   SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-              Value crdBuffer)
-      : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
+              BufferT buffers)
+      : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
+
+  ValueRange getLvlBuffers() const override { return buffers; }
 
   Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                   Value iv) const override {
     SmallVector<Value> memCrd(batchPrefix);
     memCrd.push_back(iv);
-    return genIndexLoad(b, l, crdBuffer, memCrd);
+    return genIndexLoad(b, l, getCrdBuf(), memCrd);
   }
 
 protected:
-  const Value crdBuffer;
+  template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
+  Value getPosBuf() const {
+    return buffers[0];
+  }
+
+  Value getCrdBuf() const {
+    if constexpr (hasPosBuffer)
+      return buffers[1];
+    else
+      return buffers[0];
+  }
+
+  const BufferT buffers;
 };
 
 class DenseLevel : public SparseTensorLevel {
@@ -72,6 +92,8 @@ class DenseLevel : public SparseTensorLevel {
     llvm_unreachable("locate random-accessible level instead");
   }
 
+  ValueRange getLvlBuffers() const override { return {}; }
+
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
                         Value max) const override {
     Value posLo = MULI(p, lvlSize);
@@ -88,6 +110,8 @@ class BatchLevel : public SparseTensorLevel {
     llvm_unreachable("locate random-accessible level instead");
   }
 
+  ValueRange getLvlBuffers() const override { return {}; }
+
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
                         Value max) const override {
     assert(max == nullptr && "Dense level can not be non-unique.");
@@ -96,11 +120,11 @@ class BatchLevel : public SparseTensorLevel {
   }
 };
 
-class CompressedLevel : public SparseLevel {
+class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
 public:
   CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                   Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value max) const override {
@@ -109,21 +133,18 @@ class CompressedLevel : public SparseLevel {
 
     SmallVector<Value> memCrd(batchPrefix);
     memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
     memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
     return {pLo, pHi};
   }
-
-private:
-  const Value posBuffer;
 };
 
-class LooseCompressedLevel : public SparseLevel {
+class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
 public:
   LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                        Value posBuffer, Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value max) const override {
@@ -133,21 +154,18 @@ class LooseCompressedLevel : public SparseLevel {
 
     p = MULI(p, C_IDX(2));
     memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
     memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
     return {pLo, pHi};
   }
-
-private:
-  const Value posBuffer;
 };
 
-class SingletonLevel : public SparseLevel {
+class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
 public:
   SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                  Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value segHi) const override {
@@ -159,11 +177,11 @@ class SingletonLevel : public SparseLevel {
   }
 };
 
-class NOutOfMLevel : public SparseLevel {
+class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
 public:
   NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
                Value crdBuffer)
-      : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+      : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
                         Value p, Value max) const override {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
similarity index 99%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 9f92eecdf75cb6..19c0dc942ca62f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -55,6 +55,7 @@ class SparseTensorLevel {
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
   Value getSize() const { return lvlSize; }
+  virtual ValueRange getLvlBuffers() const = 0;
 
   //
   // Level properties

@PeimingLiu PeimingLiu merged commit a454d92 into llvm:main Apr 9, 2024
@PeimingLiu PeimingLiu deleted the cherry-picks branch April 9, 2024 17:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants