Skip to content

Commit bfe08c0

Browse files
authored
[mlir][sparse] support sparsifying 2:4 block sparsity (#71749)
1 parent de79314 commit bfe08c0

File tree

5 files changed

+62
-26
lines changed

5 files changed

+62
-26
lines changed

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,8 @@ class Merger {
540540
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
541541
if (isLvlWithNonTrivialIdxExp(b)) {
542542
auto dlt = getLoopDependentLevelType(b);
543-
return isCompressedDLT(dlt) || isSingletonDLT(dlt);
543+
return isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
544+
isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt);
544545
}
545546
return false;
546547
}

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ void LoopEmitter::initializeLoopEmit(
448448
positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
449449
coordinatesBuffers[t][l] =
450450
genToCoordinates(builder, loc, tensor, l, cooStart);
451-
} else if (isSingletonDLT(lvlTp)) {
451+
} else if (isSingletonDLT(lvlTp) || is2OutOf4DLT(lvlTp)) {
452452
// Singleton level, fetch coordinates.
453453
coordinatesBuffers[t][l] =
454454
genToCoordinates(builder, loc, tensor, l, cooStart);
@@ -540,7 +540,8 @@ void LoopEmitter::categorizeLoopCondition(
540540
auto lvlType = lvlTypes[t][l];
541541
// Must be a recognizable DLT.
542542
assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) ||
543-
isLooseCompressedDLT(lvlType) || isSingletonDLT(lvlType));
543+
isLooseCompressedDLT(lvlType) || isSingletonDLT(lvlType) ||
544+
is2OutOf4DLT(lvlType));
544545

545546
bool isSparse = !isDenseDLT(lvlType);
546547
bool isSlice = isSparseSlices[t];
@@ -637,6 +638,7 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
637638
Value hi, MutableArrayRef<Value> reduc, bool isParallel) {
638639
bool isSparseCond = isCompressedDLT(lvlTypes[tid][lvl]) ||
639640
isLooseCompressedDLT(lvlTypes[tid][lvl]) ||
641+
is2OutOf4DLT(lvlTypes[tid][lvl]) ||
640642
isSingletonDLT(lvlTypes[tid][lvl]);
641643
// TODO: support dynamic slices.
642644
// Uses the first dimension here to build the loop bound (which is also the
@@ -1240,6 +1242,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
12401242

12411243
const Value c0 = C_IDX(0);
12421244
const Value c1 = C_IDX(1);
1245+
const Value c2 = C_IDX(2);
12431246
// Either the first level, or the previous level has been set.
12441247
/// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
12451248
assert(lvl == 0 || posits[tid][lvl - 1]);
@@ -1248,7 +1251,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
12481251

12491252
Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
12501253
if (isLooseCompressedDLT(lvlTp))
1251-
pLo = builder.create<arith::MulIOp>(loc, pLo, C_IDX(2));
1254+
pLo = builder.create<arith::MulIOp>(loc, pLo, c2);
12521255
posits[tid][lvl] = genIndexLoad(builder, loc, mem, pLo);
12531256

12541257
const Value pHi = ADDI(pLo, c1);
@@ -1271,7 +1274,13 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
12711274
: ADDI(pLo, c1);
12721275
return;
12731276
}
1274-
1277+
if (is2OutOf4DLT(lvlTp)) {
1278+
const Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
1279+
// Each 2:4 block has exactly two specified elements.
1280+
posits[tid][lvl] = MULI(pLo, c2);
1281+
highs[tid][lvl] = ADDI(posits[tid][lvl], c2);
1282+
return;
1283+
}
12751284
llvm_unreachable("Unrecognized level-type!");
12761285
}
12771286

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
816816
for (LoopId i = 0; i < numLoops; i++) {
817817
const auto dltI = env.dlt(tid, i);
818818
if (isCompressedDLT(dltI) || isLooseCompressedDLT(dltI) ||
819-
isSingletonDLT(dltI)) {
819+
isSingletonDLT(dltI) || is2OutOf4DLT(dltI)) {
820820
for (LoopId j = 0; j < numLoops; j++)
821821
if (isUndefDLT(env.dlt(tid, j))) {
822822
addIterOrdering(i, j, adjM, inDegree);
@@ -1508,7 +1508,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
15081508
assert(ldx == env.merger().loop(b));
15091509
Value clause;
15101510
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
1511-
isLooseCompressedDLT(dlt)) {
1511+
isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt)) {
15121512
assert(lvl.has_value());
15131513
const Value crd = env.emitter().getCoords()[tid][*lvl];
15141514
const Value lvar = env.getLoopVar(ldx);
@@ -1593,7 +1593,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
15931593
needsUniv = true;
15941594
}
15951595
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
1596-
isLooseCompressedDLT(dlt) || isIdxReduc) {
1596+
isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt) || isIdxReduc) {
15971597
// Only when this is a index reduction loop, can the dlt be undefined.
15981598
assert(!isUndefDLT(dlt) || isIdxReduc);
15991599
// sparse/singleton levels, or a dense/sparse index reduction loop.

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
490490
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
491491
const auto dlt = getLvlType(b);
492492
if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) &&
493-
!isLooseCompressedDLT(dlt)) {
493+
!isLooseCompressedDLT(dlt) && !is2OutOf4DLT(dlt)) {
494494
if (reset)
495495
simple.reset(b);
496496
reset = true;
@@ -671,7 +671,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
671671
for (TensorLoopId b : bits.set_bits()) {
672672
const auto dlt = getLvlType(b);
673673
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
674-
isLooseCompressedDLT(dlt))
674+
isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt))
675675
return true;
676676
}
677677
return hasSparseIdxReduction(bits);

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,41 @@
4444
#BSR = #sparse_tensor.encoding<{
4545
map = ( i, j ) ->
4646
( i floordiv 2 : dense,
47-
j floordiv 3 : compressed,
47+
j floordiv 2 : compressed,
4848
i mod 2 : dense,
49-
j mod 3 : dense
49+
j mod 2 : dense
5050
)
5151
}>
5252

53+
#NV_24 = #sparse_tensor.encoding<{
54+
map = ( i, j ) ->
55+
( i : dense,
56+
j floordiv 4 : dense,
57+
j mod 4 : block2_4
58+
),
59+
}>
60+
5361
module {
5462

55-
func.func @mul(%arg0: tensor<4x6xf64>,
56-
%arg1: tensor<4x6xf64, #BSR>) -> tensor<4x4xf64> {
63+
func.func @mul(%arg0: tensor<4x8xf64>,
64+
%arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
65+
%out = arith.constant dense<0.0> : tensor<4x4xf64>
66+
%0 = linalg.generic #trait_mul
67+
ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #BSR>)
68+
outs(%out: tensor<4x4xf64>) {
69+
^bb(%x: f64, %y : f64, %z : f64):
70+
%1 = arith.mulf %x, %y : f64
71+
%2 = arith.addf %1, %z : f64
72+
linalg.yield %2 : f64
73+
} -> tensor<4x4xf64>
74+
return %0 : tensor<4x4xf64>
75+
}
76+
77+
func.func @mul_24(%arg0: tensor<4x8xf64>,
78+
%arg1: tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64> {
5779
%out = arith.constant dense<0.0> : tensor<4x4xf64>
5880
%0 = linalg.generic #trait_mul
59-
ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64, #BSR>)
81+
ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #NV_24>)
6082
outs(%out: tensor<4x4xf64>) {
6183
^bb(%x: f64, %y : f64, %z : f64):
6284
%1 = arith.mulf %x, %y : f64
@@ -66,11 +88,11 @@ func.func @mul(%arg0: tensor<4x6xf64>,
6688
return %0 : tensor<4x4xf64>
6789
}
6890

69-
func.func @mul_dense(%arg0: tensor<4x6xf64>,
70-
%arg1: tensor<4x6xf64>) -> tensor<4x4xf64> {
91+
func.func @mul_dense(%arg0: tensor<4x8xf64>,
92+
%arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
7193
%out = arith.constant dense<0.0> : tensor<4x4xf64>
7294
%0 = linalg.generic #trait_mul
73-
ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64>)
95+
ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64>)
7496
outs(%out: tensor<4x4xf64>) {
7597
^bb(%x: f64, %y : f64, %z : f64):
7698
%1 = arith.mulf %x, %y : f64
@@ -101,22 +123,26 @@ func.func @mul_dense(%arg0: tensor<4x6xf64>,
101123
%c2 = arith.constant 2 : index
102124

103125

104-
%td = arith.constant dense<[[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
105-
[ 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
106-
[12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
107-
[18.0, 19.0, 20.0, 21.0, 22.0, 23.0]]> : tensor<4x6xf64>
126+
%td = arith.constant dense<[[ 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0],
127+
[ 6.0, 7.0, 0.0, 0.0, 0.0, 0.0, 10.0, 11.0],
128+
[ 0.0, 0.0, 12.0, 13.0, 16.0, 17.0, 0.0, 0.0],
129+
[ 0.0, 0.0, 18.0, 19.0, 22.0, 23.0, 0.0, 0.0]]> : tensor<4x8xf64>
108130

109131

110-
%2 = sparse_tensor.convert %td : tensor<4x6xf64> to tensor<4x6xf64, #BSR>
132+
%2 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
133+
%3 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>
111134

112135
%d = call @mul_dense(%td, %td)
113-
: (tensor<4x6xf64>, tensor<4x6xf64>) -> tensor<4x4xf64>
136+
: (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
114137
%s = call @mul(%td, %2)
115-
: (tensor<4x6xf64>, tensor<4x6xf64, #BSR>) -> tensor<4x4xf64>
138+
: (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
139+
%s24 = call @mul_24(%td, %3)
140+
: (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
116141

117-
// CHECK-COUNT-2: ( ( 55, 145, 235, 325 ), ( 145, 451, 757, 1063 ), ( 235, 757, 1279, 1801 ), ( 325, 1063, 1801, 2539 ) )
142+
// CHECK-COUNT-3: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
118143
call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
119144
call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
145+
call @dumpf64(%s24) : (tensor<4x4xf64>) -> ()
120146

121147
return
122148
}

0 commit comments

Comments
 (0)