Skip to content

Commit 429919e

Browse files
authored
[mlir][sparse][pybind][CAPI] remove LevelType enum from CAPI, constru… (#81682)
…ct LevelType from LevelFormat and properties instead. **Rationale** We used to explicitly declare every possible combination between `LevelFormat` and `LevelProperties`, and it now becomes difficult to scale as more properties/level formats are going to be introduced.
1 parent 21630ef commit 429919e

File tree

9 files changed

+123
-151
lines changed

9 files changed

+123
-151
lines changed

mlir/include/mlir-c/Dialect/SparseTensor.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,19 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
2727
/// file.
2828
typedef uint64_t MlirSparseTensorLevelType;
2929

30-
enum MlirBaseSparseTensorLevelType {
30+
enum MlirSparseTensorLevelFormat {
3131
MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
3232
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
33-
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001,
34-
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 0x000000020002,
35-
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 0x000000020003,
3633
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
37-
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 0x000000040001,
38-
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 0x000000040002,
39-
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 0x000000040003,
4034
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
41-
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 0x000000080001,
42-
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002,
43-
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003,
4435
MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
4536
};
4637

38+
enum MlirSparseTensorLevelPropertyNondefault {
39+
MLIR_SPARSE_PROPERTY_NON_UNIQUE = 0x0001,
40+
MLIR_SPARSE_PROPERTY_NON_ORDERED = 0x0002,
41+
};
42+
4743
//===----------------------------------------------------------------------===//
4844
// SparseTensorEncodingAttr
4945
//===----------------------------------------------------------------------===//
@@ -66,6 +62,10 @@ mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr);
6662
MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
6763
mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
6864

65+
/// Returns a specified level-format of the `sparse_tensor.encoding` attribute.
66+
MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelFormat
67+
mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl);
68+
6969
/// Returns the dimension-to-level mapping of the `sparse_tensor.encoding`
7070
/// attribute.
7171
MLIR_CAPI_EXPORTED MlirAffineMap
@@ -92,7 +92,9 @@ mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType);
9292

9393
MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
9494
mlirSparseTensorEncodingAttrBuildLvlType(
95-
enum MlirBaseSparseTensorLevelType lvlType, unsigned n, unsigned m);
95+
enum MlirSparseTensorLevelFormat lvlFmt,
96+
const enum MlirSparseTensorLevelPropertyNondefault *properties,
97+
unsigned propSize, unsigned n, unsigned m);
9698

9799
#ifdef __cplusplus
98100
}

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 21 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <cinttypes>
3636
#include <complex>
3737
#include <optional>
38+
#include <vector>
3839

3940
namespace mlir {
4041
namespace sparse_tensor {
@@ -343,17 +344,31 @@ constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
343344
/// Convert a LevelFormat to its corresponding LevelType with the given
344345
/// properties. Returns std::nullopt when the properties are not applicable
345346
/// for the input level format.
346-
constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
347-
bool unique, uint64_t n = 0,
348-
uint64_t m = 0) {
347+
inline std::optional<LevelType>
348+
buildLevelType(LevelFormat lf,
349+
const std::vector<LevelPropertyNondefault> &properties,
350+
uint64_t n = 0, uint64_t m = 0) {
349351
uint64_t newN = n << 32;
350352
uint64_t newM = m << 40;
351-
auto lt =
352-
static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) |
353-
(unique ? 0 : 1) | newN | newM);
353+
uint64_t ltInt = static_cast<uint64_t>(lf) | newN | newM;
354+
for (auto p : properties) {
355+
ltInt |= static_cast<uint64_t>(p);
356+
}
357+
auto lt = static_cast<LevelType>(ltInt);
354358
return isValidLT(lt) ? std::optional(lt) : std::nullopt;
355359
}
356360

361+
inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
362+
bool unique, uint64_t n = 0,
363+
uint64_t m = 0) {
364+
std::vector<LevelPropertyNondefault> properties;
365+
if (!ordered)
366+
properties.push_back(LevelPropertyNondefault::Nonordered);
367+
if (!unique)
368+
properties.push_back(LevelPropertyNondefault::Nonunique);
369+
return buildLevelType(lf, properties, n, m);
370+
}
371+
357372
//
358373
// Ensure the above methods work as intended.
359374
//
@@ -380,57 +395,6 @@ static_assert(
380395
*getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
381396
"getLevelFormat conversion is broken");
382397

383-
static_assert(
384-
(buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
385-
buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
386-
buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
387-
*buildLevelType(LevelFormat::Dense, true, true) == LevelType::Dense &&
388-
*buildLevelType(LevelFormat::Compressed, true, true) ==
389-
LevelType::Compressed &&
390-
*buildLevelType(LevelFormat::Compressed, true, false) ==
391-
LevelType::CompressedNu &&
392-
*buildLevelType(LevelFormat::Compressed, false, true) ==
393-
LevelType::CompressedNo &&
394-
*buildLevelType(LevelFormat::Compressed, false, false) ==
395-
LevelType::CompressedNuNo &&
396-
*buildLevelType(LevelFormat::Singleton, true, true) ==
397-
LevelType::Singleton &&
398-
*buildLevelType(LevelFormat::Singleton, true, false) ==
399-
LevelType::SingletonNu &&
400-
*buildLevelType(LevelFormat::Singleton, false, true) ==
401-
LevelType::SingletonNo &&
402-
*buildLevelType(LevelFormat::Singleton, false, false) ==
403-
LevelType::SingletonNuNo &&
404-
*buildLevelType(LevelFormat::LooseCompressed, true, true) ==
405-
LevelType::LooseCompressed &&
406-
*buildLevelType(LevelFormat::LooseCompressed, true, false) ==
407-
LevelType::LooseCompressedNu &&
408-
*buildLevelType(LevelFormat::LooseCompressed, false, true) ==
409-
LevelType::LooseCompressedNo &&
410-
*buildLevelType(LevelFormat::LooseCompressed, false, false) ==
411-
LevelType::LooseCompressedNuNo &&
412-
buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt &&
413-
buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt &&
414-
buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt &&
415-
*buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM),
416-
"buildLevelType conversion is broken");
417-
418-
static_assert(
419-
(getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 &&
420-
getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 &&
421-
getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 &&
422-
getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10),
423-
"getN/M conversion is broken");
424-
425-
static_assert(
426-
(isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4),
427-
2, 4) &&
428-
isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10),
429-
8, 10) &&
430-
!isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4),
431-
2, 4)),
432-
"isValidNOutOfMLT definition is broken");
433-
434398
static_assert(
435399
(isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
436400
isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&

mlir/lib/Bindings/Python/DialectSparseTensor.cpp

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,17 @@ using namespace mlir;
2323
using namespace mlir::python::adaptors;
2424

2525
static void populateDialectSparseTensorSubmodule(const py::module &m) {
26-
py::enum_<MlirBaseSparseTensorLevelType>(m, "LevelType", py::module_local())
26+
py::enum_<MlirSparseTensorLevelFormat>(m, "LevelFormat", py::module_local())
2727
.value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
2828
.value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
2929
.value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
30-
.value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU)
31-
.value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO)
32-
.value("compressed_nu_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO)
3330
.value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON)
34-
.value("singleton_nu", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU)
35-
.value("singleton_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO)
36-
.value("singleton_nu_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO)
37-
.value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED)
38-
.value("loose_compressed_nu",
39-
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU)
40-
.value("loose_compressed_no",
41-
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO)
42-
.value("loose_compressed_nu_no",
43-
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO);
31+
.value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED);
32+
33+
py::enum_<MlirSparseTensorLevelPropertyNondefault>(m, "LevelProperty",
34+
py::module_local())
35+
.value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED)
36+
.value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE);
4437

4538
mlir_attribute_subclass(m, "EncodingAttr",
4639
mlirAttributeIsASparseTensorEncodingAttr)
@@ -62,12 +55,17 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
6255
"Gets a sparse_tensor.encoding from parameters.")
6356
.def_classmethod(
6457
"build_level_type",
65-
[](py::object cls, MlirBaseSparseTensorLevelType lvlType, unsigned n,
66-
unsigned m) {
67-
return mlirSparseTensorEncodingAttrBuildLvlType(lvlType, n, m);
58+
[](py::object cls, MlirSparseTensorLevelFormat lvlFmt,
59+
const std::vector<MlirSparseTensorLevelPropertyNondefault>
60+
&properties,
61+
unsigned n, unsigned m) {
62+
return mlirSparseTensorEncodingAttrBuildLvlType(
63+
lvlFmt, properties.data(), properties.size(), n, m);
6864
},
69-
py::arg("cls"), py::arg("lvl_type"), py::arg("n") = 0,
70-
py::arg("m") = 0,
65+
py::arg("cls"), py::arg("lvl_fmt"),
66+
py::arg("properties") =
67+
std::vector<MlirSparseTensorLevelPropertyNondefault>(),
68+
py::arg("n") = 0, py::arg("m") = 0,
7169
"Builds a sparse_tensor.encoding.level_type from parameters.")
7270
.def_property_readonly(
7371
"lvl_types",
@@ -113,17 +111,12 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
113111
return mlirSparseTensorEncodingAttrGetStructuredM(
114112
mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
115113
})
116-
.def_property_readonly("lvl_types_enum", [](MlirAttribute self) {
114+
.def_property_readonly("lvl_formats_enum", [](MlirAttribute self) {
117115
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
118-
std::vector<MlirBaseSparseTensorLevelType> ret;
116+
std::vector<MlirSparseTensorLevelFormat> ret;
119117
ret.reserve(lvlRank);
120-
for (int l = 0; l < lvlRank; l++) {
121-
// Convert level type to 32 bits to ignore n and m for n_out_of_m
122-
// format.
123-
ret.push_back(
124-
static_cast<MlirBaseSparseTensorLevelType>(static_cast<uint32_t>(
125-
mlirSparseTensorEncodingAttrGetLvlType(self, l))));
126-
}
118+
for (int l = 0; l < lvlRank; l++)
119+
ret.push_back(mlirSparseTensorEncodingAttrGetLvlFmt(self, l));
127120
return ret;
128121
});
129122
}

mlir/lib/CAPI/Dialect/SparseTensor.cpp

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,34 +22,23 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
2222
// Ensure the C-API enums are int-castable to C++ equivalents.
2323
static_assert(
2424
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
25-
static_cast<int>(LevelType::Dense) &&
25+
static_cast<int>(LevelFormat::Dense) &&
2626
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
27-
static_cast<int>(LevelType::Compressed) &&
28-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
29-
static_cast<int>(LevelType::CompressedNu) &&
30-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
31-
static_cast<int>(LevelType::CompressedNo) &&
32-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
33-
static_cast<int>(LevelType::CompressedNuNo) &&
27+
static_cast<int>(LevelFormat::Compressed) &&
3428
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
35-
static_cast<int>(LevelType::Singleton) &&
36-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
37-
static_cast<int>(LevelType::SingletonNu) &&
38-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
39-
static_cast<int>(LevelType::SingletonNo) &&
40-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
41-
static_cast<int>(LevelType::SingletonNuNo) &&
29+
static_cast<int>(LevelFormat::Singleton) &&
4230
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) ==
43-
static_cast<int>(LevelType::LooseCompressed) &&
44-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) ==
45-
static_cast<int>(LevelType::LooseCompressedNu) &&
46-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) ==
47-
static_cast<int>(LevelType::LooseCompressedNo) &&
48-
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) ==
49-
static_cast<int>(LevelType::LooseCompressedNuNo) &&
31+
static_cast<int>(LevelFormat::LooseCompressed) &&
5032
static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) ==
51-
static_cast<int>(LevelType::NOutOfM),
52-
"MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
33+
static_cast<int>(LevelFormat::NOutOfM),
34+
"MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch");
35+
36+
static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) ==
37+
static_cast<int>(LevelPropertyNondefault::Nonordered) &&
38+
static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) ==
39+
static_cast<int>(LevelPropertyNondefault::Nonunique),
40+
"MlirSparseTensorLevelProperty (C-API) and "
41+
"LevelPropertyNondefault (C++) mismatch");
5342

5443
bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
5544
return isa<SparseTensorEncodingAttr>(unwrap(attr));
@@ -87,6 +76,13 @@ mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
8776
cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
8877
}
8978

79+
enum MlirSparseTensorLevelFormat
80+
mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) {
81+
LevelType lt =
82+
static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl));
83+
return static_cast<MlirSparseTensorLevelFormat>(*getLevelFormat(lt));
84+
}
85+
9086
int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
9187
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getPosWidth();
9288
}
@@ -95,12 +91,17 @@ int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
9591
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth();
9692
}
9793

98-
MlirSparseTensorLevelType
99-
mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType,
100-
unsigned n, unsigned m) {
101-
LevelType lt = static_cast<LevelType>(lvlType);
102-
return static_cast<MlirSparseTensorLevelType>(*buildLevelType(
103-
*getLevelFormat(lt), isOrderedLT(lt), isUniqueLT(lt), n, m));
94+
MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
95+
enum MlirSparseTensorLevelFormat lvlFmt,
96+
const enum MlirSparseTensorLevelPropertyNondefault *properties,
97+
unsigned size, unsigned n, unsigned m) {
98+
99+
std::vector<LevelPropertyNondefault> props;
100+
for (unsigned i = 0; i < size; i++)
101+
props.push_back(static_cast<LevelPropertyNondefault>(properties[i]));
102+
103+
return static_cast<MlirSparseTensorLevelType>(
104+
*buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m));
104105
}
105106

106107
unsigned

mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,15 @@ def main():
139139
# search the full state space to reduce runtime of the test. It is
140140
# straightforward to adapt the code below to explore more combinations.
141141
# For these simple orderings, dim2lvl and lvl2dim are the same.
142+
builder = st.EncodingAttr.build_level_type
143+
fmt = st.LevelFormat
144+
prop = st.LevelProperty
142145
levels = [
143-
[st.LevelType.compressed_nu, st.LevelType.singleton],
144-
[st.LevelType.dense, st.LevelType.dense],
145-
[st.LevelType.dense, st.LevelType.compressed],
146-
[st.LevelType.compressed, st.LevelType.dense],
147-
[st.LevelType.compressed, st.LevelType.compressed],
146+
[builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)],
147+
[builder(fmt.dense), builder(fmt.dense)],
148+
[builder(fmt.dense), builder(fmt.compressed)],
149+
[builder(fmt.compressed), builder(fmt.dense)],
150+
[builder(fmt.compressed), builder(fmt.compressed)],
148151
]
149152
orderings = [
150153
ir.AffineMap.get_permutation([0, 1]),

mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,15 @@ def main():
125125
vl = 1
126126
e = False
127127
opt = f"parallelization-strategy=none"
128+
builder = st.EncodingAttr.build_level_type
129+
fmt = st.LevelFormat
130+
prop = st.LevelProperty
128131
levels = [
129-
[st.LevelType.compressed_nu, st.LevelType.singleton],
130-
[st.LevelType.dense, st.LevelType.dense],
131-
[st.LevelType.dense, st.LevelType.compressed],
132-
[st.LevelType.compressed, st.LevelType.dense],
133-
[st.LevelType.compressed, st.LevelType.compressed],
132+
[builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)],
133+
[builder(fmt.dense), builder(fmt.dense)],
134+
[builder(fmt.dense), builder(fmt.compressed)],
135+
[builder(fmt.compressed), builder(fmt.dense)],
136+
[builder(fmt.compressed), builder(fmt.compressed)],
134137
]
135138
orderings = [
136139
ir.AffineMap.get_permutation([0, 1]),

mlir/test/Integration/Dialect/SparseTensor/python/test_output.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,14 @@ def main():
124124
# Loop over various sparse types (COO, CSR, DCSR, CSC, DCSC) with
125125
# regular and loose compression and various metadata bitwidths.
126126
# For these simple orderings, dim2lvl and lvl2dim are the same.
127+
builder = st.EncodingAttr.build_level_type
128+
fmt = st.LevelFormat
129+
prop = st.LevelProperty
127130
levels = [
128-
[st.LevelType.compressed_nu, st.LevelType.singleton],
129-
[st.LevelType.dense, st.LevelType.compressed],
130-
[st.LevelType.dense, st.LevelType.loose_compressed],
131-
[st.LevelType.compressed, st.LevelType.compressed],
131+
[builder(fmt.compressed, [prop.non_unique]), builder(fmt.singleton)],
132+
[builder(fmt.dense), builder(fmt.compressed)],
133+
[builder(fmt.dense), builder(fmt.loose_compressed)],
134+
[builder(fmt.compressed), builder(fmt.compressed)],
132135
]
133136
orderings = [
134137
(ir.AffineMap.get_permutation([0, 1]), 0),
@@ -149,10 +152,10 @@ def main():
149152

150153
# Now do the same for BSR.
151154
level = [
152-
st.LevelType.dense,
153-
st.LevelType.compressed,
154-
st.LevelType.dense,
155-
st.LevelType.dense,
155+
builder(fmt.dense),
156+
builder(fmt.compressed),
157+
builder(fmt.dense),
158+
builder(fmt.dense),
156159
]
157160
d0 = ir.AffineDimExpr.get(0)
158161
d1 = ir.AffineDimExpr.get(1)

0 commit comments

Comments
 (0)