-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Reapply "[mlir][sparse] remove LevelType enum, construct LevelType from LevelFormat and Properties" (#81923) #81934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…om LevelF…" (llvm#81923) This reverts commit 513448d.
@llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) ChangesPatch is 36.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81934.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 74cc0dee554a17..c7db5beb2015a6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -153,45 +153,9 @@ enum class Action : uint32_t {
kSortCOOInPlace = 8,
};
-/// This enum defines all the sparse representations supportable by
-/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// the "format" per se (dense, compressed, singleton, loose_compressed,
-/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
-/// the format is NOutOfM.
-/// The encoding is chosen for performance of the runtime library, and thus may
-/// change in future versions; consequently, client code should use the
-/// predicate functions defined below, rather than relying on knowledge
-/// about the particular binary encoding.
-///
-/// The `Undef` "format" is a special value used internally for cases
-/// where we need to store an undefined or indeterminate `LevelType`.
-/// It should not be used externally, since it does not indicate an
-/// actual/representable format.
-///
-/// Bit manipulations for LevelType:
-///
-/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
-///
-enum class LevelType : uint64_t {
- Undef = 0x000000000000,
- Dense = 0x000000010000,
- Compressed = 0x000000020000,
- CompressedNu = 0x000000020001,
- CompressedNo = 0x000000020002,
- CompressedNuNo = 0x000000020003,
- Singleton = 0x000000040000,
- SingletonNu = 0x000000040001,
- SingletonNo = 0x000000040002,
- SingletonNuNo = 0x000000040003,
- LooseCompressed = 0x000000080000,
- LooseCompressedNu = 0x000000080001,
- LooseCompressedNo = 0x000000080002,
- LooseCompressedNuNo = 0x000000080003,
- NOutOfM = 0x000000100000,
-};
-
/// This enum defines all supported storage format without the level properties.
enum class LevelFormat : uint64_t {
+ Undef = 0x00000000,
Dense = 0x00010000,
Compressed = 0x00020000,
Singleton = 0x00040000,
@@ -199,327 +163,240 @@ enum class LevelFormat : uint64_t {
NOutOfM = 0x00100000,
};
+template <LevelFormat... targets>
+constexpr bool isAnyOfFmt(LevelFormat fmt) {
+ return (... || (targets == fmt));
+}
+
+/// Returns string representation of the given level format.
+constexpr const char *toFormatString(LevelFormat lvlFmt) {
+ switch (lvlFmt) {
+ case LevelFormat::Undef:
+ return "undef";
+ case LevelFormat::Dense:
+ return "dense";
+ case LevelFormat::Compressed:
+ return "compressed";
+ case LevelFormat::Singleton:
+ return "singleton";
+ case LevelFormat::LooseCompressed:
+ return "loose_compressed";
+ case LevelFormat::NOutOfM:
+ return "structured";
+ }
+ return "";
+}
+
/// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint64_t {
+enum class LevelPropNonDefault : uint64_t {
Nonunique = 0x0001,
Nonordered = 0x0002,
};
-/// Get N of NOutOfM level type.
-constexpr uint64_t getN(LevelType lt) {
- return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+/// Returns string representation of the given level properties.
+constexpr const char *toPropString(LevelPropNonDefault lvlProp) {
+ switch (lvlProp) {
+ case LevelPropNonDefault::Nonunique:
+ return "nonunique";
+ case LevelPropNonDefault::Nonordered:
+ return "nonordered";
+ }
+ return "";
}
-/// Get M of NOutOfM level type.
-constexpr uint64_t getM(LevelType lt) {
- return (static_cast<uint64_t>(lt) >> 40) & 0xff;
-}
+/// This enum defines all the sparse representations supportable by
+/// the SparseTensor dialect. We use a lightweight encoding to encode
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
+/// the format is NOutOfM.
+/// The encoding is chosen for performance of the runtime library, and thus may
+/// change in future versions; consequently, client code should use the
+/// predicate functions defined below, rather than relying on knowledge
+/// about the particular binary encoding.
+///
+/// The `Undef` "format" is a special value used internally for cases
+/// where we need to store an undefined or indeterminate `LevelType`.
+/// It should not be used externally, since it does not indicate an
+/// actual/representable format.
-/// Convert N of NOutOfM level type to the stored bits.
-constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+struct LevelType {
+public:
+ /// Check that the `LevelType` contains a valid (possibly undefined) value.
+ static constexpr bool isValidLvlBits(uint64_t lvlBits) {
+ auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
+ const uint64_t propertyBits = lvlBits & 0xffff;
+ // If undefined/dense/NOutOfM, then must be unique and ordered.
+ // Otherwise, the format must be one of the known ones.
+ return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
+ LevelFormat::NOutOfM>(fmt))
+ ? (propertyBits == 0)
+ : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
+ LevelFormat::LooseCompressed>(fmt));
+ }
-/// Convert M of NOutOfM level type to the stored bits.
-constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+ /// Convert a LevelFormat to its corresponding LevelType with the given
+ /// properties. Returns std::nullopt when the properties are not applicable
+ /// for the input level format.
+ static std::optional<LevelType>
+ buildLvlType(LevelFormat lf,
+ const std::vector<LevelPropNonDefault> &properties,
+ uint64_t n = 0, uint64_t m = 0) {
+ assert((n & 0xff) == n && (m & 0xff) == m);
+ uint64_t newN = n << 32;
+ uint64_t newM = m << 40;
+ uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM;
+ for (auto p : properties)
+ ltBits |= static_cast<uint64_t>(p);
+
+ return isValidLvlBits(ltBits) ? std::optional(LevelType(ltBits))
+ : std::nullopt;
+ }
+ static std::optional<LevelType> buildLvlType(LevelFormat lf, bool ordered,
+ bool unique, uint64_t n = 0,
+ uint64_t m = 0) {
+ std::vector<LevelPropNonDefault> properties;
+ if (!ordered)
+ properties.push_back(LevelPropNonDefault::Nonordered);
+ if (!unique)
+ properties.push_back(LevelPropNonDefault::Nonunique);
+ return buildLvlType(lf, properties, n, m);
+ }
-/// Check if the `LevelType` is NOutOfM (regardless of
-/// properties and block sizes).
-constexpr bool isNOutOfMLT(LevelType lt) {
- return ((static_cast<uint64_t>(lt) & 0x100000) ==
- static_cast<uint64_t>(LevelType::NOutOfM));
-}
+ /// Explicit conversion from uint64_t.
+ constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) {
+ assert(isValidLvlBits(bits));
+ };
-/// Check if the `LevelType` is NOutOfM with the correct block sizes.
-constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
- return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
-}
+ /// Constructs a LevelType with the given format using all default properties.
+ /*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast<uint64_t>(f)) {
+ assert(isValidLvlBits(lvlBits) && !isa<LevelFormat::NOutOfM>());
+ };
-/// Returns string representation of the given dimension level type.
-constexpr const char *toMLIRString(LevelType lvlType) {
- auto lt = static_cast<LevelType>(static_cast<uint64_t>(lvlType) & 0xffffffff);
- switch (lt) {
- case LevelType::Undef:
- return "undef";
- case LevelType::Dense:
- return "dense";
- case LevelType::Compressed:
- return "compressed";
- case LevelType::CompressedNu:
- return "compressed(nonunique)";
- case LevelType::CompressedNo:
- return "compressed(nonordered)";
- case LevelType::CompressedNuNo:
- return "compressed(nonunique, nonordered)";
- case LevelType::Singleton:
- return "singleton";
- case LevelType::SingletonNu:
- return "singleton(nonunique)";
- case LevelType::SingletonNo:
- return "singleton(nonordered)";
- case LevelType::SingletonNuNo:
- return "singleton(nonunique, nonordered)";
- case LevelType::LooseCompressed:
- return "loose_compressed";
- case LevelType::LooseCompressedNu:
- return "loose_compressed(nonunique)";
- case LevelType::LooseCompressedNo:
- return "loose_compressed(nonordered)";
- case LevelType::LooseCompressedNuNo:
- return "loose_compressed(nonunique, nonordered)";
- case LevelType::NOutOfM:
- return "structured";
+ /// Converts to uint64_t
+ explicit operator uint64_t() const { return lvlBits; }
+
+ bool operator==(const LevelType lhs) const {
+ return static_cast<uint64_t>(lhs) == lvlBits;
}
- return "";
-}
+ bool operator!=(const LevelType lhs) const { return !(*this == lhs); }
-/// Check that the `LevelType` contains a valid (possibly undefined) value.
-constexpr bool isValidLT(LevelType lt) {
- const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
- const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
- // If undefined/dense/NOutOfM, then must be unique and ordered.
- // Otherwise, the format must be one of the known ones.
- return (formatBits <= 0x10000 || formatBits == 0x100000)
- ? (propertyBits == 0)
- : (formatBits == 0x20000 || formatBits == 0x40000 ||
- formatBits == 0x80000);
-}
+ LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); }
-/// Check if the `LevelType` is the special undefined value.
-constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
+ /// Get N of NOutOfM level type.
+ constexpr uint64_t getN() const {
+ assert(isa<LevelFormat::NOutOfM>());
+ return (lvlBits >> 32) & 0xff;
+ }
-/// Check if the `LevelType` is dense (regardless of properties).
-constexpr bool isDenseLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~0xffff) ==
- static_cast<uint64_t>(LevelType::Dense);
-}
+ /// Get M of NOutOfM level type.
+ constexpr uint64_t getM() const {
+ assert(isa<LevelFormat::NOutOfM>());
+ return (lvlBits >> 40) & 0xff;
+ }
-/// Check if the `LevelType` is compressed (regardless of properties).
-constexpr bool isCompressedLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~0xffff) ==
- static_cast<uint64_t>(LevelType::Compressed);
-}
+ /// Get the `LevelFormat` of the `LevelType`.
+ LevelFormat getLvlFmt() const {
+ return static_cast<LevelFormat>(lvlBits & 0xffff0000);
+ }
-/// Check if the `LevelType` is singleton (regardless of properties).
-constexpr bool isSingletonLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~0xffff) ==
- static_cast<uint64_t>(LevelType::Singleton);
-}
+ /// Check if the `LevelType` is in the `LevelFormat`.
+ template <LevelFormat fmt>
+ constexpr bool isa() const {
+ return getLvlFmt() == fmt;
+ }
-/// Check if the `LevelType` is loose compressed (regardless of properties).
-constexpr bool isLooseCompressedLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~0xffff) ==
- static_cast<uint64_t>(LevelType::LooseCompressed);
-}
+ /// Check if the `LevelType` has the properties
+ template <LevelPropNonDefault p>
+ constexpr bool isa() const {
+ return lvlBits & static_cast<uint64_t>(p);
+ }
-/// Check if the `LevelType` needs positions array.
-constexpr bool isWithPosLT(LevelType lt) {
- return isCompressedLT(lt) || isLooseCompressedLT(lt);
-}
+ /// Check if the `LevelType` needs positions array.
+ constexpr bool isWithPosLT() const {
+ return isa<LevelFormat::Compressed>() ||
+ isa<LevelFormat::LooseCompressed>();
+ }
-/// Check if the `LevelType` needs coordinates array.
-constexpr bool isWithCrdLT(LevelType lt) {
- return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- isNOutOfMLT(lt);
-}
+ /// Check if the `LevelType` needs coordinates array.
+ constexpr bool isWithCrdLT() const {
+ // All sparse levels has coordinate array.
+ return !isa<LevelFormat::Dense>();
+ }
-/// Check if the `LevelType` is ordered (regardless of storage format).
-constexpr bool isOrderedLT(LevelType lt) {
- return !(static_cast<uint64_t>(lt) & 2);
- return !(static_cast<uint64_t>(lt) & 2);
-}
+ std::string toMLIRString() const {
+ std::string lvlStr = toFormatString(getLvlFmt());
+ std::string propStr = "";
+ if (isa<LevelPropNonDefault::Nonunique>())
+ propStr += toPropString(LevelPropNonDefault::Nonunique);
+
+ if (isa<LevelPropNonDefault::Nonordered>()) {
+ if (!propStr.empty())
+ propStr += ", ";
+ propStr += toPropString(LevelPropNonDefault::Nonordered);
+ }
+ if (!propStr.empty())
+ lvlStr += ("(" + propStr + ")");
+ return lvlStr;
+ }
-/// Check if the `LevelType` is unique (regardless of storage format).
-constexpr bool isUniqueLT(LevelType lt) {
- return !(static_cast<uint64_t>(lt) & 1);
- return !(static_cast<uint64_t>(lt) & 1);
-}
+private:
+ /// Bit manipulations for LevelType:
+ ///
+ /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+ ///
+ uint64_t lvlBits;
+};
-/// Convert a LevelType to its corresponding LevelFormat.
-/// Returns std::nullopt when input lt is Undef.
-constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
- if (lt == LevelType::Undef)
- return std::nullopt;
- return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
-}
+// For backward-compatibility. TODO: remove below after fully migration.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
-/// Convert a LevelFormat to its corresponding LevelType with the given
-/// properties. Returns std::nullopt when the properties are not applicable
-/// for the input level format.
inline std::optional<LevelType>
buildLevelType(LevelFormat lf,
- const std::vector<LevelPropertyNondefault> &properties,
+ const std::vector<LevelPropNonDefault> &properties,
uint64_t n = 0, uint64_t m = 0) {
- uint64_t newN = n << 32;
- uint64_t newM = m << 40;
- uint64_t ltInt = static_cast<uint64_t>(lf) | newN | newM;
- for (auto p : properties) {
- ltInt |= static_cast<uint64_t>(p);
- }
- auto lt = static_cast<LevelType>(ltInt);
- return isValidLT(lt) ? std::optional(lt) : std::nullopt;
+ return LevelType::buildLvlType(lf, properties, n, m);
}
-
inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
bool unique, uint64_t n = 0,
uint64_t m = 0) {
- std::vector<LevelPropertyNondefault> properties;
- if (!ordered)
- properties.push_back(LevelPropertyNondefault::Nonordered);
- if (!unique)
- properties.push_back(LevelPropertyNondefault::Nonunique);
- return buildLevelType(lf, properties, n, m);
+ return LevelType::buildLvlType(lf, ordered, unique, n, m);
}
-
-//
-// Ensure the above methods work as intended.
-//
-
-static_assert(
- (getLevelFormat(LevelType::Undef) == std::nullopt &&
- *getLevelFormat(LevelType::Dense) == LevelFormat::Dense &&
- *getLevelFormat(LevelType::Compressed) == LevelFormat::Compressed &&
- *getLevelFormat(LevelType::CompressedNu) == LevelFormat::Compressed &&
- *getLevelFormat(LevelType::CompressedNo) == LevelFormat::Compressed &&
- *getLevelFormat(LevelType::CompressedNuNo) == LevelFormat::Compressed &&
- *getLevelFormat(LevelType::Singleton) == LevelFormat::Singleton &&
- *getLevelFormat(LevelType::SingletonNu) == LevelFormat::Singleton &&
- *getLevelFormat(LevelType::SingletonNo) == LevelFormat::Singleton &&
- *getLevelFormat(LevelType::SingletonNuNo) == LevelFormat::Singleton &&
- *getLevelFormat(LevelType::LooseCompressed) ==
- LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::LooseCompressedNu) ==
- LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::LooseCompressedNo) ==
- LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::LooseCompressedNuNo) ==
- LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
- "getLevelFormat conversion is broken");
-
-static_assert(
- (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
- isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
- isValidLT(LevelType::CompressedNo) &&
- isValidLT(LevelType::CompressedNuNo) && isValidLT(LevelType::Singleton) &&
- isValidLT(LevelType::SingletonNu) && isValidLT(LevelType::SingletonNo) &&
- isValidLT(LevelType::SingletonNuNo) &&
- isValidLT(LevelType::LooseCompressed) &&
- isValidLT(LevelType::LooseCompressedNu) &&
- isValidLT(LevelType::LooseCompressedNo) &&
- isValidLT(LevelType::LooseCompressedNuNo) &&
- isValidLT(LevelType::NOutOfM)),
- "isValidLT definition is broken");
-
-static_assert((isDenseLT(LevelType::Dense) &&
- !isDenseLT(LevelType::Compressed) &&
- !isDenseLT(LevelType::CompressedNu) &&
- !isDenseLT(LevelType::CompressedNo) &&
- !isDenseLT(LevelType::CompressedNuNo) &&
- !isDenseLT(LevelType::Singleton) &&
- !isDenseLT(LevelType::SingletonNu) &&
- !isDenseLT(LevelType::SingletonNo) &&
- !isDenseLT(LevelType::SingletonNuNo) &&
- !isDenseLT(LevelType::LooseCompressed) &&
- !isDenseLT(LevelType::LooseCompressedNu) &&
- !isDenseLT(LevelType::LooseCompressedNo) &&
- !isDenseLT(LevelType::LooseCompressedNuNo) &&
- !isDenseLT(LevelType::NOutOfM)),
- "isDenseLT definition is broken");
-
-static_assert((!isCompressedLT(LevelType::Dense) &&
- isCompressedLT(LevelType::Compressed) &&
- isCompressedLT(LevelType::CompressedNu) &&
- isCompressedLT(LevelType::CompressedNo) &&
- isCompressedLT(LevelType::CompressedNuNo) &&
- !isCompressedLT(LevelType::Singleton) &&
- !isCompressedLT(LevelType::SingletonNu) &&
- !isCompressedLT(LevelType::SingletonNo) &&
- !isCompressedLT(LevelType::SingletonNuNo) &&
- !isCompressedLT(LevelType::LooseCompressed) &&
- !isCompressedLT(LevelType::LooseCompressedNu) &&
- !isCompressedLT(LevelType::LooseCompressedNo) &&
- !isCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isCompressedLT(LevelType::NOutOfM)),
- "isCompressedLT definition is broken");
-
-static_assert((!isSingletonLT(LevelType::Dense) &&
- !isSingletonLT(LevelType::Compressed) &&
- !isSingletonLT(LevelType::CompressedNu) &&
- !isSingletonLT(LevelType::CompressedNo) &&
- !isSingletonLT(LevelType::CompressedNuNo) &&
- isSingletonLT(LevelType::Singleton) &&
- isSingletonLT(LevelType::SingletonNu) &&
- isSingletonLT(LevelType::SingletonNo) &&
- isSingletonLT(LevelType::SingletonNuNo) &&
- !isSingletonLT(LevelType::LooseCompressed) &&
- !isSingletonLT(LevelType::LooseCompressedNu) &&
- !isSingletonLT(LevelType::LooseCompressedNo) &&
- !isSingletonLT(LevelType::LooseCompressedNuNo) &&
- !isSingletonLT(LevelType::NOutOfM)),
- "isSingletonLT definition is broken");
-
-static_assert((!isLooseCompressedLT(LevelType::Dense) &&
- !isLooseCompressedLT(LevelType::Compressed) &&
- !isLooseCompressedLT(LevelType::CompressedNu) &&
- !isLooseCompressedLT(LevelType::CompressedNo) &&
- !isLooseCompressedLT(LevelType::CompressedNuNo) &&
- !isLooseCompressedLT(LevelType::Singleton) &&
- !isLooseCompressedLT(LevelType::SingletonNu) &&
- !isLooseCompressedLT(LevelType::SingletonNo) &&
- !isLooseCompressedLT(LevelType::SingletonNuNo) &&
- isLooseCompressedLT(LevelType::LooseCompressed) &&
- ...
[truncated]
|
@llvm/pr-subscribers-mlir-sparse Author: Peiming Liu (PeimingLiu) ChangesPatch is 36.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81934.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 74cc0dee554a17..c7db5beb2015a6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -153,45 +153,9 @@ enum class Action : uint32_t {
kSortCOOInPlace = 8,
};
-/// This enum defines all the sparse representations supportable by
-/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// the "format" per se (dense, compressed, singleton, loose_compressed,
-/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
-/// the format is NOutOfM.
-/// The encoding is chosen for performance of the runtime library, and thus may
-/// change in future versions; consequently, client code should use the
-/// predicate functions defined below, rather than relying on knowledge
-/// about the particular binary encoding.
-///
-/// The `Undef` "format" is a special value used internally for cases
-/// where we need to store an undefined or indeterminate `LevelType`.
-/// It should not be used externally, since it does not indicate an
-/// actual/representable format.
-///
-/// Bit manipulations for LevelType:
-///
-/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
-///
-enum class LevelType : uint64_t {
- Undef = 0x000000000000,
- Dense = 0x000000010000,
- Compressed = 0x000000020000,
- CompressedNu = 0x000000020001,
- CompressedNo = 0x000000020002,
- CompressedNuNo = 0x000000020003,
- Singleton = 0x000000040000,
- SingletonNu = 0x000000040001,
- SingletonNo = 0x000000040002,
- SingletonNuNo = 0x000000040003,
- LooseCompressed = 0x000000080000,
- LooseCompressedNu = 0x000000080001,
- LooseCompressedNo = 0x000000080002,
- LooseCompressedNuNo = 0x000000080003,
- NOutOfM = 0x000000100000,
-};
-
/// This enum defines all supported storage format without the level properties.
enum class LevelFormat : uint64_t {
+ Undef = 0x00000000,
Dense = 0x00010000,
Compressed = 0x00020000,
Singleton = 0x00040000,
@@ -199,327 +163,240 @@ enum class LevelFormat : uint64_t {
NOutOfM = 0x00100000,
};
+template <LevelFormat... targets>
+constexpr bool isAnyOfFmt(LevelFormat fmt) {
+ return (... || (targets == fmt));
+}
+
+/// Returns string representation of the given level format.
+constexpr const char *toFormatString(LevelFormat lvlFmt) {
+ switch (lvlFmt) {
+ case LevelFormat::Undef:
+ return "undef";
+ case LevelFormat::Dense:
+ return "dense";
+ case LevelFormat::Compressed:
+ return "compressed";
+ case LevelFormat::Singleton:
+ return "singleton";
+ case LevelFormat::LooseCompressed:
+ return "loose_compressed";
+ case LevelFormat::NOutOfM:
+ return "structured";
+ }
+ return "";
+}
+
/// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint64_t {
+enum class LevelPropNonDefault : uint64_t {
Nonunique = 0x0001,
Nonordered = 0x0002,
};
-/// Get N of NOutOfM level type.
-constexpr uint64_t getN(LevelType lt) {
- return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+/// Returns string representation of the given level properties.
+constexpr const char *toPropString(LevelPropNonDefault lvlProp) {
+ switch (lvlProp) {
+ case LevelPropNonDefault::Nonunique:
+ return "nonunique";
+ case LevelPropNonDefault::Nonordered:
+ return "nonordered";
+ }
+ return "";
}
-/// Get M of NOutOfM level type.
-constexpr uint64_t getM(LevelType lt) {
- return (static_cast<uint64_t>(lt) >> 40) & 0xff;
-}
+/// This enum defines all the sparse representations supportable by
+/// the SparseTensor dialect. We use a lightweight encoding to encode
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
+/// the format is NOutOfM.
+/// The encoding is chosen for performance of the runtime library, and thus may
+/// change in future versions; consequently, client code should use the
+/// predicate functions defined below, rather than relying on knowledge
+/// about the particular binary encoding.
+///
+/// The `Undef` "format" is a special value used internally for cases
+/// where we need to store an undefined or indeterminate `LevelType`.
+/// It should not be used externally, since it does not indicate an
+/// actual/representable format.
-/// Convert N of NOutOfM level type to the stored bits.
-constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+struct LevelType {
+public:
+ /// Check that the `LevelType` contains a valid (possibly undefined) value.
+ static constexpr bool isValidLvlBits(uint64_t lvlBits) {
+ auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
+ const uint64_t propertyBits = lvlBits & 0xffff;
+ // If undefined/dense/NOutOfM, then must be unique and ordered.
+ // Otherwise, the format must be one of the known ones.
+ return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
+ LevelFormat::NOutOfM>(fmt))
+ ? (propertyBits == 0)
+ : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
+ LevelFormat::LooseCompressed>(fmt));
+ }
-/// Convert M of NOutOfM level type to the stored bits.
-constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+ /// Convert a LevelFormat to its corresponding LevelType with the given
+ /// properties. Returns std::nullopt when the properties are not applicable
+ /// for the input level format.
+ static std::optional<LevelType>
+ buildLvlType(LevelFormat lf,
+ const std::vector<LevelPropNonDefault> &properties,
+ uint64_t n = 0, uint64_t m = 0) {
+ assert((n & 0xff) == n && (m & 0xff) == m);
+ uint64_t newN = n << 32;
+ uint64_t newM = m << 40;
+ uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM;
+ for (auto p : properties)
+ ltBits |= static_cast<uint64_t>(p);
+
+ return isValidLvlBits(ltBits) ? std::optional(LevelType(ltBits))
+ : std::nullopt;
+ }
+ static std::optional<LevelType> buildLvlType(LevelFormat lf, bool ordered,
+ bool unique, uint64_t n = 0,
+ uint64_t m = 0) {
+ std::vector<LevelPropNonDefault> properties;
+ if (!ordered)
+ properties.push_back(LevelPropNonDefault::Nonordered);
+ if (!unique)
+ properties.push_back(LevelPropNonDefault::Nonunique);
+ return buildLvlType(lf, properties, n, m);
+ }
-/// Check if the `LevelType` is NOutOfM (regardless of
-/// properties and block sizes).
-constexpr bool isNOutOfMLT(LevelType lt) {
- return ((static_cast<uint64_t>(lt) & 0x100000) ==
- static_cast<uint64_t>(LevelType::NOutOfM));
-}
+ /// Explicit conversion from uint64_t.
+ constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) {
+ assert(isValidLvlBits(bits));
+ };
-/// Check if the `LevelType` is NOutOfM with the correct block sizes.
-constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
- return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
-}
+ /// Constructs a LevelType with the given format using all default properties.
+ /*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast<uint64_t>(f)) {
+ assert(isValidLvlBits(lvlBits) && !isa<LevelFormat::NOutOfM>());
+ };
-/// Returns string representation of the given dimension level type.
-constexpr const char *toMLIRString(LevelType lvlType) {
- auto lt = static_cast<LevelType>(static_cast<uint64_t>(lvlType) & 0xffffffff);
- switch (lt) {
- case LevelType::Undef:
- return "undef";
- case LevelType::Dense:
- return "dense";
- case LevelType::Compressed:
- return "compressed";
- case LevelType::CompressedNu:
- return "compressed(nonunique)";
- case LevelType::CompressedNo:
- return "compressed(nonordered)";
- case LevelType::CompressedNuNo:
- return "compressed(nonunique, nonordered)";
- case LevelType::Singleton:
- return "singleton";
- case LevelType::SingletonNu:
- return "singleton(nonunique)";
- case LevelType::SingletonNo:
- return "singleton(nonordered)";
- case LevelType::SingletonNuNo:
- return "singleton(nonunique, nonordered)";
- case LevelType::LooseCompressed:
- return "loose_compressed";
- case LevelType::LooseCompressedNu:
- return "loose_compressed(nonunique)";
- case LevelType::LooseCompressedNo:
- return "loose_compressed(nonordered)";
- case LevelType::LooseCompressedNuNo:
- return "loose_compressed(nonunique, nonordered)";
- case LevelType::NOutOfM:
- return "structured";
+ /// Converts to uint64_t
+ explicit operator uint64_t() const { return lvlBits; }
+
+ bool operator==(const LevelType lhs) const {
+ return static_cast<uint64_t>(lhs) == lvlBits;
}
- return "";
-}
+ bool operator!=(const LevelType lhs) const { return !(*this == lhs); }
-/// Check that the `LevelType` contains a valid (possibly undefined) value.
-constexpr bool isValidLT(LevelType lt) {
- const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
- const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
- // If undefined/dense/NOutOfM, then must be unique and ordered.
- // Otherwise, the format must be one of the known ones.
- return (formatBits <= 0x10000 || formatBits == 0x100000)
- ? (propertyBits == 0)
- : (formatBits == 0x20000 || formatBits == 0x40000 ||
- formatBits == 0x80000);
-}
+ LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); }
-/// Check if the `LevelType` is the special undefined value.
-constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
+ /// Get N of NOutOfM level type.
+ constexpr uint64_t getN() const {
+ assert(isa<LevelFormat::NOutOfM>());
+ return (lvlBits >> 32) & 0xff;
+ }
-/// Check if the `LevelType` is dense (regardless of properties).
-constexpr bool isDenseLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~0xffff) ==
- static_cast<uint64_t>(LevelType::Dense);
-}
+ /// Get M of NOutOfM level type.
+ constexpr uint64_t getM() const {
+ assert(isa<LevelFormat::NOutOfM>());
+ return (lvlBits >> 40) & 0xff;
+ }
-/// Check if the `LevelType` is compressed (regardless of properties).
-constexpr bool isCompressedLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~0xffff) ==
- static_cast<uint64_t>(LevelType::Compressed);
-}
+ /// Get the `LevelFormat` of the `LevelType`.
+ LevelFormat getLvlFmt() const {
+ return static_cast<LevelFormat>(lvlBits & 0xffff0000);
+ }
-/// Check if the `LevelType` is singleton (regardless of properties).
-constexpr bool isSingletonLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~0xffff) ==
- static_cast<uint64_t>(LevelType::Singleton);
-}
+ /// Check if the `LevelType` is in the `LevelFormat`.
+ template <LevelFormat fmt>
+ constexpr bool isa() const {
+ return getLvlFmt() == fmt;
+ }
-/// Check if the `LevelType` is loose compressed (regardless of properties).
-constexpr bool isLooseCompressedLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~0xffff) ==
- static_cast<uint64_t>(LevelType::LooseCompressed);
-}
+ /// Check if the `LevelType` has the properties
+ template <LevelPropNonDefault p>
+ constexpr bool isa() const {
+ return lvlBits & static_cast<uint64_t>(p);
+ }
-/// Check if the `LevelType` needs positions array.
-constexpr bool isWithPosLT(LevelType lt) {
- return isCompressedLT(lt) || isLooseCompressedLT(lt);
-}
+ /// Check if the `LevelType` needs positions array.
+ constexpr bool isWithPosLT() const {
+ return isa<LevelFormat::Compressed>() ||
+ isa<LevelFormat::LooseCompressed>();
+ }
-/// Check if the `LevelType` needs coordinates array.
-constexpr bool isWithCrdLT(LevelType lt) {
- return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- isNOutOfMLT(lt);
-}
+ /// Check if the `LevelType` needs coordinates array.
+ constexpr bool isWithCrdLT() const {
+ // All sparse levels has coordinate array.
+ return !isa<LevelFormat::Dense>();
+ }
-/// Check if the `LevelType` is ordered (regardless of storage format).
-constexpr bool isOrderedLT(LevelType lt) {
- return !(static_cast<uint64_t>(lt) & 2);
- return !(static_cast<uint64_t>(lt) & 2);
-}
+ std::string toMLIRString() const {
+ std::string lvlStr = toFormatString(getLvlFmt());
+ std::string propStr = "";
+ if (isa<LevelPropNonDefault::Nonunique>())
+ propStr += toPropString(LevelPropNonDefault::Nonunique);
+
+ if (isa<LevelPropNonDefault::Nonordered>()) {
+ if (!propStr.empty())
+ propStr += ", ";
+ propStr += toPropString(LevelPropNonDefault::Nonordered);
+ }
+ if (!propStr.empty())
+ lvlStr += ("(" + propStr + ")");
+ return lvlStr;
+ }
-/// Check if the `LevelType` is unique (regardless of storage format).
-constexpr bool isUniqueLT(LevelType lt) {
- return !(static_cast<uint64_t>(lt) & 1);
- return !(static_cast<uint64_t>(lt) & 1);
-}
+private:
+ /// Bit manipulations for LevelType:
+ ///
+ /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+ ///
+ uint64_t lvlBits;
+};
-/// Convert a LevelType to its corresponding LevelFormat.
-/// Returns std::nullopt when input lt is Undef.
-constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
- if (lt == LevelType::Undef)
- return std::nullopt;
- return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
-}
+// For backward-compatibility. TODO: remove below after fully migration.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
-/// Convert a LevelFormat to its corresponding LevelType with the given
-/// properties. Returns std::nullopt when the properties are not applicable
-/// for the input level format.
inline std::optional<LevelType>
buildLevelType(LevelFormat lf,
- const std::vector<LevelPropertyNondefault> &properties,
+ const std::vector<LevelPropNonDefault> &properties,
uint64_t n = 0, uint64_t m = 0) {
- uint64_t newN = n << 32;
- uint64_t newM = m << 40;
- uint64_t ltInt = static_cast<uint64_t>(lf) | newN | newM;
- for (auto p : properties) {
- ltInt |= static_cast<uint64_t>(p);
- }
- auto lt = static_cast<LevelType>(ltInt);
- return isValidLT(lt) ? std::optional(lt) : std::nullopt;
+ return LevelType::buildLvlType(lf, properties, n, m);
}
-
inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
bool unique, uint64_t n = 0,
uint64_t m = 0) {
- std::vector<LevelPropertyNondefault> properties;
- if (!ordered)
- properties.push_back(LevelPropertyNondefault::Nonordered);
- if (!unique)
- properties.push_back(LevelPropertyNondefault::Nonunique);
- return buildLevelType(lf, properties, n, m);
+ return LevelType::buildLvlType(lf, ordered, unique, n, m);
}
-
-//
-// Ensure the above methods work as intended.
-//
-
-static_assert(
- (getLevelFormat(LevelType::Undef) == std::nullopt &&
- *getLevelFormat(LevelType::Dense) == LevelFormat::Dense &&
- *getLevelFormat(LevelType::Compressed) == LevelFormat::Compressed &&
- *getLevelFormat(LevelType::CompressedNu) == LevelFormat::Compressed &&
- *getLevelFormat(LevelType::CompressedNo) == LevelFormat::Compressed &&
- *getLevelFormat(LevelType::CompressedNuNo) == LevelFormat::Compressed &&
- *getLevelFormat(LevelType::Singleton) == LevelFormat::Singleton &&
- *getLevelFormat(LevelType::SingletonNu) == LevelFormat::Singleton &&
- *getLevelFormat(LevelType::SingletonNo) == LevelFormat::Singleton &&
- *getLevelFormat(LevelType::SingletonNuNo) == LevelFormat::Singleton &&
- *getLevelFormat(LevelType::LooseCompressed) ==
- LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::LooseCompressedNu) ==
- LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::LooseCompressedNo) ==
- LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::LooseCompressedNuNo) ==
- LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
- "getLevelFormat conversion is broken");
-
-static_assert(
- (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
- isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
- isValidLT(LevelType::CompressedNo) &&
- isValidLT(LevelType::CompressedNuNo) && isValidLT(LevelType::Singleton) &&
- isValidLT(LevelType::SingletonNu) && isValidLT(LevelType::SingletonNo) &&
- isValidLT(LevelType::SingletonNuNo) &&
- isValidLT(LevelType::LooseCompressed) &&
- isValidLT(LevelType::LooseCompressedNu) &&
- isValidLT(LevelType::LooseCompressedNo) &&
- isValidLT(LevelType::LooseCompressedNuNo) &&
- isValidLT(LevelType::NOutOfM)),
- "isValidLT definition is broken");
-
-static_assert((isDenseLT(LevelType::Dense) &&
- !isDenseLT(LevelType::Compressed) &&
- !isDenseLT(LevelType::CompressedNu) &&
- !isDenseLT(LevelType::CompressedNo) &&
- !isDenseLT(LevelType::CompressedNuNo) &&
- !isDenseLT(LevelType::Singleton) &&
- !isDenseLT(LevelType::SingletonNu) &&
- !isDenseLT(LevelType::SingletonNo) &&
- !isDenseLT(LevelType::SingletonNuNo) &&
- !isDenseLT(LevelType::LooseCompressed) &&
- !isDenseLT(LevelType::LooseCompressedNu) &&
- !isDenseLT(LevelType::LooseCompressedNo) &&
- !isDenseLT(LevelType::LooseCompressedNuNo) &&
- !isDenseLT(LevelType::NOutOfM)),
- "isDenseLT definition is broken");
-
-static_assert((!isCompressedLT(LevelType::Dense) &&
- isCompressedLT(LevelType::Compressed) &&
- isCompressedLT(LevelType::CompressedNu) &&
- isCompressedLT(LevelType::CompressedNo) &&
- isCompressedLT(LevelType::CompressedNuNo) &&
- !isCompressedLT(LevelType::Singleton) &&
- !isCompressedLT(LevelType::SingletonNu) &&
- !isCompressedLT(LevelType::SingletonNo) &&
- !isCompressedLT(LevelType::SingletonNuNo) &&
- !isCompressedLT(LevelType::LooseCompressed) &&
- !isCompressedLT(LevelType::LooseCompressedNu) &&
- !isCompressedLT(LevelType::LooseCompressedNo) &&
- !isCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isCompressedLT(LevelType::NOutOfM)),
- "isCompressedLT definition is broken");
-
-static_assert((!isSingletonLT(LevelType::Dense) &&
- !isSingletonLT(LevelType::Compressed) &&
- !isSingletonLT(LevelType::CompressedNu) &&
- !isSingletonLT(LevelType::CompressedNo) &&
- !isSingletonLT(LevelType::CompressedNuNo) &&
- isSingletonLT(LevelType::Singleton) &&
- isSingletonLT(LevelType::SingletonNu) &&
- isSingletonLT(LevelType::SingletonNo) &&
- isSingletonLT(LevelType::SingletonNuNo) &&
- !isSingletonLT(LevelType::LooseCompressed) &&
- !isSingletonLT(LevelType::LooseCompressedNu) &&
- !isSingletonLT(LevelType::LooseCompressedNo) &&
- !isSingletonLT(LevelType::LooseCompressedNuNo) &&
- !isSingletonLT(LevelType::NOutOfM)),
- "isSingletonLT definition is broken");
-
-static_assert((!isLooseCompressedLT(LevelType::Dense) &&
- !isLooseCompressedLT(LevelType::Compressed) &&
- !isLooseCompressedLT(LevelType::CompressedNu) &&
- !isLooseCompressedLT(LevelType::CompressedNo) &&
- !isLooseCompressedLT(LevelType::CompressedNuNo) &&
- !isLooseCompressedLT(LevelType::Singleton) &&
- !isLooseCompressedLT(LevelType::SingletonNu) &&
- !isLooseCompressedLT(LevelType::SingletonNo) &&
- !isLooseCompressedLT(LevelType::SingletonNuNo) &&
- isLooseCompressedLT(LevelType::LooseCompressed) &&
- ...
[truncated]
|
lgtm post facto |
No description provided.