Skip to content

Commit 898bf53

Browse files
[mlir][sparse] Surface syntax change in parsing
Example: compressed(nonunique, nonordered) or compressed(nonordered, nonunique) instead of compressed_nu_no. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D159366
1 parent 8347d7c commit 898bf53

File tree

6 files changed

+68
-70
lines changed

6 files changed

+68
-70
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ enum class Action : uint32_t {
169169
///
170170
// TODO: We should generalize TwoOutOfFour to N out of M and use property to
171171
// encode the value of N and M.
172+
// TODO: Update DimLevelType to use lower 8 bits for storage formats and the
173+
// higher 4 bits to store level properties. Consider CompressedWithHi and
174+
// TwoOutOfFour as properties instead of formats.
172175
enum class DimLevelType : uint8_t {
173176
Undef = 0, // 0b00000_00
174177
Dense = 4, // 0b00001_00
@@ -197,6 +200,14 @@ enum class LevelFormat : uint8_t {
197200
TwoOutOfFour = 64, // 0b10000_00
198201
};
199202

203+
/// This enum defines all the nondefault properties for storage formats.
204+
enum class LevelNondefaultProperty : uint8_t {
205+
Nonunique = 1, // 0b00000_01
206+
Nonordered = 2, // 0b00000_10
207+
High = 32, // 0b01000_00
208+
Block2_4 = 64 // 0b10000_00
209+
};
210+
200211
/// Returns string representation of the given dimension level type.
201212
constexpr const char *toMLIRString(DimLevelType dlt) {
202213
switch (dlt) {

mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ void DimLvlMap::print(llvm::raw_ostream &os, bool wantElision) const {
391391
os << '{';
392392
llvm::interleaveComma(
393393
lvlSpecs, os, [&](LvlSpec const &spec) { os << spec.getBoundVar(); });
394-
os << '}';
394+
os << "} ";
395395
}
396396

397397
// Dimension specifiers.

mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
354354
const auto type = lvlTypeParser.parseLvlType(parser);
355355
FAILURE_IF_FAILED(type)
356356

357-
lvlSpecs.emplace_back(var, expr, *type);
357+
lvlSpecs.emplace_back(var, expr, static_cast<DimLevelType>(*type));
358358
return success();
359359
}
360360

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "LvlTypeParser.h"
10+
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
1011

1112
using namespace mlir;
1213
using namespace mlir::sparse_tensor;
@@ -46,34 +47,57 @@ using namespace mlir::sparse_tensor::ir_detail;
4647
// `LvlTypeParser` implementation.
4748
//===----------------------------------------------------------------------===//
4849

49-
std::optional<DimLevelType> LvlTypeParser::lookup(StringRef str) const {
50-
// NOTE: `StringMap::lookup` will return a default-constructed value if
51-
// the key isn't found; which for enums means zero, and therefore makes
52-
// it impossible to distinguish between actual zero-DimLevelType vs
53-
// not-found. Whereas `StringMap::at` asserts that the key is found,
54-
// which we don't want either.
55-
const auto it = map.find(str);
56-
return it == map.end() ? std::nullopt : std::make_optional(it->second);
57-
}
50+
FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
51+
StringRef base;
52+
FAILURE_IF_FAILED(parser.parseOptionalKeyword(&base));
53+
uint8_t properties = 0;
54+
const auto loc = parser.getCurrentLocation();
5855

59-
std::optional<DimLevelType> LvlTypeParser::lookup(StringAttr str) const {
60-
return str ? lookup(str.getValue()) : std::nullopt;
61-
}
56+
ParseResult res = parser.parseCommaSeparatedList(
57+
mlir::OpAsmParser::Delimiter::OptionalParen,
58+
[&]() -> ParseResult { return parseProperty(parser, &properties); },
59+
" in level property list");
60+
FAILURE_IF_FAILED(res)
6261

63-
FailureOr<DimLevelType> LvlTypeParser::parseLvlType(AsmParser &parser) const {
64-
DimLevelType out;
65-
FAILURE_IF_FAILED(parseLvlType(parser, out))
66-
return out;
62+
// Set the base bit for properties.
63+
if (base.compare("dense") == 0) {
64+
properties |= static_cast<uint8_t>(LevelFormat::Dense);
65+
} else if (base.compare("compressed") == 0) {
66+
// TODO: Remove this condition once dimLvlType enum is refactored. Current
67+
// enum treats High and TwoOutOfFour as formats instead of properties.
68+
if (!(properties & static_cast<uint8_t>(LevelNondefaultProperty::High) ||
69+
properties &
70+
static_cast<uint8_t>(LevelNondefaultProperty::Block2_4))) {
71+
properties |= static_cast<uint8_t>(LevelFormat::Compressed);
72+
}
73+
} else if (base.compare("singleton") == 0) {
74+
properties |= static_cast<uint8_t>(LevelFormat::Singleton);
75+
} else {
76+
parser.emitError(loc, "unknown level format");
77+
return failure();
78+
}
79+
80+
ERROR_IF(!isValidDLT(static_cast<DimLevelType>(properties)),
81+
"invalid level type");
82+
return properties;
6783
}
6884

69-
ParseResult LvlTypeParser::parseLvlType(AsmParser &parser,
70-
DimLevelType &out) const {
71-
const auto loc = parser.getCurrentLocation();
85+
ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
86+
uint8_t *properties) const {
7287
StringRef strVal;
7388
FAILURE_IF_FAILED(parser.parseOptionalKeyword(&strVal));
74-
const auto lvlType = lookup(strVal);
75-
ERROR_IF(!lvlType, "unknown level-type '" + strVal + "'")
76-
out = *lvlType;
89+
if (strVal.compare("nonunique") == 0) {
90+
*properties |= static_cast<uint8_t>(LevelNondefaultProperty::Nonunique);
91+
} else if (strVal.compare("nonordered") == 0) {
92+
*properties |= static_cast<uint8_t>(LevelNondefaultProperty::Nonordered);
93+
} else if (strVal.compare("high") == 0) {
94+
*properties |= static_cast<uint8_t>(LevelNondefaultProperty::High);
95+
} else if (strVal.compare("block2_4") == 0) {
96+
*properties |= static_cast<uint8_t>(LevelNondefaultProperty::Block2_4);
97+
} else {
98+
parser.emitError(parser.getCurrentLocation(), "unknown level property");
99+
return failure();
100+
}
77101
return success();
78102
}
79103

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h

Lines changed: 4 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,56 +9,19 @@
99
#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_LVLTYPEPARSER_H
1010
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_LVLTYPEPARSER_H
1111

12-
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
1312
#include "mlir/IR/OpImplementation.h"
14-
#include "llvm/ADT/StringMap.h"
1513

1614
namespace mlir {
1715
namespace sparse_tensor {
1816
namespace ir_detail {
1917

20-
//===----------------------------------------------------------------------===//
21-
// These macros are for generating a C++ expression of type
22-
// `std::initializer_list<std::pair<StringRef,DimLevelType>>` since there's
23-
// no way to construct an object of that type directly via C++ code.
24-
#define FOREVERY_LEVELTYPE(DO) \
25-
DO(DimLevelType::Dense) \
26-
DO(DimLevelType::Compressed) \
27-
DO(DimLevelType::CompressedNu) \
28-
DO(DimLevelType::CompressedNo) \
29-
DO(DimLevelType::CompressedNuNo) \
30-
DO(DimLevelType::Singleton) \
31-
DO(DimLevelType::SingletonNu) \
32-
DO(DimLevelType::SingletonNo) \
33-
DO(DimLevelType::SingletonNuNo) \
34-
DO(DimLevelType::CompressedWithHi) \
35-
DO(DimLevelType::CompressedWithHiNu) \
36-
DO(DimLevelType::CompressedWithHiNo) \
37-
DO(DimLevelType::CompressedWithHiNuNo) \
38-
DO(DimLevelType::TwoOutOfFour)
39-
#define LEVELTYPE_INITLIST_ELEMENT(lvlType) \
40-
std::make_pair(StringRef(toMLIRString(lvlType)), lvlType),
41-
#define LEVELTYPE_INITLIST \
42-
{ FOREVERY_LEVELTYPE(LEVELTYPE_INITLIST_ELEMENT) }
43-
44-
// TODO(wrengr): Since this parser is non-trivial to construct, is there
45-
// any way to hook into the parsing process so that we construct it only once
46-
// at the begining of parsing and then destroy it once parsing has finished?
4718
class LvlTypeParser {
48-
const llvm::StringMap<DimLevelType> map;
49-
5019
public:
51-
explicit LvlTypeParser() : map(LEVELTYPE_INITLIST) {}
52-
#undef LEVELTYPE_INITLIST
53-
#undef LEVELTYPE_INITLIST_ELEMENT
54-
#undef FOREVERY_LEVELTYPE
20+
LvlTypeParser() = default;
21+
FailureOr<uint8_t> parseLvlType(AsmParser &parser) const;
5522

56-
std::optional<DimLevelType> lookup(StringRef str) const;
57-
std::optional<DimLevelType> lookup(StringAttr str) const;
58-
ParseResult parseLvlType(AsmParser &parser, DimLevelType &out) const;
59-
FailureOr<DimLevelType> parseLvlType(AsmParser &parser) const;
60-
// TODO(wrengr): `parseOptionalLvlType`?
61-
// TODO(wrengr): `parseLvlTypeList`?
23+
private:
24+
ParseResult parseProperty(AsmParser &parser, uint8_t *properties) const;
6225
};
6326

6427
} // namespace ir_detail

mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func.func private @sparse_dcsc(tensor<?x?xf32, #DCSC>)
5555
// -----
5656

5757
#COO = #sparse_tensor.encoding<{
58-
lvlTypes = [ "compressed_nu_no", "singleton_no" ]
58+
map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered))
5959
}>
6060

6161
// CHECK-LABEL: func private @sparse_coo(
@@ -65,7 +65,7 @@ func.func private @sparse_coo(tensor<?x?xf32, #COO>)
6565
// -----
6666

6767
#BCOO = #sparse_tensor.encoding<{
68-
lvlTypes = [ "dense", "compressed_hi_nu", "singleton" ]
68+
map = (d0, d1, d2) -> (d0 : dense, d1 : compressed(nonunique, high), d2 : singleton)
6969
}>
7070

7171
// CHECK-LABEL: func private @sparse_bcoo(
@@ -75,7 +75,7 @@ func.func private @sparse_bcoo(tensor<?x?x?xf32, #BCOO>)
7575
// -----
7676

7777
#SortedCOO = #sparse_tensor.encoding<{
78-
lvlTypes = [ "compressed_nu", "singleton" ]
78+
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
7979
}>
8080

8181
// CHECK-LABEL: func private @sparse_sorted_coo(
@@ -144,7 +144,7 @@ func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
144144
// below) to encode a 2D matrix, but it would require dim2lvl mapping which is not ready yet.
145145
// So we take the simple path for now.
146146
#NV_24= #sparse_tensor.encoding<{
147-
lvlTypes = [ "dense", "compressed24" ],
147+
map = (d0, d1) -> (d0 : dense, d1 : compressed(block2_4))
148148
}>
149149

150150
// CHECK-LABEL: func private @sparse_2_out_of_4(
@@ -195,7 +195,7 @@ func.func private @BCSR_explicit(%arg0: tensor<?x?xf64, #BCSR_explicit>) {
195195
map = ( i, j ) ->
196196
( i : dense,
197197
j floordiv 4 : dense,
198-
j mod 4 : compressed24
198+
j mod 4 : compressed(block2_4)
199199
)
200200
}>
201201

0 commit comments

Comments
 (0)