Skip to content

Commit ee3ee13

Browse files
authored
[mlir][sparse] cleanup of enums header (#71090)
Some DLT related methods leaked into sparse_tensor.h, and this moves it back to the right header. Also, the asserts were incomplete and some DLT methods duplicated.
1 parent 06145dc commit ee3ee13

File tree

4 files changed

+122
-68
lines changed

4 files changed

+122
-68
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 119 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -153,24 +153,18 @@ enum class Action : uint32_t {
153153
};
154154

155155
/// This enum defines all the sparse representations supportable by
156-
/// the SparseTensor dialect. We use a lightweight encoding to encode
157-
/// both the "format" per se (dense, compressed, singleton) as well as
158-
/// the "properties" (ordered, unique). The encoding is chosen for
159-
/// performance of the runtime library, and thus may change in future
160-
/// versions; consequently, client code should use the predicate functions
161-
/// defined below, rather than relying on knowledge about the particular
162-
/// binary encoding.
156+
/// the SparseTensor dialect. We use a lightweight encoding to encode
157+
/// both the "format" per se (dense, compressed, singleton, loose_compressed,
158+
/// two-out-of-four) as well as the "properties" (ordered, unique). The
159+
/// encoding is chosen for performance of the runtime library, and thus may
160+
/// change in future versions; consequently, client code should use the
161+
/// predicate functions defined below, rather than relying on knowledge
162+
/// about the particular binary encoding.
163163
///
164164
/// The `Undef` "format" is a special value used internally for cases
165165
/// where we need to store an undefined or indeterminate `DimLevelType`.
166166
/// It should not be used externally, since it does not indicate an
167167
/// actual/representable format.
168-
///
169-
// TODO: We should generalize TwoOutOfFour to N out of M and use property to
170-
// encode the value of N and M.
171-
// TODO: Update DimLevelType to use lower 8 bits for storage formats and the
172-
// higher 4 bits to store level properties. Consider LooseCompressed and
173-
// TwoOutOfFour as properties instead of formats.
174168
enum class DimLevelType : uint8_t {
175169
Undef = 0, // 0b00000_00
176170
Dense = 4, // 0b00001_00
@@ -257,44 +251,47 @@ constexpr bool isUndefDLT(DimLevelType dlt) {
257251
return dlt == DimLevelType::Undef;
258252
}
259253

260-
/// Check if the `DimLevelType` is dense.
254+
/// Check if the `DimLevelType` is dense (regardless of properties).
261255
constexpr bool isDenseDLT(DimLevelType dlt) {
262-
return dlt == DimLevelType::Dense;
263-
}
264-
265-
/// Check if the `DimLevelType` is 2:4
266-
constexpr bool isTwoOutOfFourDLT(DimLevelType dlt) {
267-
return dlt == DimLevelType::TwoOutOfFour;
256+
return (static_cast<uint8_t>(dlt) & ~3) ==
257+
static_cast<uint8_t>(DimLevelType::Dense);
268258
}
269259

270-
// We use the idiom `(dlt & ~3) == format` in order to only return true
271-
// for valid DLTs. Whereas the `dlt & format` idiom is a bit faster but
272-
// can return false-positives on invalid DLTs.
273-
274260
/// Check if the `DimLevelType` is compressed (regardless of properties).
275261
constexpr bool isCompressedDLT(DimLevelType dlt) {
276262
return (static_cast<uint8_t>(dlt) & ~3) ==
277263
static_cast<uint8_t>(DimLevelType::Compressed);
278264
}
279265

280-
/// Check if the `DimLevelType` is loose compressed (regardless of properties).
281-
constexpr bool isLooseCompressedDLT(DimLevelType dlt) {
282-
return (static_cast<uint8_t>(dlt) & ~3) ==
283-
static_cast<uint8_t>(DimLevelType::LooseCompressed);
284-
}
285-
286266
/// Check if the `DimLevelType` is singleton (regardless of properties).
287267
constexpr bool isSingletonDLT(DimLevelType dlt) {
288268
return (static_cast<uint8_t>(dlt) & ~3) ==
289269
static_cast<uint8_t>(DimLevelType::Singleton);
290270
}
291271

272+
/// Check if the `DimLevelType` is loose compressed (regardless of properties).
273+
constexpr bool isLooseCompressedDLT(DimLevelType dlt) {
274+
return (static_cast<uint8_t>(dlt) & ~3) ==
275+
static_cast<uint8_t>(DimLevelType::LooseCompressed);
276+
}
277+
292278
/// Check if the `DimLevelType` is 2OutOf4 (regardless of properties).
293279
constexpr bool is2OutOf4DLT(DimLevelType dlt) {
294280
return (static_cast<uint8_t>(dlt) & ~3) ==
295281
static_cast<uint8_t>(DimLevelType::TwoOutOfFour);
296282
}
297283

284+
/// Check if the `DimLevelType` needs positions array.
285+
constexpr bool isDLTWithPos(DimLevelType dlt) {
286+
return isCompressedDLT(dlt) || isLooseCompressedDLT(dlt);
287+
}
288+
289+
/// Check if the `DimLevelType` needs coordinates array.
290+
constexpr bool isDLTWithCrd(DimLevelType dlt) {
291+
return isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
292+
isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt);
293+
}
294+
298295
/// Check if the `DimLevelType` is ordered (regardless of storage format).
299296
constexpr bool isOrderedDLT(DimLevelType dlt) {
300297
return !(static_cast<uint8_t>(dlt) & 2);
@@ -325,7 +322,10 @@ buildLevelType(LevelFormat lf, bool ordered, bool unique) {
325322
return isValidDLT(dlt) ? std::optional(dlt) : std::nullopt;
326323
}
327324

328-
/// Ensure the above conversion works as intended.
325+
//
326+
// Ensure the above methods work as indended.
327+
//
328+
329329
static_assert(
330330
(getLevelFormat(DimLevelType::Undef) == std::nullopt &&
331331
*getLevelFormat(DimLevelType::Dense) == LevelFormat::Dense &&
@@ -336,19 +336,23 @@ static_assert(
336336
*getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton &&
337337
*getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton &&
338338
*getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton &&
339-
*getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton),
339+
*getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton &&
340+
*getLevelFormat(DimLevelType::LooseCompressed) ==
341+
LevelFormat::LooseCompressed &&
342+
*getLevelFormat(DimLevelType::LooseCompressedNu) ==
343+
LevelFormat::LooseCompressed &&
344+
*getLevelFormat(DimLevelType::LooseCompressedNo) ==
345+
LevelFormat::LooseCompressed &&
346+
*getLevelFormat(DimLevelType::LooseCompressedNuNo) ==
347+
LevelFormat::LooseCompressed &&
348+
*getLevelFormat(DimLevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
340349
"getLevelFormat conversion is broken");
341350

342351
static_assert(
343352
(buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
344353
buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
345354
buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
346355
*buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
347-
buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
348-
buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
349-
buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
350-
*buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
351-
DimLevelType::TwoOutOfFour &&
352356
*buildLevelType(LevelFormat::Compressed, true, true) ==
353357
DimLevelType::Compressed &&
354358
*buildLevelType(LevelFormat::Compressed, true, false) ==
@@ -364,10 +368,22 @@ static_assert(
364368
*buildLevelType(LevelFormat::Singleton, false, true) ==
365369
DimLevelType::SingletonNo &&
366370
*buildLevelType(LevelFormat::Singleton, false, false) ==
367-
DimLevelType::SingletonNuNo),
371+
DimLevelType::SingletonNuNo &&
372+
*buildLevelType(LevelFormat::LooseCompressed, true, true) ==
373+
DimLevelType::LooseCompressed &&
374+
*buildLevelType(LevelFormat::LooseCompressed, true, false) ==
375+
DimLevelType::LooseCompressedNu &&
376+
*buildLevelType(LevelFormat::LooseCompressed, false, true) ==
377+
DimLevelType::LooseCompressedNo &&
378+
*buildLevelType(LevelFormat::LooseCompressed, false, false) ==
379+
DimLevelType::LooseCompressedNuNo &&
380+
buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
381+
buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
382+
buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
383+
*buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
384+
DimLevelType::TwoOutOfFour),
368385
"buildLevelType conversion is broken");
369386

370-
// Ensure the above predicates work as intended.
371387
static_assert((isValidDLT(DimLevelType::Undef) &&
372388
isValidDLT(DimLevelType::Dense) &&
373389
isValidDLT(DimLevelType::Compressed) &&
@@ -385,6 +401,22 @@ static_assert((isValidDLT(DimLevelType::Undef) &&
385401
isValidDLT(DimLevelType::TwoOutOfFour)),
386402
"isValidDLT definition is broken");
387403

404+
static_assert((isDenseDLT(DimLevelType::Dense) &&
405+
!isDenseDLT(DimLevelType::Compressed) &&
406+
!isDenseDLT(DimLevelType::CompressedNu) &&
407+
!isDenseDLT(DimLevelType::CompressedNo) &&
408+
!isDenseDLT(DimLevelType::CompressedNuNo) &&
409+
!isDenseDLT(DimLevelType::Singleton) &&
410+
!isDenseDLT(DimLevelType::SingletonNu) &&
411+
!isDenseDLT(DimLevelType::SingletonNo) &&
412+
!isDenseDLT(DimLevelType::SingletonNuNo) &&
413+
!isDenseDLT(DimLevelType::LooseCompressed) &&
414+
!isDenseDLT(DimLevelType::LooseCompressedNu) &&
415+
!isDenseDLT(DimLevelType::LooseCompressedNo) &&
416+
!isDenseDLT(DimLevelType::LooseCompressedNuNo) &&
417+
!isDenseDLT(DimLevelType::TwoOutOfFour)),
418+
"isDenseDLT definition is broken");
419+
388420
static_assert((!isCompressedDLT(DimLevelType::Dense) &&
389421
isCompressedDLT(DimLevelType::Compressed) &&
390422
isCompressedDLT(DimLevelType::CompressedNu) &&
@@ -393,20 +425,14 @@ static_assert((!isCompressedDLT(DimLevelType::Dense) &&
393425
!isCompressedDLT(DimLevelType::Singleton) &&
394426
!isCompressedDLT(DimLevelType::SingletonNu) &&
395427
!isCompressedDLT(DimLevelType::SingletonNo) &&
396-
!isCompressedDLT(DimLevelType::SingletonNuNo)),
428+
!isCompressedDLT(DimLevelType::SingletonNuNo) &&
429+
!isCompressedDLT(DimLevelType::LooseCompressed) &&
430+
!isCompressedDLT(DimLevelType::LooseCompressedNu) &&
431+
!isCompressedDLT(DimLevelType::LooseCompressedNo) &&
432+
!isCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
433+
!isCompressedDLT(DimLevelType::TwoOutOfFour)),
397434
"isCompressedDLT definition is broken");
398435

399-
static_assert((!isLooseCompressedDLT(DimLevelType::Dense) &&
400-
isLooseCompressedDLT(DimLevelType::LooseCompressed) &&
401-
isLooseCompressedDLT(DimLevelType::LooseCompressedNu) &&
402-
isLooseCompressedDLT(DimLevelType::LooseCompressedNo) &&
403-
isLooseCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
404-
!isLooseCompressedDLT(DimLevelType::Singleton) &&
405-
!isLooseCompressedDLT(DimLevelType::SingletonNu) &&
406-
!isLooseCompressedDLT(DimLevelType::SingletonNo) &&
407-
!isLooseCompressedDLT(DimLevelType::SingletonNuNo)),
408-
"isLooseCompressedDLT definition is broken");
409-
410436
static_assert((!isSingletonDLT(DimLevelType::Dense) &&
411437
!isSingletonDLT(DimLevelType::Compressed) &&
412438
!isSingletonDLT(DimLevelType::CompressedNu) &&
@@ -415,11 +441,47 @@ static_assert((!isSingletonDLT(DimLevelType::Dense) &&
415441
isSingletonDLT(DimLevelType::Singleton) &&
416442
isSingletonDLT(DimLevelType::SingletonNu) &&
417443
isSingletonDLT(DimLevelType::SingletonNo) &&
418-
isSingletonDLT(DimLevelType::SingletonNuNo)),
444+
isSingletonDLT(DimLevelType::SingletonNuNo) &&
445+
!isSingletonDLT(DimLevelType::LooseCompressed) &&
446+
!isSingletonDLT(DimLevelType::LooseCompressedNu) &&
447+
!isSingletonDLT(DimLevelType::LooseCompressedNo) &&
448+
!isSingletonDLT(DimLevelType::LooseCompressedNuNo) &&
449+
!isSingletonDLT(DimLevelType::TwoOutOfFour)),
419450
"isSingletonDLT definition is broken");
420451

452+
static_assert((!isLooseCompressedDLT(DimLevelType::Dense) &&
453+
!isLooseCompressedDLT(DimLevelType::Compressed) &&
454+
!isLooseCompressedDLT(DimLevelType::CompressedNu) &&
455+
!isLooseCompressedDLT(DimLevelType::CompressedNo) &&
456+
!isLooseCompressedDLT(DimLevelType::CompressedNuNo) &&
457+
!isLooseCompressedDLT(DimLevelType::Singleton) &&
458+
!isLooseCompressedDLT(DimLevelType::SingletonNu) &&
459+
!isLooseCompressedDLT(DimLevelType::SingletonNo) &&
460+
!isLooseCompressedDLT(DimLevelType::SingletonNuNo) &&
461+
isLooseCompressedDLT(DimLevelType::LooseCompressed) &&
462+
isLooseCompressedDLT(DimLevelType::LooseCompressedNu) &&
463+
isLooseCompressedDLT(DimLevelType::LooseCompressedNo) &&
464+
isLooseCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
465+
!isLooseCompressedDLT(DimLevelType::TwoOutOfFour)),
466+
"isLooseCompressedDLT definition is broken");
467+
468+
static_assert((!is2OutOf4DLT(DimLevelType::Dense) &&
469+
!is2OutOf4DLT(DimLevelType::Compressed) &&
470+
!is2OutOf4DLT(DimLevelType::CompressedNu) &&
471+
!is2OutOf4DLT(DimLevelType::CompressedNo) &&
472+
!is2OutOf4DLT(DimLevelType::CompressedNuNo) &&
473+
!is2OutOf4DLT(DimLevelType::Singleton) &&
474+
!is2OutOf4DLT(DimLevelType::SingletonNu) &&
475+
!is2OutOf4DLT(DimLevelType::SingletonNo) &&
476+
!is2OutOf4DLT(DimLevelType::SingletonNuNo) &&
477+
!is2OutOf4DLT(DimLevelType::LooseCompressed) &&
478+
!is2OutOf4DLT(DimLevelType::LooseCompressedNu) &&
479+
!is2OutOf4DLT(DimLevelType::LooseCompressedNo) &&
480+
!is2OutOf4DLT(DimLevelType::LooseCompressedNuNo) &&
481+
is2OutOf4DLT(DimLevelType::TwoOutOfFour)),
482+
"is2OutOf4DLT definition is broken");
483+
421484
static_assert((isOrderedDLT(DimLevelType::Dense) &&
422-
isOrderedDLT(DimLevelType::TwoOutOfFour) &&
423485
isOrderedDLT(DimLevelType::Compressed) &&
424486
isOrderedDLT(DimLevelType::CompressedNu) &&
425487
!isOrderedDLT(DimLevelType::CompressedNo) &&
@@ -431,11 +493,11 @@ static_assert((isOrderedDLT(DimLevelType::Dense) &&
431493
isOrderedDLT(DimLevelType::LooseCompressed) &&
432494
isOrderedDLT(DimLevelType::LooseCompressedNu) &&
433495
!isOrderedDLT(DimLevelType::LooseCompressedNo) &&
434-
!isOrderedDLT(DimLevelType::LooseCompressedNuNo)),
496+
!isOrderedDLT(DimLevelType::LooseCompressedNuNo) &&
497+
isOrderedDLT(DimLevelType::TwoOutOfFour)),
435498
"isOrderedDLT definition is broken");
436499

437500
static_assert((isUniqueDLT(DimLevelType::Dense) &&
438-
isUniqueDLT(DimLevelType::TwoOutOfFour) &&
439501
isUniqueDLT(DimLevelType::Compressed) &&
440502
!isUniqueDLT(DimLevelType::CompressedNu) &&
441503
isUniqueDLT(DimLevelType::CompressedNo) &&
@@ -447,7 +509,8 @@ static_assert((isUniqueDLT(DimLevelType::Dense) &&
447509
isUniqueDLT(DimLevelType::LooseCompressed) &&
448510
!isUniqueDLT(DimLevelType::LooseCompressedNu) &&
449511
isUniqueDLT(DimLevelType::LooseCompressedNo) &&
450-
!isUniqueDLT(DimLevelType::LooseCompressedNuNo)),
512+
!isUniqueDLT(DimLevelType::LooseCompressedNuNo) &&
513+
isUniqueDLT(DimLevelType::TwoOutOfFour)),
451514
"isUniqueDLT definition is broken");
452515

453516
/// Bit manipulations for affine encoding.

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,6 @@ inline MemRefType getMemRefType(T &&t) {
8989
/// Returns null-attribute for any type without an encoding.
9090
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
9191

92-
/// Convenience method to query whether a given DLT needs both position and
93-
/// coordinates array or only coordinates array.
94-
constexpr inline bool isDLTWithPos(DimLevelType dlt) {
95-
return isLooseCompressedDLT(dlt) || isCompressedDLT(dlt);
96-
}
97-
constexpr inline bool isDLTWithCrd(DimLevelType dlt) {
98-
return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
99-
isCompressedDLT(dlt);
100-
}
101-
10292
/// Returns true iff the given sparse tensor encoding attribute has a trailing
10393
/// COO region starting at the given level.
10494
bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
371371
::mlir::sparse_tensor::DimLevelType getLvlType(::mlir::sparse_tensor::Level l) const;
372372

373373
bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseDLT(getLvlType(l)); }
374-
bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return isTwoOutOfFourDLT(getLvlType(l)); }
375374
bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedDLT(getLvlType(l)); }
376-
bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedDLT(getLvlType(l)); }
377375
bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonDLT(getLvlType(l)); }
376+
bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedDLT(getLvlType(l)); }
377+
bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4DLT(getLvlType(l)); }
378378
bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedDLT(getLvlType(l)); }
379379
bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueDLT(getLvlType(l)); }
380380

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "Detail/DimLvlMapParser.h"
1212

13+
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
1314
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1415
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
1516
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"

0 commit comments

Comments
 (0)