Skip to content

Commit 2a28861

Browse files
committed
[mlir][sparse] improved testing and codegen for semi-ring operations
The semi-ring blocks were simply "inlined" by the sparse compiler but without any filtering or patching. This revision improves the analysis (rejecting blocks that use non-invariant computations from outside their blocks, except for linalg.index) and also improves the codegen by properly patching up index computations (previous version crashed). With a regression test. Also updated the documentation now that the example code is properly working. Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D128000
1 parent 21f557e commit 2a28861

File tree

4 files changed

+179
-13
lines changed

4 files changed

+179
-13
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
188188
is solely defined by side-effects and not SSA values. The semantics
189189
may be refined over time as our sparse abstractions evolve.
190190

191+
Example:
192+
191193
```mlir
192194
sparse_tensor.lex_insert %tensor, %indices, %val
193195
: tensor<1024x1024xf64, #CSR>, memref<?xindex>, f64
@@ -385,7 +387,8 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>,
385387
would be equivalent to a union operation where non-overlapping values
386388
in the inputs are copied to the output unchanged.
387389

388-
Example of isEqual applied to intersecting elements only.
390+
Example of isEqual applied to intersecting elements only:
391+
389392
```mlir
390393
%C = bufferization.alloc_tensor...
391394
%0 = linalg.generic #trait
@@ -405,8 +408,8 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>,
405408
} -> tensor<?xi8, #SparseVec>
406409
```
407410

408-
Example of A+B in upper triangle, A-B in lower triangle
409-
(not working yet, but construct will be available soon).
411+
Example of A+B in upper triangle, A-B in lower triangle:
412+
410413
```mlir
411414
%C = bufferization.alloc_tensor...
412415
%1 = linalg.generic #trait
@@ -438,7 +441,8 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>,
438441

439442
Example of set difference. Returns a copy of A where its sparse structure
440443
is *not* overlapped by B. The element type of B can be different than A
441-
because we never use its values, only its sparse structure.
444+
because we never use its values, only its sparse structure:
445+
442446
```mlir
443447
%C = bufferization.alloc_tensor...
444448
%2 = linalg.generic #trait
@@ -486,6 +490,7 @@ def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [NoSideEffect]>,
486490
region does not contribute to the output.
487491

488492
Example of A+1, restricted to existing elements:
493+
489494
```mlir
490495
%C = bufferization.alloc_tensor...
491496
%0 = linalg.generic #trait
@@ -546,6 +551,7 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
546551
Yields a value from within a `binary` or `unary` block.
547552

548553
Example:
554+
549555
```
550556
%0 = sparse_tensor.unary %a : i64 to i64 {
551557
^bb0(%arg0: i64):

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -881,9 +881,8 @@ static Value genAddress(CodeGen &codegen, OpBuilder &builder, Location loc,
881881
}
882882

883883
/// Generates an index value.
884-
static Value genIndexValue(Merger &merger, CodeGen &codegen, OpBuilder &builder,
885-
unsigned exp, unsigned ldx) {
886-
unsigned idx = merger.exp(exp).index;
884+
static Value genIndexValue(CodeGen &codegen, OpBuilder &builder, unsigned idx,
885+
unsigned ldx) {
887886
Value ival = codegen.loops[idx];
888887
Type itype = ival.getType();
889888
// During vectorization, we either encounter:
@@ -913,6 +912,25 @@ static Value genIndexValue(Merger &merger, CodeGen &codegen, OpBuilder &builder,
913912
return ival;
914913
}
915914

915+
/// Semi-ring branches are simply inlined by the sparse compiler. Prior
916+
/// analysis has verified that all computations are "local" to the inlined
917+
/// branch or otherwise invariantly defined outside the loop nest, with the
918+
/// exception of index computations, which need to be relinked to actual
919+
/// inlined cloned code.
920+
static Value relinkBranch(CodeGen &codegen, RewriterBase &rewriter,
921+
Block *block, Value e, unsigned ldx) {
922+
if (Operation *def = e.getDefiningOp()) {
923+
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
924+
return genIndexValue(codegen, rewriter, indexOp.dim(), ldx);
925+
if (def->getBlock() == block) {
926+
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
927+
def->setOperand(
928+
i, relinkBranch(codegen, rewriter, block, def->getOperand(i), ldx));
929+
}
930+
}
931+
return e;
932+
}
933+
916934
/// Recursively generates tensor expression.
917935
static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
918936
linalg::GenericOp op, unsigned exp, unsigned ldx) {
@@ -924,12 +942,17 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
924942
if (merger.exp(exp).kind == Kind::kInvariant)
925943
return genInvariantValue(merger, codegen, rewriter, exp);
926944
if (merger.exp(exp).kind == Kind::kIndex)
927-
return genIndexValue(merger, codegen, rewriter, exp, ldx);
945+
return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx);
928946
Value v0 =
929947
genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx);
930948
Value v1 =
931949
genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx);
932-
return merger.buildExp(rewriter, loc, exp, v0, v1);
950+
Value ee = merger.buildExp(rewriter, loc, exp, v0, v1);
951+
if (ee && (merger.exp(exp).kind == Kind::kUnary ||
952+
merger.exp(exp).kind == Kind::kBinary ||
953+
merger.exp(exp).kind == Kind::kBinaryBranch))
954+
ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx);
955+
return ee;
933956
}
934957

935958
/// Determines if affine expression is invariant.

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,9 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
798798
}
799799

800800
Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
801+
// Build the linalg semantics backward from yield.
801802
Operation *yield = op.region().front().getTerminator();
803+
assert(isa<linalg::YieldOp>(yield));
802804
return buildTensorExp(op, yield->getOperand(0));
803805
}
804806

@@ -832,6 +834,37 @@ Type Merger::inferType(unsigned e, Value src) {
832834
return dtp;
833835
}
834836

837+
/// Ensures that sparse compiler can generate code for expression.
838+
static bool isAdmissableBranchExp(Operation *op, Block *block, Value v) {
839+
// Arguments are always admissable.
840+
if (auto arg = v.dyn_cast<BlockArgument>())
841+
return true;
842+
// Accept index anywhere.
843+
Operation *def = v.getDefiningOp();
844+
if (isa<linalg::IndexOp>(def))
845+
return true;
846+
// Operation defined outside branch.
847+
if (def->getBlock() != block) {
848+
return def->getBlock() != op->getBlock(); // invariant?
849+
}
850+
// Operation defined within branch. Anything is accepted,
851+
// as long as all subexpressions are admissable.
852+
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
853+
if (!isAdmissableBranchExp(op, block, def->getOperand(i)))
854+
return false;
855+
return true;
856+
}
857+
858+
/// Ensures that sparse compiler can generate code for branch.
859+
static bool isAdmissableBranch(Operation *op, Region &region) {
860+
if (region.empty())
861+
return true;
862+
// Build the semi-ring branch semantics backward from yield.
863+
Operation *yield = region.front().getTerminator();
864+
assert(isa<YieldOp>(yield));
865+
return isAdmissableBranchExp(op, &region.front(), yield->getOperand(0));
866+
}
867+
835868
Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
836869
if (auto arg = v.dyn_cast<BlockArgument>()) {
837870
unsigned argN = arg.getArgNumber();
@@ -920,8 +953,11 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
920953
return addExp(kCRe, e);
921954
if (isa<arith::BitcastOp>(def))
922955
return addExp(kBitCast, e, v);
923-
if (isa<sparse_tensor::UnaryOp>(def))
924-
return addExp(kUnary, e, Value(), def);
956+
if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
957+
if (isAdmissableBranch(unop, unop.presentRegion()) &&
958+
isAdmissableBranch(unop, unop.absentRegion()))
959+
return addExp(kUnary, e, Value(), def);
960+
}
925961
}
926962
}
927963
// Construct binary operations if subexpressions can be built.
@@ -971,8 +1007,14 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
9711007
return addExp(kShrU, e0, e1);
9721008
if (isa<arith::ShLIOp>(def) && isInvariant(e1))
9731009
return addExp(kShlI, e0, e1);
974-
if (isa<sparse_tensor::BinaryOp>(def))
975-
return addExp(kBinary, e0, e1, Value(), def);
1010+
if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1011+
if (isAdmissableBranch(binop, binop.overlapRegion()) &&
1012+
(binop.left_identity() ||
1013+
isAdmissableBranch(binop, binop.leftRegion())) &&
1014+
(binop.right_identity() ||
1015+
isAdmissableBranch(binop, binop.rightRegion())))
1016+
return addExp(kBinary, e0, e1, Value(), def);
1017+
}
9761018
}
9771019
}
9781020
// Cannot build.
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// RUN: mlir-opt %s --sparse-compiler | \
2+
// RUN: mlir-cpu-runner \
3+
// RUN: -e entry -entry-point-result=void \
4+
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
5+
// RUN: FileCheck %s
6+
7+
#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
8+
9+
#trait_op = {
10+
indexing_maps = [
11+
affine_map<(i,j) -> (i,j)>, // A
12+
affine_map<(i,j) -> (i,j)>, // B
13+
affine_map<(i,j) -> (i,j)> // X (out)
14+
],
15+
iterator_types = ["parallel","parallel"],
16+
doc = "X(i,j) = A(i,j) OP B(i,j)"
17+
}
18+
19+
module {
20+
// Performs triangular add/sub operation (using semi-ring binary op).
21+
func.func @triangular(%A: tensor<4x4xf64, #SparseMatrix>,
22+
%B: tensor<4x4xf64, #SparseMatrix>) -> tensor<4x4xf64, #SparseMatrix> {
23+
%C = bufferization.alloc_tensor() : tensor<4x4xf64, #SparseMatrix>
24+
%0 = linalg.generic #trait_op
25+
ins(%A, %B: tensor<4x4xf64, #SparseMatrix>,
26+
tensor<4x4xf64, #SparseMatrix>)
27+
outs(%C: tensor<4x4xf64, #SparseMatrix>) {
28+
^bb0(%a: f64, %b: f64, %c: f64) :
29+
%row = linalg.index 0 : index
30+
%col = linalg.index 1 : index
31+
%result = sparse_tensor.binary %a, %b : f64, f64 to f64
32+
overlap={
33+
^bb0(%x: f64, %y: f64):
34+
%cmp = arith.cmpi "uge", %col, %row : index
35+
%upperTriangleResult = arith.addf %x, %y : f64
36+
%lowerTriangleResult = arith.subf %x, %y : f64
37+
%ret = arith.select %cmp, %upperTriangleResult, %lowerTriangleResult : f64
38+
sparse_tensor.yield %ret : f64
39+
}
40+
left=identity
41+
right={
42+
^bb0(%y: f64):
43+
%cmp = arith.cmpi "uge", %col, %row : index
44+
%lowerTriangleResult = arith.negf %y : f64
45+
%ret = arith.select %cmp, %y, %lowerTriangleResult : f64
46+
sparse_tensor.yield %ret : f64
47+
}
48+
linalg.yield %result : f64
49+
} -> tensor<4x4xf64, #SparseMatrix>
50+
return %0 : tensor<4x4xf64, #SparseMatrix>
51+
}
52+
53+
// Driver method to call and verify triangular kernel.
54+
func.func @entry() {
55+
%c0 = arith.constant 0 : index
56+
%du = arith.constant -1.0 : f64
57+
58+
%am = arith.constant dense<
59+
[ [ 1.0, 0.0, 3.0, 0.0],
60+
[ 0.0, 2.0, 0.0, 0.0],
61+
[ 0.0, 0.0, 0.0, 4.0],
62+
[ 3.0, 4.0, 0.0, 0.0] ]> : tensor<4x4xf64>
63+
%bm = arith.constant dense<
64+
[ [ 1.0, 0.0, 1.0, 1.0],
65+
[ 0.0, 0.5, 0.0, 0.0],
66+
[ 1.0, 5.0, 2.0, 0.0],
67+
[ 2.0, 0.0, 0.0, 0.0] ]> : tensor<4x4xf64>
68+
69+
%a = sparse_tensor.convert %am : tensor<4x4xf64> to tensor<4x4xf64, #SparseMatrix>
70+
%b = sparse_tensor.convert %bm : tensor<4x4xf64> to tensor<4x4xf64, #SparseMatrix>
71+
%0 = call @triangular(%a, %b) : (tensor<4x4xf64, #SparseMatrix>,
72+
tensor<4x4xf64, #SparseMatrix>) -> tensor<4x4xf64, #SparseMatrix>
73+
74+
//
75+
// Verify the results.
76+
//
77+
// CHECK: ( ( 2, 0, 4, 1 ), ( 0, 2.5, 0, 0 ), ( -1, -5, 2, 4 ), ( 1, 4, 0, 0 ) )
78+
// CHECK-NEXST: ( 2, 4, 1, 2.5, -1, -5, 2, 4, 1, 4, -1, -1, -1, -1, -1, -1 )
79+
//
80+
%c = sparse_tensor.convert %0 : tensor<4x4xf64, #SparseMatrix> to tensor<4x4xf64>
81+
%m = bufferization.to_memref %c : memref<4x4xf64>
82+
%v = vector.transfer_read %m[%c0, %c0], %du: memref<4x4xf64>, vector<4x4xf64>
83+
vector.print %v : vector<4x4xf64>
84+
%1 = sparse_tensor.values %0 : tensor<4x4xf64, #SparseMatrix> to memref<?xf64>
85+
%2 = vector.transfer_read %1[%c0], %du: memref<?xf64>, vector<16xf64>
86+
vector.print %2 : vector<16xf64>
87+
88+
// Release the resources.
89+
memref.dealloc %m : memref<4x4xf64>
90+
sparse_tensor.release %a : tensor<4x4xf64, #SparseMatrix>
91+
sparse_tensor.release %b : tensor<4x4xf64, #SparseMatrix>
92+
sparse_tensor.release %0 : tensor<4x4xf64, #SparseMatrix>
93+
return
94+
}
95+
}

0 commit comments

Comments
 (0)