Skip to content

Commit fc83eda

Browse files
author
Peiming Liu
authored
[mlir][sparse] make sparse compiler more admissible. (#90927)
1 parent fd3e7e3 commit fc83eda

File tree

3 files changed

+79
-50
lines changed

3 files changed

+79
-50
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -432,12 +432,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
432432

433433
Operation *producer = opOperand.get().getDefiningOp();
434434

435-
// Do not fuse a sparse-in/dense-out operation, as the
436-
// result is too often not sparsifiable anymore.
437-
if (sparse_tensor::hasAnySparseOperand(producer) &&
438-
!sparse_tensor::hasAnySparseResult(producer))
439-
return failure();
440-
441435
// Find the producer of the operand.
442436
FailureOr<ElementwiseOpFusionResult> fusionResult =
443437
fuseElementwiseOps(rewriter, &opOperand);

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,50 +1356,54 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
13561356
// See buildLattices() for an explanation of rejecting certain
13571357
// division and shift operations.
13581358
if (def->getNumOperands() == 2) {
1359-
const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1360-
const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1361-
bool hasSpDep = xDepSp || yDepSp;
1359+
const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
1360+
const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
1361+
// For a conjunctive operation, it yields a "sparse" result if any operand
1362+
// is sparse. For a disjunctive operation, it yields a "sparse" result if
1363+
// all operands are sparse.
1364+
bool conjSpVals = xSpVals || ySpVals;
1365+
bool disjSpVals = xSpVals && ySpVals;
13621366
if (x.has_value() && y.has_value()) {
13631367
const ExprId e0 = *x;
13641368
const ExprId e1 = *y;
13651369
if (isa<arith::MulFOp>(def))
1366-
return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
1370+
return {addExp(TensorExp::Kind::kMulF, e0, e1), conjSpVals};
13671371
if (isa<complex::MulOp>(def))
1368-
return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
1372+
return {addExp(TensorExp::Kind::kMulC, e0, e1), conjSpVals};
13691373
if (isa<arith::MulIOp>(def))
1370-
return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
1374+
return {addExp(TensorExp::Kind::kMulI, e0, e1), conjSpVals};
13711375
if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1372-
return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
1376+
return {addExp(TensorExp::Kind::kDivF, e0, e1), conjSpVals};
13731377
if (isa<complex::DivOp>(def) && !maybeZero(e1))
1374-
return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
1378+
return {addExp(TensorExp::Kind::kDivC, e0, e1), conjSpVals};
13751379
if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1376-
return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
1380+
return {addExp(TensorExp::Kind::kDivS, e0, e1), conjSpVals};
13771381
if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1378-
return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
1382+
return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
13791383
if (isa<arith::AddFOp>(def))
1380-
return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
1384+
return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
13811385
if (isa<complex::AddOp>(def))
1382-
return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
1386+
return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
13831387
if (isa<arith::AddIOp>(def))
1384-
return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
1388+
return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
13851389
if (isa<arith::SubFOp>(def))
1386-
return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
1390+
return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
13871391
if (isa<complex::SubOp>(def))
1388-
return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
1392+
return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
13891393
if (isa<arith::SubIOp>(def))
1390-
return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
1394+
return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
13911395
if (isa<arith::AndIOp>(def))
1392-
return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
1396+
return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
13931397
if (isa<arith::OrIOp>(def))
1394-
return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
1398+
return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
13951399
if (isa<arith::XOrIOp>(def))
1396-
return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
1400+
return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
13971401
if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1398-
return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
1402+
return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
13991403
if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1400-
return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
1404+
return {addExp(TensorExp::Kind::kShrU, e0, e1), conjSpVals};
14011405
if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1402-
return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
1406+
return {addExp(TensorExp::Kind::kShlI, e0, e1), conjSpVals};
14031407
if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
14041408
if (ci.getPredicate() == arith::CmpIPredicate::eq &&
14051409
ci.getPredicate() == arith::CmpIPredicate::sle &&
@@ -1413,7 +1417,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
14131417

14141418
auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
14151419
ci.getPredicateAttr());
1416-
return {e, hasSpDep};
1420+
return {e, conjSpVals};
14171421
}
14181422
if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
14191423
if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
@@ -1431,15 +1435,15 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
14311435
}
14321436
auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
14331437
cf.getPredicateAttr());
1434-
return {e, hasSpDep};
1438+
return {e, conjSpVals};
14351439
}
14361440
if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
14371441
if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
14381442
(binop.getLeftIdentity() ||
14391443
isAdmissibleBranch(binop, binop.getLeftRegion())) &&
14401444
(binop.getRightIdentity() ||
14411445
isAdmissibleBranch(binop, binop.getRightRegion())))
1442-
return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
1446+
return {addExp(TensorExp::Kind::kBinary, e0, e1, def), conjSpVals};
14431447
}
14441448
}
14451449
}

mlir/test/Dialect/SparseTensor/sparse_fusion.mlir

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --linalg-fuse-elementwise-ops | FileCheck %s
1+
// RUN: mlir-opt %s --linalg-fuse-elementwise-ops --sparse-reinterpret-map --sparsification | FileCheck %s
22

33
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
44

@@ -11,35 +11,66 @@
1111
doc = "B(i) = OP A(i)"
1212
}
1313

14-
// CHECK-LABEL: func @sparse_fusion
15-
// CHECK: linalg.generic
16-
// CHECK: arith.addf
17-
// CHECK: linalg.generic
18-
// CHECK: math.exp
19-
// CHECK: arith.maximumf
20-
// CHECK-NOT: linalg.generic
21-
// CHECK: return
14+
15+
// CHECK-LABEL: func.func @sparse_fusion(
16+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<100xf64, #sparse>) -> tensor<100xf64> {
17+
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant true
18+
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
19+
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
20+
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f64
21+
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 100 : index
22+
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1.000000e+00 : f64
23+
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1.000000e+02 : f64
24+
// CHECK-DAG: %[[VAL_8:.*]] = tensor.empty() : tensor<100xf64>
25+
// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
26+
// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
27+
// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<100xf64, #sparse> to memref<?xf64>
28+
// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_8]] : memref<100xf64>
29+
// CHECK: linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_12]] : memref<100xf64>)
30+
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
31+
// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
32+
// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_3]]) : (index, index) -> (index, index) {
33+
// CHECK: %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
34+
// CHECK: scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
35+
// CHECK: } do {
36+
// CHECK: ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
37+
// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
38+
// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
39+
// CHECK: scf.if %[[VAL_22]] {
40+
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xf64>
41+
// CHECK: %[[VAL_24:.*]] = arith.addf %[[VAL_23]], %[[VAL_6]] : f64
42+
// CHECK: %[[VAL_25:.*]] = math.exp %[[VAL_24]] : f64
43+
// CHECK: %[[VAL_26:.*]] = arith.maximumf %[[VAL_25]], %[[VAL_7]] : f64
44+
// CHECK: memref.store %[[VAL_26]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
45+
// CHECK: } else {
46+
// CHECK: scf.if %[[VAL_1]] {
47+
// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
48+
// CHECK: } else {
49+
// CHECK: }
50+
// CHECK: }
51+
// CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
52+
// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_2]] : index
53+
// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_19]] : index
54+
// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_20]], %[[VAL_2]] : index
55+
// CHECK: scf.yield %[[VAL_29]], %[[VAL_30]] : index, index
56+
// CHECK: }
57+
// CHECK: scf.for %[[VAL_31:.*]] = %[[VAL_32:.*]]#1 to %[[VAL_5]] step %[[VAL_2]] {
58+
// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_31]]] : memref<100xf64>
59+
// CHECK: }
60+
// CHECK: %[[VAL_33:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<100xf64>
61+
// CHECK: return %[[VAL_33]] : tensor<100xf64>
62+
// CHECK: }
2263
func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
2364
%c1 = arith.constant 1.0 : f64
2465
%c100 = arith.constant 100.0 : f64
2566

26-
//
27-
// Densifying op.
28-
// Should not be fused with subsequent dense ops.
29-
//
3067
%t0 = tensor.empty() : tensor<100xf64>
3168
%l0 = linalg.generic #trait
3269
ins(%argA: tensor<100xf64, #SV>) outs(%t0: tensor<100xf64>) {
3370
^bb0(%in0: f64, %out0: f64):
3471
%b0 = arith.addf %in0, %c1 : f64
3572
linalg.yield %b0 : f64
3673
} -> tensor<100xf64>
37-
38-
39-
//
40-
// Two following dense ops.
41-
// Should be fused, but not with above.
42-
//
4374
%t1 = tensor.empty() : tensor<100xf64>
4475
%l1 = linalg.generic #trait
4576
ins(%l0: tensor<100xf64>) outs(%t1: tensor<100xf64>) {

0 commit comments

Comments
 (0)