Skip to content

Commit 56d5829

Browse files
authored
[mlir][sparse] Introduce batch level format. (#83082)
1 parent 371e6d0 commit 56d5829

File tree

11 files changed

+62
-18
lines changed

11 files changed

+62
-18
lines changed

mlir/include/mlir-c/Dialect/SparseTensor.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ typedef uint64_t MlirSparseTensorLevelType;
2929

3030
enum MlirSparseTensorLevelFormat {
3131
MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
32-
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
33-
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
34-
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
35-
MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
32+
MLIR_SPARSE_TENSOR_LEVEL_BATCH = 0x000000020000,
33+
MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000040000,
34+
MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000080000,
35+
MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000100000,
36+
MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000200000,
3637
};
3738

3839
enum MlirSparseTensorLevelPropertyNondefault {

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,26 @@ enum class Action : uint32_t {
154154
enum class LevelFormat : uint64_t {
155155
Undef = 0x00000000,
156156
Dense = 0x00010000,
157-
Compressed = 0x00020000,
158-
Singleton = 0x00040000,
159-
LooseCompressed = 0x00080000,
160-
NOutOfM = 0x00100000,
157+
Batch = 0x00020000,
158+
Compressed = 0x00040000,
159+
Singleton = 0x00080000,
160+
LooseCompressed = 0x00100000,
161+
NOutOfM = 0x00200000,
161162
};
162163

164+
constexpr bool encPowOfTwo(LevelFormat fmt) {
165+
auto enc = static_cast<std::underlying_type_t<LevelFormat>>(fmt);
166+
return (enc & (enc - 1)) == 0;
167+
}
168+
169+
// All LevelFormats must have only one bit set (power of two).
170+
static_assert(encPowOfTwo(LevelFormat::Dense) &&
171+
encPowOfTwo(LevelFormat::Batch) &&
172+
encPowOfTwo(LevelFormat::Compressed) &&
173+
encPowOfTwo(LevelFormat::Singleton) &&
174+
encPowOfTwo(LevelFormat::LooseCompressed) &&
175+
encPowOfTwo(LevelFormat::NOutOfM));
176+
163177
template <LevelFormat... targets>
164178
constexpr bool isAnyOfFmt(LevelFormat fmt) {
165179
return (... || (targets == fmt));
@@ -172,6 +186,8 @@ constexpr const char *toFormatString(LevelFormat lvlFmt) {
172186
return "undef";
173187
case LevelFormat::Dense:
174188
return "dense";
189+
case LevelFormat::Batch:
190+
return "batch";
175191
case LevelFormat::Compressed:
176192
return "compressed";
177193
case LevelFormat::Singleton:
@@ -225,10 +241,10 @@ struct LevelType {
225241
static constexpr bool isValidLvlBits(uint64_t lvlBits) {
226242
auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
227243
const uint64_t propertyBits = lvlBits & 0xffff;
228-
// If undefined/dense/NOutOfM, then must be unique and ordered.
244+
// If undefined/dense/batch/NOutOfM, then must be unique and ordered.
229245
// Otherwise, the format must be one of the known ones.
230246
return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
231-
LevelFormat::NOutOfM>(fmt))
247+
LevelFormat::Batch, LevelFormat::NOutOfM>(fmt))
232248
? (propertyBits == 0)
233249
: (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
234250
LevelFormat::LooseCompressed>(fmt));
@@ -375,6 +391,7 @@ inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
375391
}
376392
inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
377393
inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
394+
inline bool isBatchLT(LevelType lt) { return lt.isa<LevelFormat::Batch>(); }
378395
inline bool isCompressedLT(LevelType lt) {
379396
return lt.isa<LevelFormat::Compressed>();
380397
}

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
141141

142142
The supported level-formats are the following:
143143

144-
- **dense** : all entries along this level are stored
144+
- **dense** : all entries along this level are stored and linearized.
145+
- **batch** : all entries along this level are stored but not linearized.
145146
- **compressed** : only nonzeros along this level are stored
146147
- **loose_compressed** : as compressed, but allows for free space between regions
147148
- **singleton** : a variant of the compressed format, where coordinates have no siblings

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
6262
// Set the base bit for properties.
6363
if (base.compare("dense") == 0) {
6464
properties |= static_cast<uint64_t>(LevelFormat::Dense);
65+
} else if (base.compare("batch") == 0) {
66+
properties |= static_cast<uint64_t>(LevelFormat::Batch);
6567
} else if (base.compare("compressed") == 0) {
6668
properties |= static_cast<uint64_t>(LevelFormat::Compressed);
6769
} else if (base.compare("structured") == 0) {

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,10 @@ LogicalResult SparseTensorEncodingAttr::verify(
690690
}
691691
}
692692

693+
auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
694+
if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
695+
return emitError() << "Batch lvlType can only be leading levels.";
696+
693697
// SoA property can only be applied on singleton level.
694698
auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
695699
return lt.isa<LevelPropNonDefault::SoA>();

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,8 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
12781278
switch (lt.getLvlFmt()) {
12791279
case LevelFormat::Dense:
12801280
return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
1281+
case LevelFormat::Batch:
1282+
llvm_unreachable("not implemented");
12811283
case LevelFormat::Compressed: {
12821284
Value pos = genToPositions(b, l, t, lvl);
12831285
Value crd = genToCoordinates(b, l, t, lvl);

mlir/test/CAPI/sparse_tensor.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
3939
// CHECK: (d0, d1)[s0] -> (s0, d0, d1)
4040
mlirAffineMapDump(dimToLvl);
4141
// CHECK: level_type: 65536
42-
// CHECK: level_type: 131072
43-
// CHECK: level_type: 131072
42+
// CHECK: level_type: 262144
43+
// CHECK: level_type: 262144
4444
MlirAffineMap lvlToDim =
4545
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
4646
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);

mlir/test/Dialect/SparseTensor/invalid_encoding.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ func.func private @tensor_dimlevel_size_mismatch(%arg0: tensor<8xi32, #a>) -> ()
5454

5555
// -----
5656

57+
// expected-error@+1 {{Batch lvlType can only be leading levels}}
58+
#a = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : batch, d1 : compressed, d2: batch)}>
59+
func.func private @non_leading_batch(%arg0: tensor<?x?x?i32, #a>) -> ()
60+
61+
// -----
62+
5763
// expected-error@+1 {{use of undeclared identifier}}
5864
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : dense, d1 : compressed)}>
5965
func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()

mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
2222

2323
// -----
2424

25+
#BCSR = #sparse_tensor.encoding<{
26+
map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
27+
}>
28+
29+
// CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed) }>
30+
// CHECK-LABEL: func private @sparse_bcsr(
31+
// CHECK-SAME: tensor<?x?x?xf32, #[[$BCSR]]>)
32+
func.func private @sparse_bcsr(tensor<?x?x?xf32, #BCSR>)
33+
34+
// -----
35+
2536
#CSR_explicit = #sparse_tensor.encoding<{
2637
map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
2738
}>

mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true
1515
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index
1616
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index
17-
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 131072 : i64
17+
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 262144 : i64
1818
// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
1919
// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
2020
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>

mlir/test/python/dialects/sparse_tensor/dialect.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def testEncodingAttr1D():
2828
# CHECK: equal: True
2929
print(f"equal: {casted == parsed}")
3030

31-
# CHECK: lvl_types: [131072]
31+
# CHECK: lvl_types: [262144]
3232
print(f"lvl_types: {casted.lvl_types}")
3333
# CHECK: dim_to_lvl: (d0) -> (d0)
3434
print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -71,9 +71,9 @@ def testEncodingAttrStructure():
7171
# CHECK: equal: True
7272
print(f"equal: {casted == parsed}")
7373

74-
# CHECK: lvl_types: [65536, 65536, 4406637494272]
74+
# CHECK: lvl_types: [65536, 65536, 4406638542848]
7575
print(f"lvl_types: {casted.lvl_types}")
76-
# CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 1048576>]
76+
# CHECK: lvl_formats_enum: [<LevelFormat.dense: 65536>, <LevelFormat.dense: 65536>, <LevelFormat.n_out_of_m: 2097152>]
7777
print(f"lvl_formats_enum: {casted.lvl_formats_enum}")
7878
# CHECK: structured_n: 2
7979
print(f"structured_n: {casted.structured_n}")
@@ -157,7 +157,7 @@ def testEncodingAttr2D():
157157
# CHECK: equal: True
158158
print(f"equal: {casted == parsed}")
159159

160-
# CHECK: lvl_types: [65536, 131072]
160+
# CHECK: lvl_types: [65536, 262144]
161161
print(f"lvl_types: {casted.lvl_types}")
162162
# CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
163163
print(f"dim_to_lvl: {casted.dim_to_lvl}")

0 commit comments

Comments
 (0)