-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] cleanup of enums header #71090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some DLT related methods leaked into sparse_tensor.h, and this moves it back to the right header. Also, the asserts were incomplete and some DLT methods duplicated.
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Aart Bik (aartbik) ChangesSome DLT related methods leaked into sparse_tensor.h, and this moves it back to the right header. Also, the asserts were incomplete and some DLT methods duplicated. Full diff: https://github.com/llvm/llvm-project/pull/71090.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1e9aa2bdf45dbdb..a867b99c3bfa5ba 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -153,24 +153,18 @@ enum class Action : uint32_t {
};
/// This enum defines all the sparse representations supportable by
-/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton) as well as
-/// the "properties" (ordered, unique). The encoding is chosen for
-/// performance of the runtime library, and thus may change in future
-/// versions; consequently, client code should use the predicate functions
-/// defined below, rather than relying on knowledge about the particular
-/// binary encoding.
+/// the SparseTensor dialect. We use a lightweight encoding to encode
+/// both the "format" per se (dense, compressed, singleton, loose_compressed,
+/// two-out-of-four) as well as the "properties" (ordered, unique). The
+/// encoding is chosen for performance of the runtime library, and thus may
+/// change in future versions; consequently, client code should use the
+/// predicate functions defined below, rather than relying on knowledge
+/// about the particular binary encoding.
///
/// The `Undef` "format" is a special value used internally for cases
/// where we need to store an undefined or indeterminate `DimLevelType`.
/// It should not be used externally, since it does not indicate an
/// actual/representable format.
-///
-// TODO: We should generalize TwoOutOfFour to N out of M and use property to
-// encode the value of N and M.
-// TODO: Update DimLevelType to use lower 8 bits for storage formats and the
-// higher 4 bits to store level properties. Consider LooseCompressed and
-// TwoOutOfFour as properties instead of formats.
enum class DimLevelType : uint8_t {
Undef = 0, // 0b00000_00
Dense = 4, // 0b00001_00
@@ -257,20 +251,12 @@ constexpr bool isUndefDLT(DimLevelType dlt) {
return dlt == DimLevelType::Undef;
}
-/// Check if the `DimLevelType` is dense.
+/// Check if the `DimLevelType` is dense (regardless of properties).
constexpr bool isDenseDLT(DimLevelType dlt) {
- return dlt == DimLevelType::Dense;
-}
-
-/// Check if the `DimLevelType` is 2:4
-constexpr bool isTwoOutOfFourDLT(DimLevelType dlt) {
- return dlt == DimLevelType::TwoOutOfFour;
+ return (static_cast<uint8_t>(dlt) & ~3) ==
+ static_cast<uint8_t>(DimLevelType::Dense);
}
-// We use the idiom `(dlt & ~3) == format` in order to only return true
-// for valid DLTs. Whereas the `dlt & format` idiom is a bit faster but
-// can return false-positives on invalid DLTs.
-
/// Check if the `DimLevelType` is compressed (regardless of properties).
constexpr bool isCompressedDLT(DimLevelType dlt) {
return (static_cast<uint8_t>(dlt) & ~3) ==
@@ -295,6 +281,17 @@ constexpr bool is2OutOf4DLT(DimLevelType dlt) {
static_cast<uint8_t>(DimLevelType::TwoOutOfFour);
}
+/// Check if the `DimLevelType` needs positions array.
+constexpr bool isDLTWithPos(DimLevelType dlt) {
+ return isLooseCompressedDLT(dlt) || isCompressedDLT(dlt);
+}
+
+/// Check if the `DimLevelType` needs coordinates array.
+constexpr bool isDLTWithCrd(DimLevelType dlt) {
+ return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
+ isCompressedDLT(dlt);
+}
+
/// Check if the `DimLevelType` is ordered (regardless of storage format).
constexpr bool isOrderedDLT(DimLevelType dlt) {
return !(static_cast<uint8_t>(dlt) & 2);
@@ -336,35 +333,52 @@ static_assert(
*getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton &&
*getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton &&
*getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton &&
- *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton),
+ *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton &&
+ *getLevelFormat(DimLevelType::LooseCompressed) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(DimLevelType::LooseCompressedNu) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(DimLevelType::LooseCompressedNo) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(DimLevelType::LooseCompressedNuNo) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(DimLevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
"getLevelFormat conversion is broken");
static_assert(
(buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
- DimLevelType::TwoOutOfFour &&
- *buildLevelType(LevelFormat::Compressed, true, true) ==
+ buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
+ buildLevelType(LevelFormat::Compressed, true, true) ==
DimLevelType::Compressed &&
- *buildLevelType(LevelFormat::Compressed, true, false) ==
+ buildLevelType(LevelFormat::Compressed, true, false) ==
DimLevelType::CompressedNu &&
- *buildLevelType(LevelFormat::Compressed, false, true) ==
+ buildLevelType(LevelFormat::Compressed, false, true) ==
DimLevelType::CompressedNo &&
- *buildLevelType(LevelFormat::Compressed, false, false) ==
+ buildLevelType(LevelFormat::Compressed, false, false) ==
DimLevelType::CompressedNuNo &&
- *buildLevelType(LevelFormat::Singleton, true, true) ==
+ buildLevelType(LevelFormat::Singleton, true, true) ==
DimLevelType::Singleton &&
- *buildLevelType(LevelFormat::Singleton, true, false) ==
+ buildLevelType(LevelFormat::Singleton, true, false) ==
DimLevelType::SingletonNu &&
- *buildLevelType(LevelFormat::Singleton, false, true) ==
+ buildLevelType(LevelFormat::Singleton, false, true) ==
DimLevelType::SingletonNo &&
- *buildLevelType(LevelFormat::Singleton, false, false) ==
- DimLevelType::SingletonNuNo),
+ buildLevelType(LevelFormat::Singleton, false, false) ==
+ DimLevelType::SingletonNuNo &&
+ buildLevelType(LevelFormat::LooseCompressed, true, true) ==
+ DimLevelType::LooseCompressed &&
+ buildLevelType(LevelFormat::LooseCompressed, true, false) ==
+ DimLevelType::LooseCompressedNu &&
+ buildLevelType(LevelFormat::LooseCompressed, false, true) ==
+ DimLevelType::LooseCompressedNo &&
+ buildLevelType(LevelFormat::LooseCompressed, false, false) ==
+ DimLevelType::LooseCompressedNuNo &&
+ buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
+ buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
+ buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
+ buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
+ DimLevelType::TwoOutOfFour),
"buildLevelType conversion is broken");
// Ensure the above predicates work as intended.
@@ -393,18 +407,28 @@ static_assert((!isCompressedDLT(DimLevelType::Dense) &&
!isCompressedDLT(DimLevelType::Singleton) &&
!isCompressedDLT(DimLevelType::SingletonNu) &&
!isCompressedDLT(DimLevelType::SingletonNo) &&
- !isCompressedDLT(DimLevelType::SingletonNuNo)),
+ !isCompressedDLT(DimLevelType::SingletonNuNo) &&
+ !isCompressedDLT(DimLevelType::LooseCompressed) &&
+ !isCompressedDLT(DimLevelType::LooseCompressedNu) &&
+ !isCompressedDLT(DimLevelType::LooseCompressedNo) &&
+ !isCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
+ !isCompressedDLT(DimLevelType::TwoOutOfFour)),
"isCompressedDLT definition is broken");
static_assert((!isLooseCompressedDLT(DimLevelType::Dense) &&
+ !isLooseCompressedDLT(DimLevelType::Compressed) &&
+ !isLooseCompressedDLT(DimLevelType::CompressedNu) &&
+ !isLooseCompressedDLT(DimLevelType::CompressedNo) &&
+ !isLooseCompressedDLT(DimLevelType::CompressedNuNo) &&
+ !isLooseCompressedDLT(DimLevelType::Singleton) &&
+ !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
+ !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
+ !isLooseCompressedDLT(DimLevelType::SingletonNuNo) &&
isLooseCompressedDLT(DimLevelType::LooseCompressed) &&
isLooseCompressedDLT(DimLevelType::LooseCompressedNu) &&
isLooseCompressedDLT(DimLevelType::LooseCompressedNo) &&
isLooseCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
- !isLooseCompressedDLT(DimLevelType::Singleton) &&
- !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
- !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
- !isLooseCompressedDLT(DimLevelType::SingletonNuNo)),
+ !isLooseCompressedDLT(DimLevelType::TwoOutOfFour)),
"isLooseCompressedDLT definition is broken");
static_assert((!isSingletonDLT(DimLevelType::Dense) &&
@@ -415,11 +439,15 @@ static_assert((!isSingletonDLT(DimLevelType::Dense) &&
isSingletonDLT(DimLevelType::Singleton) &&
isSingletonDLT(DimLevelType::SingletonNu) &&
isSingletonDLT(DimLevelType::SingletonNo) &&
- isSingletonDLT(DimLevelType::SingletonNuNo)),
+ isSingletonDLT(DimLevelType::SingletonNuNo) &&
+ !isSingletonDLT(DimLevelType::LooseCompressed) &&
+ !isSingletonDLT(DimLevelType::LooseCompressedNu) &&
+ !isSingletonDLT(DimLevelType::LooseCompressedNo) &&
+ !isSingletonDLT(DimLevelType::LooseCompressedNuNo) &&
+ !isSingletonDLT(DimLevelType::TwoOutOfFour)),
"isSingletonDLT definition is broken");
static_assert((isOrderedDLT(DimLevelType::Dense) &&
- isOrderedDLT(DimLevelType::TwoOutOfFour) &&
isOrderedDLT(DimLevelType::Compressed) &&
isOrderedDLT(DimLevelType::CompressedNu) &&
!isOrderedDLT(DimLevelType::CompressedNo) &&
@@ -431,11 +459,11 @@ static_assert((isOrderedDLT(DimLevelType::Dense) &&
isOrderedDLT(DimLevelType::LooseCompressed) &&
isOrderedDLT(DimLevelType::LooseCompressedNu) &&
!isOrderedDLT(DimLevelType::LooseCompressedNo) &&
- !isOrderedDLT(DimLevelType::LooseCompressedNuNo)),
+ !isOrderedDLT(DimLevelType::LooseCompressedNuNo) &&
+ isOrderedDLT(DimLevelType::TwoOutOfFour)),
"isOrderedDLT definition is broken");
static_assert((isUniqueDLT(DimLevelType::Dense) &&
- isUniqueDLT(DimLevelType::TwoOutOfFour) &&
isUniqueDLT(DimLevelType::Compressed) &&
!isUniqueDLT(DimLevelType::CompressedNu) &&
isUniqueDLT(DimLevelType::CompressedNo) &&
@@ -447,7 +475,8 @@ static_assert((isUniqueDLT(DimLevelType::Dense) &&
isUniqueDLT(DimLevelType::LooseCompressed) &&
!isUniqueDLT(DimLevelType::LooseCompressedNu) &&
isUniqueDLT(DimLevelType::LooseCompressedNo) &&
- !isUniqueDLT(DimLevelType::LooseCompressedNuNo)),
+ !isUniqueDLT(DimLevelType::LooseCompressedNuNo) &&
+ isUniqueDLT(DimLevelType::TwoOutOfFour)),
"isUniqueDLT definition is broken");
/// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 94e7d12b9ee915f..241d90a87165928 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,16 +89,6 @@ inline MemRefType getMemRefType(T &&t) {
/// Returns null-attribute for any type without an encoding.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
-/// Convenience method to query whether a given DLT needs both position and
-/// coordinates array or only coordinates array.
-constexpr inline bool isDLTWithPos(DimLevelType dlt) {
- return isLooseCompressedDLT(dlt) || isCompressedDLT(dlt);
-}
-constexpr inline bool isDLTWithCrd(DimLevelType dlt) {
- return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
- isCompressedDLT(dlt);
-}
-
/// Returns true iff the given sparse tensor encoding attribute has a trailing
/// COO region starting at the given level.
bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 3c73b19319e588c..e7c6435e997ca00 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -371,7 +371,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
::mlir::sparse_tensor::DimLevelType getLvlType(::mlir::sparse_tensor::Level l) const;
bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseDLT(getLvlType(l)); }
- bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return isTwoOutOfFourDLT(getLvlType(l)); }
+ bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4DLT(getLvlType(l)); }
bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedDLT(getLvlType(l)); }
bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedDLT(getLvlType(l)); }
bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonDLT(getLvlType(l)); }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6080317d07a64e0..97ef753aacf35b1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -10,6 +10,7 @@
#include "Detail/DimLvlMapParser.h"
+#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
|
PeimingLiu
reviewed
Nov 2, 2023
yinying-lisa-li
approved these changes
Nov 2, 2023
PeimingLiu
approved these changes
Nov 2, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Some DLT related methods leaked into sparse_tensor.h, and this moves it back to the right header. Also, the asserts were incomplete and some DLT methods duplicated.