Skip to content

Commit 2a6b521

Browse files
[mlir][sparse] Add more tests and verification for n:m (#81186)
1. Add python test for n out of m 2. Add more methods for python binding 3. Add verification for n:m and invalid encoding tests 4. Add e2e test for n:m Previous PRs for n:m #80501 #79935
1 parent 5948d4d commit 2a6b521

File tree

9 files changed

+331
-16
lines changed

9 files changed

+331
-16
lines changed

mlir/include/mlir-c/Dialect/SparseTensor.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr);
8484
MLIR_CAPI_EXPORTED int
8585
mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr);
8686

87+
MLIR_CAPI_EXPORTED unsigned
88+
mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType);
89+
90+
MLIR_CAPI_EXPORTED unsigned
91+
mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType);
92+
93+
MLIR_CAPI_EXPORTED MlirSparseTensorLevelType
94+
mlirSparseTensorEncodingAttrBuildLvlType(
95+
enum MlirBaseSparseTensorLevelType lvlType, unsigned n, unsigned m);
96+
8797
#ifdef __cplusplus
8898
}
8999
#endif

mlir/lib/Bindings/Python/DialectSparseTensor.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
6060
py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
6161
py::arg("context") = py::none(),
6262
"Gets a sparse_tensor.encoding from parameters.")
63+
.def_classmethod(
64+
"build_level_type",
65+
[](py::object cls, MlirBaseSparseTensorLevelType lvlType, unsigned n,
66+
unsigned m) {
67+
return mlirSparseTensorEncodingAttrBuildLvlType(lvlType, n, m);
68+
},
69+
py::arg("cls"), py::arg("lvl_type"), py::arg("n") = 0,
70+
py::arg("m") = 0,
71+
"Builds a sparse_tensor.encoding.level_type from parameters.")
6372
.def_property_readonly(
6473
"lvl_types",
6574
[](MlirAttribute self) {
@@ -89,7 +98,34 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
8998
.def_property_readonly("pos_width",
9099
mlirSparseTensorEncodingAttrGetPosWidth)
91100
.def_property_readonly("crd_width",
92-
mlirSparseTensorEncodingAttrGetCrdWidth);
101+
mlirSparseTensorEncodingAttrGetCrdWidth)
102+
.def_property_readonly(
103+
"structured_n",
104+
[](MlirAttribute self) -> unsigned {
105+
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
106+
return mlirSparseTensorEncodingAttrGetStructuredN(
107+
mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
108+
})
109+
.def_property_readonly(
110+
"structured_m",
111+
[](MlirAttribute self) -> unsigned {
112+
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
113+
return mlirSparseTensorEncodingAttrGetStructuredM(
114+
mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1));
115+
})
116+
.def_property_readonly("lvl_types_enum", [](MlirAttribute self) {
117+
const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
118+
std::vector<MlirBaseSparseTensorLevelType> ret;
119+
ret.reserve(lvlRank);
120+
for (int l = 0; l < lvlRank; l++) {
121+
// Convert level type to 32 bits to ignore n and m for n_out_of_m
122+
// format.
123+
ret.push_back(
124+
static_cast<MlirBaseSparseTensorLevelType>(static_cast<uint32_t>(
125+
mlirSparseTensorEncodingAttrGetLvlType(self, l))));
126+
}
127+
return ret;
128+
});
93129
}
94130

95131
PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {

mlir/lib/CAPI/Dialect/SparseTensor.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,21 @@ int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
9494
int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
9595
return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth();
9696
}
97+
98+
MlirSparseTensorLevelType
99+
mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType,
100+
unsigned n, unsigned m) {
101+
LevelType lt = static_cast<LevelType>(lvlType);
102+
return static_cast<MlirSparseTensorLevelType>(*buildLevelType(
103+
*getLevelFormat(lt), isOrderedLT(lt), isUniqueLT(lt), n, m));
104+
}
105+
106+
unsigned
107+
mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType) {
108+
return getN(static_cast<LevelType>(lvlType));
109+
}
110+
111+
unsigned
112+
mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType) {
113+
return getM(static_cast<LevelType>(lvlType));
114+
}

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,22 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
3535
ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
3636
"expected valid level format (e.g. dense, compressed or singleton)")
3737
uint64_t properties = 0;
38-
SmallVector<unsigned> structure;
38+
SmallVector<unsigned> structured;
3939

4040
if (base.compare("structured") == 0) {
4141
ParseResult res = parser.parseCommaSeparatedList(
4242
mlir::OpAsmParser::Delimiter::OptionalSquare,
43-
[&]() -> ParseResult { return parseStructure(parser, &structure); },
44-
" in block n out of m");
43+
[&]() -> ParseResult { return parseStructured(parser, &structured); },
44+
" in structured n out of m");
4545
FAILURE_IF_FAILED(res)
46+
if (structured.size() != 2) {
47+
parser.emitError(loc, "expected exactly 2 structured sizes");
48+
return failure();
49+
}
50+
if (structured[0] > structured[1]) {
51+
parser.emitError(loc, "expected n <= m in n_out_of_m");
52+
return failure();
53+
}
4654
}
4755

4856
ParseResult res = parser.parseCommaSeparatedList(
@@ -57,12 +65,8 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
5765
} else if (base.compare("compressed") == 0) {
5866
properties |= static_cast<uint64_t>(LevelFormat::Compressed);
5967
} else if (base.compare("structured") == 0) {
60-
if (structure.size() != 2) {
61-
parser.emitError(loc, "expected exactly 2 structure sizes");
62-
return failure();
63-
}
6468
properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
65-
properties |= nToBits(structure[0]) | mToBits(structure[1]);
69+
properties |= nToBits(structured[0]) | mToBits(structured[1]);
6670
} else if (base.compare("loose_compressed") == 0) {
6771
properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
6872
} else if (base.compare("singleton") == 0) {
@@ -95,20 +99,24 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
9599
}
96100

97101
ParseResult
98-
LvlTypeParser::parseStructure(AsmParser &parser,
99-
SmallVector<unsigned> *structure) const {
102+
LvlTypeParser::parseStructured(AsmParser &parser,
103+
SmallVector<unsigned> *structured) const {
100104
int intVal;
101105
auto loc = parser.getCurrentLocation();
102106
OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
103107
if (intValParseResult.has_value()) {
104108
if (failed(*intValParseResult)) {
105-
parser.emitError(loc, "failed to parse block size");
109+
parser.emitError(loc, "failed to parse structured size");
110+
return failure();
111+
}
112+
if (intVal < 0) {
113+
parser.emitError(loc, "expected structured size to be >= 0");
106114
return failure();
107115
}
108-
structure->push_back(intVal);
116+
structured->push_back(intVal);
109117
return success();
110118
}
111-
parser.emitError(loc, "expected valid integer for block size");
119+
parser.emitError(loc, "expected valid integer for structured size");
112120
return failure();
113121
}
114122

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class LvlTypeParser {
2222

2323
private:
2424
ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const;
25-
ParseResult parseStructure(AsmParser &parser,
26-
SmallVector<unsigned> *structure) const;
25+
ParseResult parseStructured(AsmParser &parser,
26+
SmallVector<unsigned> *structured) const;
2727
};
2828

2929
} // namespace ir_detail

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,37 @@ LogicalResult SparseTensorEncodingAttr::verify(
657657
return emitError() << "expected all singleton lvlTypes "
658658
"following a singleton level";
659659
}
660+
// TODO: audit formats that actually are supported by backend.
661+
if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
662+
it != std::end(lvlTypes)) {
663+
if (it != lvlTypes.end() - 1)
664+
return emitError() << "expected n_out_of_m to be the last level type";
665+
if (!std::all_of(lvlTypes.begin(), it,
666+
[](LevelType i) { return isDenseLT(i); }))
667+
return emitError() << "expected all dense lvlTypes "
668+
"before a n_out_of_m level";
669+
if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
670+
if (!isBlockSparsity(dimToLvl)) {
671+
return emitError()
672+
<< "expected 1xm block structure for n_out_of_m level";
673+
}
674+
auto sizes = getBlockSize(dimToLvl);
675+
unsigned coefficient = 0;
676+
for (const auto &elem : sizes) {
677+
if (elem != 0) {
678+
if (elem != coefficient && coefficient != 0) {
679+
return emitError() << "expected only one blocked level "
680+
"with the same coefficients";
681+
}
682+
coefficient = elem;
683+
}
684+
}
685+
if (coefficient != getM(*it)) {
686+
return emitError() << "expected coeffiencts of Affine expressions "
687+
"to be equal to m of n_out_of_m level";
688+
}
689+
}
690+
}
660691
// Before we can check that the level-rank is consistent/coherent
661692
// across all fields, we need to define it. The source-of-truth for
662693
// the `getLvlRank` method is the length of the level-types array,

mlir/test/Dialect/SparseTensor/invalid_encoding.mlir

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,109 @@ func.func private @BSR(%arg0: tensor<?x?xf64, #BSR>) {
315315
func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
316316
return
317317
}
318+
319+
// -----
320+
321+
// expected-error@+6 {{expected structured size to be >= 0}}
322+
#NOutOfM = #sparse_tensor.encoding<{
323+
map = ( i, j, k ) ->
324+
( i : dense,
325+
k floordiv 4 : dense,
326+
j : dense,
327+
k mod 4 : structured[-2, 4]
328+
)
329+
}>
330+
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
331+
return
332+
}
333+
334+
// -----
335+
336+
// expected-error@+6 {{expected n <= m in n_out_of_m}}
337+
#NOutOfM = #sparse_tensor.encoding<{
338+
map = ( i, j, k ) ->
339+
( i : dense,
340+
k floordiv 4 : dense,
341+
j : dense,
342+
k mod 4 : structured[5, 4]
343+
)
344+
}>
345+
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
346+
return
347+
}
348+
349+
// -----
350+
351+
// expected-error@+1 {{expected all dense lvlTypes before a n_out_of_m level}}
352+
#NOutOfM = #sparse_tensor.encoding<{
353+
map = ( i, j, k ) ->
354+
( i : dense,
355+
k floordiv 4 : compressed,
356+
j : dense,
357+
k mod 4 : structured[2, 4]
358+
)
359+
}>
360+
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
361+
return
362+
}
363+
364+
// -----
365+
366+
// expected-error@+1 {{expected n_out_of_m to be the last level type}}
367+
#NOutOfM = #sparse_tensor.encoding<{
368+
map = ( i, j, k ) ->
369+
( i : dense,
370+
k floordiv 4 : structured[2, 4],
371+
j : dense,
372+
k mod 4 : compressed
373+
)
374+
}>
375+
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
376+
return
377+
}
378+
379+
// -----
380+
381+
// expected-error@+1 {{expected 1xm block structure for n_out_of_m level}}
382+
#NOutOfM = #sparse_tensor.encoding<{
383+
map = ( i, j, k ) ->
384+
( i : dense,
385+
k floordiv 2 : dense,
386+
j : dense,
387+
k mod 4 : structured[2, 4]
388+
)
389+
}>
390+
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
391+
return
392+
}
393+
394+
// -----
395+
396+
// expected-error@+1 {{expected coeffiencts of Affine expressions to be equal to m of n_out_of_m level}}
397+
#NOutOfM = #sparse_tensor.encoding<{
398+
map = ( i, j, k ) ->
399+
( i : dense,
400+
k floordiv 2 : dense,
401+
j : dense,
402+
k mod 2 : structured[2, 4]
403+
)
404+
}>
405+
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
406+
return
407+
}
408+
409+
// -----
410+
411+
// expected-error@+1 {{expected only one blocked level with the same coefficients}}
412+
#NOutOfM = #sparse_tensor.encoding<{
413+
map = ( i, j, k ) ->
414+
( i floordiv 2 : dense,
415+
i mod 2 : dense,
416+
j : dense,
417+
k floordiv 4 : dense,
418+
k mod 4 : structured[2, 4]
419+
)
420+
}>
421+
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
422+
return
423+
}

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@
4545
crdWidth = 8
4646
}>
4747

48+
#NV_58 = #sparse_tensor.encoding<{
49+
map = ( i, j ) -> ( i : dense,
50+
j floordiv 8 : dense,
51+
j mod 8 : structured[5, 8]),
52+
crdWidth = 8
53+
}>
54+
4855
module {
4956

5057
func.func private @getTensorFilename(index) -> (!Filename)
@@ -65,6 +72,7 @@ module {
6572
%A1 = sparse_tensor.new %fileName : !Filename to tensor<?x?xf64, #CSR>
6673
%A2 = sparse_tensor.new %fileName : !Filename to tensor<?x?xf64, #CSR_hi>
6774
%A3 = sparse_tensor.new %fileName : !Filename to tensor<?x?xf64, #NV_24>
75+
%A4 = sparse_tensor.new %fileName : !Filename to tensor<?x?xf64, #NV_58>
6876

6977
//
7078
// CSR:
@@ -113,10 +121,24 @@ module {
113121
%vecv3 = vector.transfer_read %val3[%c0], %f0 : memref<?xf64>, vector<12xf64>
114122
vector.print %vecv3 : vector<12xf64>
115123

124+
//
125+
// NV_58
126+
//
127+
// CHECK-NEXT: ( 2, 3, 5, 7, 1, 2, 4, 7, 0, 2, 4, 5 )
128+
// CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 )
129+
//
130+
%crd4 = sparse_tensor.coordinates %A4 {level = 2 : index } : tensor<?x?xf64, #NV_58> to memref<?xi8>
131+
%vecc4 = vector.transfer_read %crd4[%c0], %u0 : memref<?xi8>, vector<12xi8>
132+
vector.print %vecc4 : vector<12xi8>
133+
%val4 = sparse_tensor.values %A4 : tensor<?x?xf64, #NV_58> to memref<?xf64>
134+
%vecv4 = vector.transfer_read %val4[%c0], %f0 : memref<?xf64>, vector<12xf64>
135+
vector.print %vecv4 : vector<12xf64>
136+
116137
// Release the resources.
117138
bufferization.dealloc_tensor %A1: tensor<?x?xf64, #CSR>
118139
bufferization.dealloc_tensor %A2: tensor<?x?xf64, #CSR_hi>
119140
bufferization.dealloc_tensor %A3: tensor<?x?xf64, #NV_24>
141+
bufferization.dealloc_tensor %A4: tensor<?x?xf64, #NV_58>
120142

121143
return
122144
}

0 commit comments

Comments
 (0)