Skip to content

Commit b52eb7c

Browse files
authored
[mlir][sparse] add a csr x bsr matmul test case (llvm#73012)
1 parent 1caaec1 commit b52eb7c

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,8 +1497,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
14971497
levelReducedDep[tid][lvl]--;
14981498
if (!resolved) {
14991499
// TODO: support coiterating multiple slices
1500-
assert(loopInfo.trivialTidLvls.empty() &&
1501-
loopInfo.sliceDrivenInfo.size() == 1);
1500+
assert(loopInfo.sliceDrivenInfo.size() == 1);
15021501
auto [nxNonEmpty, nxMinCrd, nxAbsOffset] =
15031502
genSliceNextInduction(builder, loc, tid, lvl);
15041503
// Update while loop induction operands.

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
doc = "X(i,j) *= A(i,j) * B(j,i)"
4242
}
4343

44+
#CSR = #sparse_tensor.encoding<{
45+
map = ( i, j ) -> (i : dense, j : compressed)
46+
}>
47+
4448

4549
#BSR = #sparse_tensor.encoding<{
4650
map = ( i, j ) ->
@@ -89,6 +93,20 @@ func.func @mul_24(%arg0: tensor<4x8xf64>,
8993
return %0 : tensor<4x4xf64>
9094
}
9195

96+
func.func @mul_csr_bsr(%arg0: tensor<4x8xf64, #CSR>,
97+
%arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
98+
%out = arith.constant dense<0.0> : tensor<4x4xf64>
99+
%0 = linalg.generic #trait_mul
100+
ins(%arg0, %arg1: tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>)
101+
outs(%out: tensor<4x4xf64>) {
102+
^bb(%x: f64, %y : f64, %z : f64):
103+
%1 = arith.mulf %x, %y : f64
104+
%2 = arith.addf %1, %z : f64
105+
linalg.yield %2 : f64
106+
} -> tensor<4x4xf64>
107+
return %0 : tensor<4x4xf64>
108+
}
109+
92110
func.func @mul_dense(%arg0: tensor<4x8xf64>,
93111
%arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
94112
%out = arith.constant dense<0.0> : tensor<4x4xf64>
@@ -132,18 +150,22 @@ func.func @mul_dense(%arg0: tensor<4x8xf64>,
132150

133151
%2 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
134152
%3 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>
153+
%4 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #CSR>
135154

136155
%d = call @mul_dense(%td, %td)
137156
: (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
138157
%s = call @mul(%td, %2)
139158
: (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
140159
%s24 = call @mul_24(%td, %3)
141160
: (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
161+
%scsr = call @mul_csr_bsr(%4, %2)
162+
: (tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
142163

143-
// CHECK-COUNT-3: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
164+
// CHECK-COUNT-4: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
144165
call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
145166
call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
146167
call @dumpf64(%s24) : (tensor<4x4xf64>) -> ()
168+
call @dumpf64(%scsr) : (tensor<4x4xf64>) -> ()
147169

148170
return
149171
}

0 commit comments

Comments
 (0)