Skip to content

[mlir][sparse] make sparse compiler more admissible. #90927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2024

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented May 3, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/90927.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (-6)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+29-25)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fusion.mlir (+50-19)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 89fb4944c0ca3c..ad313c2d5ce603 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -432,12 +432,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 
       Operation *producer = opOperand.get().getDefiningOp();
 
-      // Do not fuse a sparse-in/dense-out operation, as the
-      // result is too often not sparsifiable anymore.
-      if (sparse_tensor::hasAnySparseOperand(producer) &&
-          !sparse_tensor::hasAnySparseResult(producer))
-        return failure();
-
       // Find the producer of the operand.
       FailureOr<ElementwiseOpFusionResult> fusionResult =
           fuseElementwiseOps(rewriter, &opOperand);
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 9c0aed3c18eff2..308fbd965259db 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1356,50 +1356,54 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   // See buildLattices() for an explanation of rejecting certain
   // division and shift operations.
   if (def->getNumOperands() == 2) {
-    const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
-    const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
-    bool hasSpDep = xDepSp || yDepSp;
+    const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
+    const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
+    // For a conjunctive operation, it yields a "sparse" result if any operand
+    // is sparse. For a disjunctive operation, it yields a "sparse" result if
+    // all operands are sparse.
+    bool conjSpVals = xSpVals || ySpVals;
+    bool disjSpVals = xSpVals && ySpVals;
     if (x.has_value() && y.has_value()) {
       const ExprId e0 = *x;
       const ExprId e1 = *y;
       if (isa<arith::MulFOp>(def))
-        return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kMulF, e0, e1), conjSpVals};
       if (isa<complex::MulOp>(def))
-        return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kMulC, e0, e1), conjSpVals};
       if (isa<arith::MulIOp>(def))
-        return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kMulI, e0, e1), conjSpVals};
       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
-        return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kDivF, e0, e1), conjSpVals};
       if (isa<complex::DivOp>(def) && !maybeZero(e1))
-        return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kDivC, e0, e1), conjSpVals};
       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
-        return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kDivS, e0, e1), conjSpVals};
       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
-        return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
       if (isa<arith::AddFOp>(def))
-        return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
       if (isa<complex::AddOp>(def))
-        return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
       if (isa<arith::AddIOp>(def))
-        return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
       if (isa<arith::SubFOp>(def))
-        return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
       if (isa<complex::SubOp>(def))
-        return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
       if (isa<arith::SubIOp>(def))
-        return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
       if (isa<arith::AndIOp>(def))
-        return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
       if (isa<arith::OrIOp>(def))
-        return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
       if (isa<arith::XOrIOp>(def))
-        return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
-        return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
-        return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kShrU, e0, e1), conjSpVals};
       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
-        return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kShlI, e0, e1), conjSpVals};
       if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
         if (ci.getPredicate() == arith::CmpIPredicate::eq &&
             ci.getPredicate() == arith::CmpIPredicate::sle &&
@@ -1413,7 +1417,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
 
         auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
                         ci.getPredicateAttr());
-        return {e, hasSpDep};
+        return {e, conjSpVals};
       }
       if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
         if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
@@ -1431,7 +1435,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         }
         auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
                         cf.getPredicateAttr());
-        return {e, hasSpDep};
+        return {e, conjSpVals};
       }
       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
         if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
@@ -1439,7 +1443,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
              isAdmissibleBranch(binop, binop.getLeftRegion())) &&
             (binop.getRightIdentity() ||
              isAdmissibleBranch(binop, binop.getRightRegion())))
-          return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
+          return {addExp(TensorExp::Kind::kBinary, e0, e1, def), conjSpVals};
       }
     }
   }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
index 8780baac199e16..2cc64434a1d8f2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --linalg-fuse-elementwise-ops | FileCheck %s
+// RUN: mlir-opt %s --linalg-fuse-elementwise-ops --sparse-reinterpret-map --sparsification | FileCheck %s
 
 #SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 
@@ -11,22 +11,59 @@
   doc = "B(i) = OP A(i)"
 }
 
-// CHECK-LABEL: func @sparse_fusion
-// CHECK:     linalg.generic
-// CHECK:       arith.addf
-// CHECK:     linalg.generic
-// CHECK:       math.exp
-// CHECK:       arith.maximumf
-// CHECK-NOT: linalg.generic
-// CHECK:     return
+
+// CHECK-LABEL:   func.func @sparse_fusion(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<100xf64, #sparse>) -> tensor<100xf64> {
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant true
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 100 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1.000000e+00 : f64
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1.000000e+02 : f64
+// CHECK-DAG:       %[[VAL_8:.*]] = tensor.empty() : tensor<100xf64>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<100xf64, #sparse> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_8]] : memref<100xf64>
+// CHECK:           linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_12]] : memref<100xf64>)
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_3]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
+// CHECK:             scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
+// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK:             %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK:             scf.if %[[VAL_22]] {
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xf64>
+// CHECK:               %[[VAL_24:.*]] = arith.addf %[[VAL_23]], %[[VAL_6]] : f64
+// CHECK:               %[[VAL_25:.*]] = math.exp %[[VAL_24]] : f64
+// CHECK:               %[[VAL_26:.*]] = arith.maximumf %[[VAL_25]], %[[VAL_7]] : f64
+// CHECK:               memref.store %[[VAL_26]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_1]] {
+// CHECK:                 memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK:             %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_2]] : index
+// CHECK:             %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_30:.*]] = arith.addi %[[VAL_20]], %[[VAL_2]] : index
+// CHECK:             scf.yield %[[VAL_29]], %[[VAL_30]] : index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_31:.*]] = %[[VAL_32:.*]]#1 to %[[VAL_5]] step %[[VAL_2]] {
+// CHECK:             memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_31]]] : memref<100xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_33:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<100xf64>
+// CHECK:           return %[[VAL_33]] : tensor<100xf64>
+// CHECK:         }
 func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
   %c1 = arith.constant 1.0 : f64
   %c100 = arith.constant 100.0 : f64
 
-  //
-  // Densifying op.
-  // Should not be fused with subsequent dense ops.
-  //
   %t0 = tensor.empty() : tensor<100xf64>
   %l0 = linalg.generic #trait
       ins(%argA: tensor<100xf64, #SV>) outs(%t0: tensor<100xf64>) {
@@ -34,12 +71,6 @@ func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
       %b0 = arith.addf %in0, %c1 : f64
       linalg.yield %b0 : f64
   } -> tensor<100xf64>
-
-
-  //
-  // Two following dense ops.
-  // Should be fused, but not with above.
-  //
   %t1 = tensor.empty() : tensor<100xf64>
   %l1 = linalg.generic #trait
       ins(%l0: tensor<100xf64>) outs(%t1: tensor<100xf64>) {

@llvmbot
Copy link
Member

llvmbot commented May 3, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/90927.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (-6)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+29-25)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fusion.mlir (+50-19)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 89fb4944c0ca3c..ad313c2d5ce603 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -432,12 +432,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 
       Operation *producer = opOperand.get().getDefiningOp();
 
-      // Do not fuse a sparse-in/dense-out operation, as the
-      // result is too often not sparsifiable anymore.
-      if (sparse_tensor::hasAnySparseOperand(producer) &&
-          !sparse_tensor::hasAnySparseResult(producer))
-        return failure();
-
       // Find the producer of the operand.
       FailureOr<ElementwiseOpFusionResult> fusionResult =
           fuseElementwiseOps(rewriter, &opOperand);
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 9c0aed3c18eff2..308fbd965259db 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1356,50 +1356,54 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   // See buildLattices() for an explanation of rejecting certain
   // division and shift operations.
   if (def->getNumOperands() == 2) {
-    const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
-    const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
-    bool hasSpDep = xDepSp || yDepSp;
+    const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
+    const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
+    // For a conjunctive operation, it yields a "sparse" result if any operand
+    // is sparse. For a disjunctive operation, it yields a "sparse" result if
+    // all operands are sparse.
+    bool conjSpVals = xSpVals || ySpVals;
+    bool disjSpVals = xSpVals && ySpVals;
     if (x.has_value() && y.has_value()) {
       const ExprId e0 = *x;
       const ExprId e1 = *y;
       if (isa<arith::MulFOp>(def))
-        return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kMulF, e0, e1), conjSpVals};
       if (isa<complex::MulOp>(def))
-        return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kMulC, e0, e1), conjSpVals};
       if (isa<arith::MulIOp>(def))
-        return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kMulI, e0, e1), conjSpVals};
       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
-        return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kDivF, e0, e1), conjSpVals};
       if (isa<complex::DivOp>(def) && !maybeZero(e1))
-        return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kDivC, e0, e1), conjSpVals};
       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
-        return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kDivS, e0, e1), conjSpVals};
       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
-        return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
       if (isa<arith::AddFOp>(def))
-        return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
       if (isa<complex::AddOp>(def))
-        return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
       if (isa<arith::AddIOp>(def))
-        return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
       if (isa<arith::SubFOp>(def))
-        return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
       if (isa<complex::SubOp>(def))
-        return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
       if (isa<arith::SubIOp>(def))
-        return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
       if (isa<arith::AndIOp>(def))
-        return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
       if (isa<arith::OrIOp>(def))
-        return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
       if (isa<arith::XOrIOp>(def))
-        return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
-        return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
-        return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kShrU, e0, e1), conjSpVals};
       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
-        return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
+        return {addExp(TensorExp::Kind::kShlI, e0, e1), conjSpVals};
       if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
         if (ci.getPredicate() == arith::CmpIPredicate::eq &&
             ci.getPredicate() == arith::CmpIPredicate::sle &&
@@ -1413,7 +1417,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
 
         auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
                         ci.getPredicateAttr());
-        return {e, hasSpDep};
+        return {e, conjSpVals};
       }
       if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
         if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
@@ -1431,7 +1435,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         }
         auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
                         cf.getPredicateAttr());
-        return {e, hasSpDep};
+        return {e, conjSpVals};
       }
       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
         if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
@@ -1439,7 +1443,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
              isAdmissibleBranch(binop, binop.getLeftRegion())) &&
             (binop.getRightIdentity() ||
              isAdmissibleBranch(binop, binop.getRightRegion())))
-          return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
+          return {addExp(TensorExp::Kind::kBinary, e0, e1, def), conjSpVals};
       }
     }
   }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
index 8780baac199e16..2cc64434a1d8f2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --linalg-fuse-elementwise-ops | FileCheck %s
+// RUN: mlir-opt %s --linalg-fuse-elementwise-ops --sparse-reinterpret-map --sparsification | FileCheck %s
 
 #SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 
@@ -11,22 +11,59 @@
   doc = "B(i) = OP A(i)"
 }
 
-// CHECK-LABEL: func @sparse_fusion
-// CHECK:     linalg.generic
-// CHECK:       arith.addf
-// CHECK:     linalg.generic
-// CHECK:       math.exp
-// CHECK:       arith.maximumf
-// CHECK-NOT: linalg.generic
-// CHECK:     return
+
+// CHECK-LABEL:   func.func @sparse_fusion(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<100xf64, #sparse>) -> tensor<100xf64> {
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant true
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 100 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1.000000e+00 : f64
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1.000000e+02 : f64
+// CHECK-DAG:       %[[VAL_8:.*]] = tensor.empty() : tensor<100xf64>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<100xf64, #sparse> to memref<?xf64>
+// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_8]] : memref<100xf64>
+// CHECK:           linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_12]] : memref<100xf64>)
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_3]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
+// CHECK:             scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
+// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK:             %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK:             scf.if %[[VAL_22]] {
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xf64>
+// CHECK:               %[[VAL_24:.*]] = arith.addf %[[VAL_23]], %[[VAL_6]] : f64
+// CHECK:               %[[VAL_25:.*]] = math.exp %[[VAL_24]] : f64
+// CHECK:               %[[VAL_26:.*]] = arith.maximumf %[[VAL_25]], %[[VAL_7]] : f64
+// CHECK:               memref.store %[[VAL_26]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_1]] {
+// CHECK:                 memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK:             %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_2]] : index
+// CHECK:             %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_30:.*]] = arith.addi %[[VAL_20]], %[[VAL_2]] : index
+// CHECK:             scf.yield %[[VAL_29]], %[[VAL_30]] : index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_31:.*]] = %[[VAL_32:.*]]#1 to %[[VAL_5]] step %[[VAL_2]] {
+// CHECK:             memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_31]]] : memref<100xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_33:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<100xf64>
+// CHECK:           return %[[VAL_33]] : tensor<100xf64>
+// CHECK:         }
 func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
   %c1 = arith.constant 1.0 : f64
   %c100 = arith.constant 100.0 : f64
 
-  //
-  // Densifying op.
-  // Should not be fused with subsequent dense ops.
-  //
   %t0 = tensor.empty() : tensor<100xf64>
   %l0 = linalg.generic #trait
       ins(%argA: tensor<100xf64, #SV>) outs(%t0: tensor<100xf64>) {
@@ -34,12 +71,6 @@ func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
       %b0 = arith.addf %in0, %c1 : f64
       linalg.yield %b0 : f64
   } -> tensor<100xf64>
-
-
-  //
-  // Two following dense ops.
-  // Should be fused, but not with above.
-  //
   %t1 = tensor.empty() : tensor<100xf64>
   %l1 = linalg.generic #trait
       ins(%l0: tensor<100xf64>) outs(%t1: tensor<100xf64>) {

@PeimingLiu PeimingLiu merged commit fc83eda into llvm:main May 3, 2024
@PeimingLiu PeimingLiu deleted the handle-pad branch May 3, 2024 01:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:linalg mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants