-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] rename files and unifies APIs #88162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) ChangesFull diff: https://github.com/llvm/llvm-project/pull/88162.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 3c0f82fc00bb9d..af3a1b48f45af9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -20,7 +20,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
Utils/IterationGraphSorter.cpp
Utils/LoopEmitter.cpp
Utils/SparseTensorDescriptor.cpp
- Utils/SparseTensorLevel.cpp
+ Utils/SparseTensorIterator.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index b5a0ac8484abdd..59c3e49264dbe1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -11,7 +11,7 @@
#include <vector>
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
similarity index 96%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index bc27fae5d19480..3f2ee89c854d69 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "SparseTensorLevel.h"
+#include "SparseTensorIterator.h"
#include "CodegenUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -46,21 +46,41 @@ using ValueTuple = std::tuple<Value, Value, Value>;
namespace {
+template <bool hasPosBuffer>
class SparseLevel : public SparseTensorLevel {
+ // It is either a array of size 2 or size 1 depending on whether the space
+ // level requires a position array.
+ using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
+ std::array<Value, 1>>;
+
public:
SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuffer)
- : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
+ BufferT buffers)
+ : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
+
+ ValueRange getLvlBuffers() const override { return buffers; }
Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value iv) const override {
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(iv);
- return genIndexLoad(b, l, crdBuffer, memCrd);
+ return genIndexLoad(b, l, getCrdBuf(), memCrd);
}
protected:
- const Value crdBuffer;
+ template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
+ Value getPosBuf() const {
+ return buffers[0];
+ }
+
+ Value getCrdBuf() const {
+ if constexpr (hasPosBuffer)
+ return buffers[1];
+ else
+ return buffers[0];
+ }
+
+ const BufferT buffers;
};
class DenseLevel : public SparseTensorLevel {
@@ -72,6 +92,8 @@ class DenseLevel : public SparseTensorLevel {
llvm_unreachable("locate random-accessible level instead");
}
+ ValueRange getLvlBuffers() const override { return {}; }
+
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
Value posLo = MULI(p, lvlSize);
@@ -88,6 +110,8 @@ class BatchLevel : public SparseTensorLevel {
llvm_unreachable("locate random-accessible level instead");
}
+ ValueRange getLvlBuffers() const override { return {}; }
+
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
@@ -96,11 +120,11 @@ class BatchLevel : public SparseTensorLevel {
}
};
-class CompressedLevel : public SparseLevel {
+class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
public:
CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
@@ -109,21 +133,18 @@ class CompressedLevel : public SparseLevel {
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}
-
-private:
- const Value posBuffer;
};
-class LooseCompressedLevel : public SparseLevel {
+class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
public:
LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
@@ -133,21 +154,18 @@ class LooseCompressedLevel : public SparseLevel {
p = MULI(p, C_IDX(2));
memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}
-
-private:
- const Value posBuffer;
};
-class SingletonLevel : public SparseLevel {
+class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
public:
SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value segHi) const override {
@@ -159,11 +177,11 @@ class SingletonLevel : public SparseLevel {
}
};
-class NOutOfMLevel : public SparseLevel {
+class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
public:
NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
similarity index 99%
rename from mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
rename to mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 9f92eecdf75cb6..19c0dc942ca62f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -55,6 +55,7 @@ class SparseTensorLevel {
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
Value getSize() const { return lvlSize; }
+ virtual ValueRange getLvlBuffers() const = 0;
//
// Level properties
|
aartbik
reviewed
Apr 9, 2024
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
Outdated
Show resolved
Hide resolved
yinying-lisa-li
approved these changes
Apr 9, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.