Skip to content

[mlir][sparse] cleanup of enums header #71090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 2, 2023
Merged

[mlir][sparse] cleanup of enums header #71090

merged 5 commits into from
Nov 2, 2023

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Nov 2, 2023

Some DLT related methods leaked into sparse_tensor.h, and this moves it back to the right header. Also, the asserts were incomplete and some DLT methods duplicated.

Some DLT related methods leaked into sparse_tensor.h, and this
moves it back to the right header. Also, the asserts were incomplete
and some DLT methods duplicated.
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Nov 2, 2023
@llvmbot
Copy link
Member

llvmbot commented Nov 2, 2023

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Aart Bik (aartbik)

Changes

Some DLT related methods leaked into sparse_tensor.h, and this moves it back to the right header. Also, the asserts were incomplete and some DLT methods duplicated.


Full diff: https://github.com/llvm/llvm-project/pull/71090.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+79-50)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h (-10)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+1)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1e9aa2bdf45dbdb..a867b99c3bfa5ba 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -153,24 +153,18 @@ enum class Action : uint32_t {
 };
 
 /// This enum defines all the sparse representations supportable by
-/// the SparseTensor dialect.  We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton) as well as
-/// the "properties" (ordered, unique).  The encoding is chosen for
-/// performance of the runtime library, and thus may change in future
-/// versions; consequently, client code should use the predicate functions
-/// defined below, rather than relying on knowledge about the particular
-/// binary encoding.
+/// the SparseTensor dialect. We use a lightweight encoding to encode
+/// both the "format" per se (dense, compressed, singleton, loose_compressed,
+/// two-out-of-four) as well as the "properties" (ordered, unique). The
+/// encoding is chosen for performance of the runtime library, and thus may
+/// change in future versions; consequently, client code should use the
+/// predicate functions defined below, rather than relying on knowledge
+/// about the particular binary encoding.
 ///
 /// The `Undef` "format" is a special value used internally for cases
 /// where we need to store an undefined or indeterminate `DimLevelType`.
 /// It should not be used externally, since it does not indicate an
 /// actual/representable format.
-///
-// TODO: We should generalize TwoOutOfFour to N out of M and use property to
-// encode the value of N and M.
-// TODO: Update DimLevelType to use lower 8 bits for storage formats and the
-// higher 4 bits to store level properties. Consider LooseCompressed and
-// TwoOutOfFour as properties instead of formats.
 enum class DimLevelType : uint8_t {
   Undef = 0,                // 0b00000_00
   Dense = 4,                // 0b00001_00
@@ -257,20 +251,12 @@ constexpr bool isUndefDLT(DimLevelType dlt) {
   return dlt == DimLevelType::Undef;
 }
 
-/// Check if the `DimLevelType` is dense.
+/// Check if the `DimLevelType` is dense (regardless of properties).
 constexpr bool isDenseDLT(DimLevelType dlt) {
-  return dlt == DimLevelType::Dense;
-}
-
-/// Check if the `DimLevelType` is 2:4
-constexpr bool isTwoOutOfFourDLT(DimLevelType dlt) {
-  return dlt == DimLevelType::TwoOutOfFour;
+  return (static_cast<uint8_t>(dlt) & ~3) ==
+         static_cast<uint8_t>(DimLevelType::Dense);
 }
 
-// We use the idiom `(dlt & ~3) == format` in order to only return true
-// for valid DLTs.  Whereas the `dlt & format` idiom is a bit faster but
-// can return false-positives on invalid DLTs.
-
 /// Check if the `DimLevelType` is compressed (regardless of properties).
 constexpr bool isCompressedDLT(DimLevelType dlt) {
   return (static_cast<uint8_t>(dlt) & ~3) ==
@@ -295,6 +281,17 @@ constexpr bool is2OutOf4DLT(DimLevelType dlt) {
          static_cast<uint8_t>(DimLevelType::TwoOutOfFour);
 }
 
+/// Check if the `DimLevelType` needs positions array.
+constexpr bool isDLTWithPos(DimLevelType dlt) {
+  return isLooseCompressedDLT(dlt) || isCompressedDLT(dlt);
+}
+
+/// Check if the `DimLevelType` needs coordinates array.
+constexpr bool isDLTWithCrd(DimLevelType dlt) {
+  return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
+         isCompressedDLT(dlt);
+}
+
 /// Check if the `DimLevelType` is ordered (regardless of storage format).
 constexpr bool isOrderedDLT(DimLevelType dlt) {
   return !(static_cast<uint8_t>(dlt) & 2);
@@ -336,35 +333,52 @@ static_assert(
      *getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton &&
      *getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton &&
      *getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton &&
-     *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton),
+     *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton &&
+     *getLevelFormat(DimLevelType::LooseCompressed) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(DimLevelType::LooseCompressedNu) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(DimLevelType::LooseCompressedNo) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(DimLevelType::LooseCompressedNuNo) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(DimLevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
     "getLevelFormat conversion is broken");
 
 static_assert(
     (buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
      buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
      buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
-     *buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
-     *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
-         DimLevelType::TwoOutOfFour &&
-     *buildLevelType(LevelFormat::Compressed, true, true) ==
+     buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
+     buildLevelType(LevelFormat::Compressed, true, true) ==
          DimLevelType::Compressed &&
-     *buildLevelType(LevelFormat::Compressed, true, false) ==
+     buildLevelType(LevelFormat::Compressed, true, false) ==
          DimLevelType::CompressedNu &&
-     *buildLevelType(LevelFormat::Compressed, false, true) ==
+     buildLevelType(LevelFormat::Compressed, false, true) ==
          DimLevelType::CompressedNo &&
-     *buildLevelType(LevelFormat::Compressed, false, false) ==
+     buildLevelType(LevelFormat::Compressed, false, false) ==
          DimLevelType::CompressedNuNo &&
-     *buildLevelType(LevelFormat::Singleton, true, true) ==
+     buildLevelType(LevelFormat::Singleton, true, true) ==
          DimLevelType::Singleton &&
-     *buildLevelType(LevelFormat::Singleton, true, false) ==
+     buildLevelType(LevelFormat::Singleton, true, false) ==
          DimLevelType::SingletonNu &&
-     *buildLevelType(LevelFormat::Singleton, false, true) ==
+     buildLevelType(LevelFormat::Singleton, false, true) ==
          DimLevelType::SingletonNo &&
-     *buildLevelType(LevelFormat::Singleton, false, false) ==
-         DimLevelType::SingletonNuNo),
+     buildLevelType(LevelFormat::Singleton, false, false) ==
+         DimLevelType::SingletonNuNo &&
+     buildLevelType(LevelFormat::LooseCompressed, true, true) ==
+         DimLevelType::LooseCompressed &&
+     buildLevelType(LevelFormat::LooseCompressed, true, false) ==
+         DimLevelType::LooseCompressedNu &&
+     buildLevelType(LevelFormat::LooseCompressed, false, true) ==
+         DimLevelType::LooseCompressedNo &&
+     buildLevelType(LevelFormat::LooseCompressed, false, false) ==
+         DimLevelType::LooseCompressedNuNo &&
+     buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
+     buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
+     buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
+     buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
+         DimLevelType::TwoOutOfFour),
     "buildLevelType conversion is broken");
 
 // Ensure the above predicates work as intended.
@@ -393,18 +407,28 @@ static_assert((!isCompressedDLT(DimLevelType::Dense) &&
                !isCompressedDLT(DimLevelType::Singleton) &&
                !isCompressedDLT(DimLevelType::SingletonNu) &&
                !isCompressedDLT(DimLevelType::SingletonNo) &&
-               !isCompressedDLT(DimLevelType::SingletonNuNo)),
+               !isCompressedDLT(DimLevelType::SingletonNuNo) &&
+               !isCompressedDLT(DimLevelType::LooseCompressed) &&
+               !isCompressedDLT(DimLevelType::LooseCompressedNu) &&
+               !isCompressedDLT(DimLevelType::LooseCompressedNo) &&
+               !isCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
+               !isCompressedDLT(DimLevelType::TwoOutOfFour)),
               "isCompressedDLT definition is broken");
 
 static_assert((!isLooseCompressedDLT(DimLevelType::Dense) &&
+               !isLooseCompressedDLT(DimLevelType::Compressed) &&
+               !isLooseCompressedDLT(DimLevelType::CompressedNu) &&
+               !isLooseCompressedDLT(DimLevelType::CompressedNo) &&
+               !isLooseCompressedDLT(DimLevelType::CompressedNuNo) &&
+               !isLooseCompressedDLT(DimLevelType::Singleton) &&
+               !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
+               !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
+               !isLooseCompressedDLT(DimLevelType::SingletonNuNo) &&
                isLooseCompressedDLT(DimLevelType::LooseCompressed) &&
                isLooseCompressedDLT(DimLevelType::LooseCompressedNu) &&
                isLooseCompressedDLT(DimLevelType::LooseCompressedNo) &&
                isLooseCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
-               !isLooseCompressedDLT(DimLevelType::Singleton) &&
-               !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
-               !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
-               !isLooseCompressedDLT(DimLevelType::SingletonNuNo)),
+               !isLooseCompressedDLT(DimLevelType::TwoOutOfFour)),
               "isLooseCompressedDLT definition is broken");
 
 static_assert((!isSingletonDLT(DimLevelType::Dense) &&
@@ -415,11 +439,15 @@ static_assert((!isSingletonDLT(DimLevelType::Dense) &&
                isSingletonDLT(DimLevelType::Singleton) &&
                isSingletonDLT(DimLevelType::SingletonNu) &&
                isSingletonDLT(DimLevelType::SingletonNo) &&
-               isSingletonDLT(DimLevelType::SingletonNuNo)),
+               isSingletonDLT(DimLevelType::SingletonNuNo) &&
+               !isSingletonDLT(DimLevelType::LooseCompressed) &&
+               !isSingletonDLT(DimLevelType::LooseCompressedNu) &&
+               !isSingletonDLT(DimLevelType::LooseCompressedNo) &&
+               !isSingletonDLT(DimLevelType::LooseCompressedNuNo) &&
+               !isSingletonDLT(DimLevelType::TwoOutOfFour)),
               "isSingletonDLT definition is broken");
 
 static_assert((isOrderedDLT(DimLevelType::Dense) &&
-               isOrderedDLT(DimLevelType::TwoOutOfFour) &&
                isOrderedDLT(DimLevelType::Compressed) &&
                isOrderedDLT(DimLevelType::CompressedNu) &&
                !isOrderedDLT(DimLevelType::CompressedNo) &&
@@ -431,11 +459,11 @@ static_assert((isOrderedDLT(DimLevelType::Dense) &&
                isOrderedDLT(DimLevelType::LooseCompressed) &&
                isOrderedDLT(DimLevelType::LooseCompressedNu) &&
                !isOrderedDLT(DimLevelType::LooseCompressedNo) &&
-               !isOrderedDLT(DimLevelType::LooseCompressedNuNo)),
+               !isOrderedDLT(DimLevelType::LooseCompressedNuNo) &&
+               isOrderedDLT(DimLevelType::TwoOutOfFour)),
               "isOrderedDLT definition is broken");
 
 static_assert((isUniqueDLT(DimLevelType::Dense) &&
-               isUniqueDLT(DimLevelType::TwoOutOfFour) &&
                isUniqueDLT(DimLevelType::Compressed) &&
                !isUniqueDLT(DimLevelType::CompressedNu) &&
                isUniqueDLT(DimLevelType::CompressedNo) &&
@@ -447,7 +475,8 @@ static_assert((isUniqueDLT(DimLevelType::Dense) &&
                isUniqueDLT(DimLevelType::LooseCompressed) &&
                !isUniqueDLT(DimLevelType::LooseCompressedNu) &&
                isUniqueDLT(DimLevelType::LooseCompressedNo) &&
-               !isUniqueDLT(DimLevelType::LooseCompressedNuNo)),
+               !isUniqueDLT(DimLevelType::LooseCompressedNuNo) &&
+               isUniqueDLT(DimLevelType::TwoOutOfFour)),
               "isUniqueDLT definition is broken");
 
 /// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 94e7d12b9ee915f..241d90a87165928 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,16 +89,6 @@ inline MemRefType getMemRefType(T &&t) {
 /// Returns null-attribute for any type without an encoding.
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 
-/// Convenience method to query whether a given DLT needs both position and
-/// coordinates array or only coordinates array.
-constexpr inline bool isDLTWithPos(DimLevelType dlt) {
-  return isLooseCompressedDLT(dlt) || isCompressedDLT(dlt);
-}
-constexpr inline bool isDLTWithCrd(DimLevelType dlt) {
-  return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
-         isCompressedDLT(dlt);
-}
-
 /// Returns true iff the given sparse tensor encoding attribute has a trailing
 /// COO region starting at the given level.
 bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 3c73b19319e588c..e7c6435e997ca00 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -371,7 +371,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     ::mlir::sparse_tensor::DimLevelType getLvlType(::mlir::sparse_tensor::Level l) const;
 
     bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseDLT(getLvlType(l)); }
-    bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return isTwoOutOfFourDLT(getLvlType(l)); }
+    bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4DLT(getLvlType(l)); }
     bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedDLT(getLvlType(l)); }
     bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedDLT(getLvlType(l)); }
     bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonDLT(getLvlType(l)); }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6080317d07a64e0..97ef753aacf35b1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -10,6 +10,7 @@
 
 #include "Detail/DimLvlMapParser.h"
 
+#include "mlir/Dialect/SparseTensor/IR/Enums.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"

@aartbik aartbik merged commit ee3ee13 into llvm:main Nov 2, 2023
@aartbik aartbik deleted the bik branch November 2, 2023 20:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants