Skip to content

Commit e5924d6

Browse files
[mlir][sparse] Implement parsing n out of m (llvm#79935)
1. Add parsing methods for block[n, m]. 2. Encode n and m with the newly extended 64-bit LevelType enum. 3. Update 2:4 methods names/comments to n:m.
1 parent 3b57b64 commit e5924d6

File tree

26 files changed

+302
-174
lines changed

26 files changed

+302
-174
lines changed

mlir/include/mlir-c/Dialect/SparseTensor.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
2828
typedef uint64_t MlirSparseTensorLevelType;
2929

3030
enum MlirBaseSparseTensorLevelType {
31-
MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00
32-
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00
33-
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01
34-
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10, // 0b00010_10
35-
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11
36-
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16, // 0b00100_00
37-
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17, // 0b00100_01
38-
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18, // 0b00100_10
39-
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11
40-
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00
41-
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01
42-
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10
43-
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11
44-
MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00
31+
MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
32+
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
33+
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001,
34+
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 0x000000020002,
35+
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 0x000000020003,
36+
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
37+
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 0x000000040001,
38+
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 0x000000040002,
39+
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 0x000000040003,
40+
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
41+
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 0x000000080001,
42+
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002,
43+
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003,
44+
MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
4545
};
4646

4747
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 128 additions & 77 deletions
Large diffs are not rendered by default.

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
145145
- **compressed** : only nonzeros along this level are stored
146146
- **loose_compressed** : as compressed, but allows for free space between regions
147147
- **singleton** : a variant of the compressed format, where coordinates have no siblings
148-
- **block2_4** : the compression uses a 2:4 encoding per 1x4 block
148+
- **structured[n, m]** : the compression uses a n:m encoding
149+
(viz. n out of m consecutive elements are nonzero)
149150

150151
For a compressed level, each position interval is represented in a compact
151152
way with a lowerbound `pos(i)` and an upperbound `pos(i+1) - 1`, which implies
@@ -374,7 +375,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
374375
bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedLT(getLvlType(l)); }
375376
bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonLT(getLvlType(l)); }
376377
bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedLT(getLvlType(l)); }
377-
bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4LT(getLvlType(l)); }
378+
bool isNOutOfMLvl(::mlir::sparse_tensor::Level l) const { return isNOutOfMLT(getLvlType(l)); }
378379
bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedLT(getLvlType(l)); }
379380
bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueLT(getLvlType(l)); }
380381

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ class SparseTensorType {
291291
return isLooseCompressedLT(getLvlType(l));
292292
}
293293
bool isSingletonLvl(Level l) const { return isSingletonLT(getLvlType(l)); }
294-
bool is2OutOf4Lvl(Level l) const { return is2OutOf4LT(getLvlType(l)); }
294+
bool isNOutOfMLvl(Level l) const { return isNOutOfMLT(getLvlType(l)); }
295295
bool isOrderedLvl(Level l) const { return isOrderedLT(getLvlType(l)); }
296296
bool isUniqueLvl(Level l) const { return isUniqueLT(getLvlType(l)); }
297297
bool isWithPos(Level l) const { return isWithPosLT(getLvlType(l)); }

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ class Merger {
510510
if (isLvlWithNonTrivialIdxExp(b)) {
511511
auto lt = getLoopDependentLevelType(b);
512512
return isCompressedLT(lt) || isSingletonLT(lt) ||
513-
isLooseCompressedLT(lt) || is2OutOf4LT(lt);
513+
isLooseCompressedLT(lt) || isNOutOfMLT(lt);
514514
}
515515
return false;
516516
}

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ class SparseTensorStorageBase {
123123
/// Safely checks if the level uses singleton storage.
124124
bool isSingletonLvl(uint64_t l) const { return isSingletonLT(getLvlType(l)); }
125125

126-
/// Safely checks if the level uses 2 out of 4 storage.
127-
bool is2OutOf4Lvl(uint64_t l) const { return is2OutOf4LT(getLvlType(l)); }
126+
/// Safely checks if the level uses n out of m storage.
127+
bool isNOutOfMLvl(uint64_t l) const { return isNOutOfMLT(getLvlType(l)); }
128128

129129
/// Safely checks if the level is ordered.
130130
bool isOrderedLvl(uint64_t l) const { return isOrderedLT(getLvlType(l)); }
@@ -450,7 +450,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
450450
void appendCrd(uint64_t lvl, uint64_t full, uint64_t crd) {
451451
if (!isDenseLvl(lvl)) {
452452
assert(isCompressedLvl(lvl) || isLooseCompressedLvl(lvl) ||
453-
isSingletonLvl(lvl) || is2OutOf4Lvl(lvl));
453+
isSingletonLvl(lvl) || isNOutOfMLvl(lvl));
454454
coordinates[lvl].push_back(detail::checkOverflowCast<C>(crd));
455455
} else { // Dense level.
456456
assert(crd >= full && "Coordinate was already filled");
@@ -473,7 +473,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
473473
return positions[l][parentSz];
474474
if (isLooseCompressedLvl(l))
475475
return positions[l][2 * parentSz - 1];
476-
if (isSingletonLvl(l) || is2OutOf4Lvl(l))
476+
if (isSingletonLvl(l) || isNOutOfMLvl(l))
477477
return parentSz; // new size same as the parent
478478
assert(isDenseLvl(l));
479479
return parentSz * getLvlSize(l);
@@ -527,7 +527,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
527527
uint64_t pos = coordinates[l].size();
528528
positions[l].insert(positions[l].end(), 2 * count,
529529
detail::checkOverflowCast<P>(pos));
530-
} else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
530+
} else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
531531
return; // Nothing to finalize.
532532
} else { // Dense dimension.
533533
assert(isDenseLvl(l));
@@ -624,7 +624,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
624624
lvlCursor[l] = static_cast<uint64_t>(coordinatesL[pos]);
625625
toCOO(pos, l + 1, dimCoords);
626626
}
627-
} else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
627+
} else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
628628
assert(parentPos < coordinates[l].size());
629629
lvlCursor[l] = static_cast<uint64_t>(coordinates[l][parentPos]);
630630
toCOO(parentPos, l + 1, dimCoords);
@@ -721,8 +721,8 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
721721
} else if (isSingletonLvl(l)) {
722722
coordinates[l].reserve(sz);
723723
sz = 1;
724-
} else if (is2OutOf4Lvl(l)) {
725-
assert(l == lvlRank - 1 && "unexpected 2:4 usage");
724+
} else if (isNOutOfMLvl(l)) {
725+
assert(l == lvlRank - 1 && "unexpected n:m usage");
726726
sz = detail::checkedMul(sz, lvlSizes[l]) / 2;
727727
coordinates[l].reserve(sz);
728728
values.reserve(sz);
@@ -791,8 +791,8 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
791791
}
792792
} else if (isSingletonLvl(l)) {
793793
assert(0 && "general singleton not supported yet");
794-
} else if (is2OutOf4Lvl(l)) {
795-
assert(0 && "2Out4 not supported yet");
794+
} else if (isNOutOfMLvl(l)) {
795+
assert(0 && "n ouf of m not supported yet");
796796
} else {
797797
assert(isDenseLvl(l));
798798
}

mlir/lib/Bindings/Python/DialectSparseTensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ using namespace mlir::python::adaptors;
2525
static void populateDialectSparseTensorSubmodule(const py::module &m) {
2626
py::enum_<MlirBaseSparseTensorLevelType>(m, "LevelType", py::module_local())
2727
.value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
28-
.value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR)
28+
.value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
2929
.value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
3030
.value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU)
3131
.value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO)

mlir/lib/CAPI/Dialect/SparseTensor.cpp

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,36 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
2020
mlir::sparse_tensor::SparseTensorDialect)
2121

2222
// Ensure the C-API enums are int-castable to C++ equivalents.
23-
static_assert(static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
24-
static_cast<int>(LevelType::Dense) &&
25-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
26-
static_cast<int>(LevelType::Compressed) &&
27-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
28-
static_cast<int>(LevelType::CompressedNu) &&
29-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
30-
static_cast<int>(LevelType::CompressedNo) &&
31-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
32-
static_cast<int>(LevelType::CompressedNuNo) &&
33-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
34-
static_cast<int>(LevelType::Singleton) &&
35-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
36-
static_cast<int>(LevelType::SingletonNu) &&
37-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
38-
static_cast<int>(LevelType::SingletonNo) &&
39-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
40-
static_cast<int>(LevelType::SingletonNuNo),
41-
"MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
23+
static_assert(
24+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
25+
static_cast<int>(LevelType::Dense) &&
26+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
27+
static_cast<int>(LevelType::Compressed) &&
28+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
29+
static_cast<int>(LevelType::CompressedNu) &&
30+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
31+
static_cast<int>(LevelType::CompressedNo) &&
32+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
33+
static_cast<int>(LevelType::CompressedNuNo) &&
34+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
35+
static_cast<int>(LevelType::Singleton) &&
36+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
37+
static_cast<int>(LevelType::SingletonNu) &&
38+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
39+
static_cast<int>(LevelType::SingletonNo) &&
40+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
41+
static_cast<int>(LevelType::SingletonNuNo) &&
42+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) ==
43+
static_cast<int>(LevelType::LooseCompressed) &&
44+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) ==
45+
static_cast<int>(LevelType::LooseCompressedNu) &&
46+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) ==
47+
static_cast<int>(LevelType::LooseCompressedNo) &&
48+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) ==
49+
static_cast<int>(LevelType::LooseCompressedNuNo) &&
50+
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) ==
51+
static_cast<int>(LevelType::NOutOfM),
52+
"MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
4253

4354
bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
4455
return isa<SparseTensorEncodingAttr>(unwrap(attr));

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,21 @@ using namespace mlir::sparse_tensor::ir_detail;
2929
// `LvlTypeParser` implementation.
3030
//===----------------------------------------------------------------------===//
3131

32-
FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
32+
FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
3333
StringRef base;
3434
const auto loc = parser.getCurrentLocation();
3535
ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
3636
"expected valid level format (e.g. dense, compressed or singleton)")
37-
uint8_t properties = 0;
37+
uint64_t properties = 0;
38+
SmallVector<unsigned> structure;
39+
40+
if (base.compare("structured") == 0) {
41+
ParseResult res = parser.parseCommaSeparatedList(
42+
mlir::OpAsmParser::Delimiter::OptionalSquare,
43+
[&]() -> ParseResult { return parseStructure(parser, &structure); },
44+
" in block n out of m");
45+
FAILURE_IF_FAILED(res)
46+
}
3847

3948
ParseResult res = parser.parseCommaSeparatedList(
4049
mlir::OpAsmParser::Delimiter::OptionalParen,
@@ -44,15 +53,20 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
4453

4554
// Set the base bit for properties.
4655
if (base.compare("dense") == 0) {
47-
properties |= static_cast<uint8_t>(LevelFormat::Dense);
56+
properties |= static_cast<uint64_t>(LevelFormat::Dense);
4857
} else if (base.compare("compressed") == 0) {
49-
properties |= static_cast<uint8_t>(LevelFormat::Compressed);
50-
} else if (base.compare("block2_4") == 0) {
51-
properties |= static_cast<uint8_t>(LevelFormat::TwoOutOfFour);
58+
properties |= static_cast<uint64_t>(LevelFormat::Compressed);
59+
} else if (base.compare("structured") == 0) {
60+
if (structure.size() != 2) {
61+
parser.emitError(loc, "expected exactly 2 structure sizes");
62+
return failure();
63+
}
64+
properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
65+
properties |= nToBits(structure[0]) | mToBits(structure[1]);
5266
} else if (base.compare("loose_compressed") == 0) {
53-
properties |= static_cast<uint8_t>(LevelFormat::LooseCompressed);
67+
properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
5468
} else if (base.compare("singleton") == 0) {
55-
properties |= static_cast<uint8_t>(LevelFormat::Singleton);
69+
properties |= static_cast<uint64_t>(LevelFormat::Singleton);
5670
} else {
5771
parser.emitError(loc, "unknown level format: ") << base;
5872
return failure();
@@ -64,20 +78,38 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
6478
}
6579

6680
ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
67-
uint8_t *properties) const {
81+
uint64_t *properties) const {
6882
StringRef strVal;
6983
auto loc = parser.getCurrentLocation();
7084
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
7185
"expected valid level property (e.g. nonordered, nonunique or high)")
7286
if (strVal.compare("nonunique") == 0) {
73-
*properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonunique);
87+
*properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique);
7488
} else if (strVal.compare("nonordered") == 0) {
75-
*properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonordered);
89+
*properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered);
7690
} else {
7791
parser.emitError(loc, "unknown level property: ") << strVal;
7892
return failure();
7993
}
8094
return success();
8195
}
8296

97+
ParseResult
98+
LvlTypeParser::parseStructure(AsmParser &parser,
99+
SmallVector<unsigned> *structure) const {
100+
int intVal;
101+
auto loc = parser.getCurrentLocation();
102+
OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
103+
if (intValParseResult.has_value()) {
104+
if (failed(*intValParseResult)) {
105+
parser.emitError(loc, "failed to parse block size");
106+
return failure();
107+
}
108+
structure->push_back(intVal);
109+
return success();
110+
}
111+
parser.emitError(loc, "expected valid integer for block size");
112+
return failure();
113+
}
114+
83115
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ namespace ir_detail {
1818
class LvlTypeParser {
1919
public:
2020
LvlTypeParser() = default;
21-
FailureOr<uint8_t> parseLvlType(AsmParser &parser) const;
21+
FailureOr<uint64_t> parseLvlType(AsmParser &parser) const;
2222

2323
private:
24-
ParseResult parseProperty(AsmParser &parser, uint8_t *properties) const;
24+
ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const;
25+
ParseResult parseStructure(AsmParser &parser,
26+
SmallVector<unsigned> *structure) const;
2527
};
2628

2729
} // namespace ir_detail

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,16 +613,28 @@ void SparseTensorEncodingAttr::printDimensions(
613613
}
614614
}
615615

616+
std::string getNOutOfMString(LevelType lt) {
617+
if (isNOutOfMLT(lt)) {
618+
unsigned n = getN(lt);
619+
unsigned m = getM(lt);
620+
auto output = "[" + std::to_string(n) + ", " + std::to_string(m) + "]";
621+
return output;
622+
}
623+
return "";
624+
}
625+
616626
void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
617627
ArrayRef<LevelType> lvlTypes) const {
618628
for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
619629
map.getResult(i).print(printer.getStream());
620-
printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
630+
printer << " : " << toMLIRString(lvlTypes[i])
631+
<< getNOutOfMString(lvlTypes[i]) << ", ";
621632
}
622633
if (map.getNumResults() >= 1) {
623634
auto lastIndex = map.getNumResults() - 1;
624635
map.getResult(lastIndex).print(printer.getStream());
625-
printer << " : " << toMLIRString(lvlTypes[lastIndex]);
636+
printer << " : " << toMLIRString(lvlTypes[lastIndex])
637+
<< getNOutOfMString(lvlTypes[lastIndex]);
626638
}
627639
}
628640

mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
451451
/// Test for 2:4 matrix with suitable metadata.
452452
static bool isAdmissible24(SparseTensorType &aTp) {
453453
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
454-
aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
454+
aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp);
455455
}
456456

457457
/// Test for conversion into 2:4 matrix.

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
130130
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
131131
/*value=*/posZero, /*repeat=*/linear);
132132
return;
133-
} else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
133+
} else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
134134
return; // nothing to do
135135
}
136136
// Keep compounding the size, but nothing needs to be initialized
@@ -409,7 +409,7 @@ static void genEndInsert(OpBuilder &builder, Location loc,
409409
}
410410
} else {
411411
assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
412-
is2OutOf4LT(lt));
412+
isNOutOfMLT(lt));
413413
}
414414
}
415415
}
@@ -488,7 +488,7 @@ class SparseInsertGenerator
488488
}
489489
parentPos =
490490
genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
491-
} else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
491+
} else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
492492
// Create:
493493
// coordinates[lvl].push_back(coords[lvl])
494494
// positions[lvl] = positions[lvl-1]

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
891891
assert(curr == env.merger().loop(b));
892892
Value clause;
893893
if (isCompressedLT(lt) || isSingletonLT(lt) ||
894-
isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
894+
isLooseCompressedLT(lt) || isNOutOfMLT(lt)) {
895895
assert(lvl.has_value());
896896
const Value crd = env.emitter().getCoord(tid, *lvl);
897897
const Value lvar = env.getLoopVar(curr);

0 commit comments

Comments
 (0)