Skip to content

Commit 22216d0

Browse files
more edits for n:m
1 parent 01bbe25 commit 22216d0

File tree

5 files changed

+26
-8
lines changed

5 files changed

+26
-8
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
145145
- **compressed** : only nonzeros along this level are stored
146146
- **loose_compressed** : as compressed, but allows for free space between regions
147147
- **singleton** : a variant of the compressed format, where coordinates have no siblings
148-
- **block[2, 4]** : the compression uses a 2:4 encoding per 1x4 block
148+
- **block[n, m]** : the compression uses a n:m encoding per 1xm block
149149

150150
For a compressed level, each position interval is represented in a compact
151151
way with a lowerbound `pos(i)` and an upperbound `pos(i+1) - 1`, which implies

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class SparseTensorStorageBase {
123123
/// Safely checks if the level uses singleton storage.
124124
bool isSingletonLvl(uint64_t l) const { return isSingletonLT(getLvlType(l)); }
125125

126-
/// Safely checks if the level uses 2 out of 4 storage.
126+
/// Safely checks if the level uses n out of m storage.
127127
bool isNOutOfMLvl(uint64_t l) const { return isNOutOfMLT(getLvlType(l)); }
128128

129129
/// Safely checks if the level is ordered.
@@ -792,7 +792,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
792792
} else if (isSingletonLvl(l)) {
793793
assert(0 && "general singleton not supported yet");
794794
} else if (isNOutOfMLvl(l)) {
795-
assert(0 && "2Out4 not supported yet");
795+
assert(0 && "n ouf of m not supported yet");
796796
} else {
797797
assert(isDenseLvl(l));
798798
}

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
4040
if (base.compare("block") == 0) {
4141
ParseResult res = parser.parseCommaSeparatedList(
4242
mlir::OpAsmParser::Delimiter::OptionalSquare,
43-
[&]() -> ParseResult { return parseBlockSize(parser, &blockSizes); },
43+
[&]() -> ParseResult { return parseBlockSizes(parser, &blockSizes); },
4444
" in block n out of m");
4545
FAILURE_IF_FAILED(res)
4646
}
@@ -95,8 +95,8 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
9595
}
9696

9797
ParseResult
98-
LvlTypeParser::parseBlockSize(AsmParser &parser,
99-
SmallVector<unsigned> *blockSizes) const {
98+
LvlTypeParser::parseBlockSizes(AsmParser &parser,
99+
SmallVector<unsigned> *blockSizes) const {
100100
int intVal;
101101
auto loc = parser.getCurrentLocation();
102102
OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class LvlTypeParser {
2222

2323
private:
2424
ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const;
25-
ParseResult parseBlockSize(AsmParser &parser,
26-
SmallVector<unsigned> *blockSizes) const;
25+
ParseResult parseBlockSizes(AsmParser &parser,
26+
SmallVector<unsigned> *blockSizes) const;
2727
};
2828

2929
} // namespace ir_detail

mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,21 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
254254
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
255255
return
256256
}
257+
258+
// -----
259+
260+
#NOutOfM = #sparse_tensor.encoding<{
261+
map = ( i, j, k ) ->
262+
( i : dense,
263+
k floordiv 8 : dense,
264+
j : dense,
265+
k mod 8 : block[5, 8]
266+
)
267+
}>
268+
269+
// CHECK-DAG: #[[$NOutOfM:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 8 : dense, d1 : dense, d2 mod 8 : block[5, 8]) }>
270+
// CHECK-LABEL: func private @NOutOfM(
271+
// CHECK-SAME: tensor<?x?x?xf64, #[[$NOutOfM]]>
272+
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
273+
return
274+
}

0 commit comments

Comments
 (0)