-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] make sparse compiler more admissible. #90927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sparse Author: Peiming Liu (PeimingLiu) ChangesFull diff: https://github.com/llvm/llvm-project/pull/90927.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 89fb4944c0ca3c..ad313c2d5ce603 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -432,12 +432,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
Operation *producer = opOperand.get().getDefiningOp();
- // Do not fuse a sparse-in/dense-out operation, as the
- // result is too often not sparsifiable anymore.
- if (sparse_tensor::hasAnySparseOperand(producer) &&
- !sparse_tensor::hasAnySparseResult(producer))
- return failure();
-
// Find the producer of the operand.
FailureOr<ElementwiseOpFusionResult> fusionResult =
fuseElementwiseOps(rewriter, &opOperand);
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 9c0aed3c18eff2..308fbd965259db 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1356,50 +1356,54 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// See buildLattices() for an explanation of rejecting certain
// division and shift operations.
if (def->getNumOperands() == 2) {
- const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
- const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
- bool hasSpDep = xDepSp || yDepSp;
+ const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
+ const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
+ // For a conjunctive operation, it yields a "sparse" result if any operand
+ // is sparse. For a disjunctive operation, it yields a "sparse" result if
+ // all operands are sparse.
+ bool conjSpVals = xSpVals || ySpVals;
+ bool disjSpVals = xSpVals && ySpVals;
if (x.has_value() && y.has_value()) {
const ExprId e0 = *x;
const ExprId e1 = *y;
if (isa<arith::MulFOp>(def))
- return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kMulF, e0, e1), conjSpVals};
if (isa<complex::MulOp>(def))
- return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kMulC, e0, e1), conjSpVals};
if (isa<arith::MulIOp>(def))
- return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kMulI, e0, e1), conjSpVals};
if (isa<arith::DivFOp>(def) && !maybeZero(e1))
- return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kDivF, e0, e1), conjSpVals};
if (isa<complex::DivOp>(def) && !maybeZero(e1))
- return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kDivC, e0, e1), conjSpVals};
if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
- return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kDivS, e0, e1), conjSpVals};
if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
- return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
if (isa<arith::AddFOp>(def))
- return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
if (isa<complex::AddOp>(def))
- return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
if (isa<arith::AddIOp>(def))
- return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
if (isa<arith::SubFOp>(def))
- return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
if (isa<complex::SubOp>(def))
- return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
if (isa<arith::SubIOp>(def))
- return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
if (isa<arith::AndIOp>(def))
- return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
if (isa<arith::OrIOp>(def))
- return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
if (isa<arith::XOrIOp>(def))
- return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
- return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
- return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kShrU, e0, e1), conjSpVals};
if (isa<arith::ShLIOp>(def) && isInvariant(e1))
- return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kShlI, e0, e1), conjSpVals};
if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
if (ci.getPredicate() == arith::CmpIPredicate::eq &&
ci.getPredicate() == arith::CmpIPredicate::sle &&
@@ -1413,7 +1417,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
ci.getPredicateAttr());
- return {e, hasSpDep};
+ return {e, conjSpVals};
}
if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
@@ -1431,7 +1435,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
}
auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
cf.getPredicateAttr());
- return {e, hasSpDep};
+ return {e, conjSpVals};
}
if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
@@ -1439,7 +1443,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
isAdmissibleBranch(binop, binop.getLeftRegion())) &&
(binop.getRightIdentity() ||
isAdmissibleBranch(binop, binop.getRightRegion())))
- return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
+ return {addExp(TensorExp::Kind::kBinary, e0, e1, def), conjSpVals};
}
}
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
index 8780baac199e16..2cc64434a1d8f2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --linalg-fuse-elementwise-ops | FileCheck %s
+// RUN: mlir-opt %s --linalg-fuse-elementwise-ops --sparse-reinterpret-map --sparsification | FileCheck %s
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
@@ -11,22 +11,59 @@
doc = "B(i) = OP A(i)"
}
-// CHECK-LABEL: func @sparse_fusion
-// CHECK: linalg.generic
-// CHECK: arith.addf
-// CHECK: linalg.generic
-// CHECK: math.exp
-// CHECK: arith.maximumf
-// CHECK-NOT: linalg.generic
-// CHECK: return
+
+// CHECK-LABEL: func.func @sparse_fusion(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<100xf64, #sparse>) -> tensor<100xf64> {
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant true
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 100 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1.000000e+00 : f64
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1.000000e+02 : f64
+// CHECK-DAG: %[[VAL_8:.*]] = tensor.empty() : tensor<100xf64>
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
+// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
+// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<100xf64, #sparse> to memref<?xf64>
+// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_8]] : memref<100xf64>
+// CHECK: linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_12]] : memref<100xf64>)
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_3]]) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
+// CHECK: scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK: scf.if %[[VAL_22]] {
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xf64>
+// CHECK: %[[VAL_24:.*]] = arith.addf %[[VAL_23]], %[[VAL_6]] : f64
+// CHECK: %[[VAL_25:.*]] = math.exp %[[VAL_24]] : f64
+// CHECK: %[[VAL_26:.*]] = arith.maximumf %[[VAL_25]], %[[VAL_7]] : f64
+// CHECK: memref.store %[[VAL_26]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
+// CHECK: } else {
+// CHECK: scf.if %[[VAL_1]] {
+// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
+// CHECK: } else {
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_2]] : index
+// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_19]] : index
+// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_20]], %[[VAL_2]] : index
+// CHECK: scf.yield %[[VAL_29]], %[[VAL_30]] : index, index
+// CHECK: }
+// CHECK: scf.for %[[VAL_31:.*]] = %[[VAL_32:.*]]#1 to %[[VAL_5]] step %[[VAL_2]] {
+// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_31]]] : memref<100xf64>
+// CHECK: }
+// CHECK: %[[VAL_33:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<100xf64>
+// CHECK: return %[[VAL_33]] : tensor<100xf64>
+// CHECK: }
func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
%c1 = arith.constant 1.0 : f64
%c100 = arith.constant 100.0 : f64
- //
- // Densifying op.
- // Should not be fused with subsequent dense ops.
- //
%t0 = tensor.empty() : tensor<100xf64>
%l0 = linalg.generic #trait
ins(%argA: tensor<100xf64, #SV>) outs(%t0: tensor<100xf64>) {
@@ -34,12 +71,6 @@ func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
%b0 = arith.addf %in0, %c1 : f64
linalg.yield %b0 : f64
} -> tensor<100xf64>
-
-
- //
- // Two following dense ops.
- // Should be fused, but not with above.
- //
%t1 = tensor.empty() : tensor<100xf64>
%l1 = linalg.generic #trait
ins(%l0: tensor<100xf64>) outs(%t1: tensor<100xf64>) {
|
@llvm/pr-subscribers-mlir-linalg Author: Peiming Liu (PeimingLiu) ChangesFull diff: https://github.com/llvm/llvm-project/pull/90927.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 89fb4944c0ca3c..ad313c2d5ce603 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -432,12 +432,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
Operation *producer = opOperand.get().getDefiningOp();
- // Do not fuse a sparse-in/dense-out operation, as the
- // result is too often not sparsifiable anymore.
- if (sparse_tensor::hasAnySparseOperand(producer) &&
- !sparse_tensor::hasAnySparseResult(producer))
- return failure();
-
// Find the producer of the operand.
FailureOr<ElementwiseOpFusionResult> fusionResult =
fuseElementwiseOps(rewriter, &opOperand);
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 9c0aed3c18eff2..308fbd965259db 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1356,50 +1356,54 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// See buildLattices() for an explanation of rejecting certain
// division and shift operations.
if (def->getNumOperands() == 2) {
- const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
- const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
- bool hasSpDep = xDepSp || yDepSp;
+ const auto [x, xSpVals] = buildTensorExp(op, def->getOperand(0));
+ const auto [y, ySpVals] = buildTensorExp(op, def->getOperand(1));
+ // For a conjunctive operation, it yields a "sparse" result if any operand
+ // is sparse. For a disjunctive operation, it yields a "sparse" result if
+ // all operands are sparse.
+ bool conjSpVals = xSpVals || ySpVals;
+ bool disjSpVals = xSpVals && ySpVals;
if (x.has_value() && y.has_value()) {
const ExprId e0 = *x;
const ExprId e1 = *y;
if (isa<arith::MulFOp>(def))
- return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kMulF, e0, e1), conjSpVals};
if (isa<complex::MulOp>(def))
- return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kMulC, e0, e1), conjSpVals};
if (isa<arith::MulIOp>(def))
- return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kMulI, e0, e1), conjSpVals};
if (isa<arith::DivFOp>(def) && !maybeZero(e1))
- return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kDivF, e0, e1), conjSpVals};
if (isa<complex::DivOp>(def) && !maybeZero(e1))
- return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kDivC, e0, e1), conjSpVals};
if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
- return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kDivS, e0, e1), conjSpVals};
if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
- return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
if (isa<arith::AddFOp>(def))
- return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
if (isa<complex::AddOp>(def))
- return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
if (isa<arith::AddIOp>(def))
- return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
if (isa<arith::SubFOp>(def))
- return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
if (isa<complex::SubOp>(def))
- return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
if (isa<arith::SubIOp>(def))
- return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
if (isa<arith::AndIOp>(def))
- return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
if (isa<arith::OrIOp>(def))
- return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
if (isa<arith::XOrIOp>(def))
- return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
- return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
- return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kShrU, e0, e1), conjSpVals};
if (isa<arith::ShLIOp>(def) && isInvariant(e1))
- return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
+ return {addExp(TensorExp::Kind::kShlI, e0, e1), conjSpVals};
if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
if (ci.getPredicate() == arith::CmpIPredicate::eq &&
ci.getPredicate() == arith::CmpIPredicate::sle &&
@@ -1413,7 +1417,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
ci.getPredicateAttr());
- return {e, hasSpDep};
+ return {e, conjSpVals};
}
if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
@@ -1431,7 +1435,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
}
auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
cf.getPredicateAttr());
- return {e, hasSpDep};
+ return {e, conjSpVals};
}
if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
@@ -1439,7 +1443,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
isAdmissibleBranch(binop, binop.getLeftRegion())) &&
(binop.getRightIdentity() ||
isAdmissibleBranch(binop, binop.getRightRegion())))
- return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
+ return {addExp(TensorExp::Kind::kBinary, e0, e1, def), conjSpVals};
}
}
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
index 8780baac199e16..2cc64434a1d8f2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --linalg-fuse-elementwise-ops | FileCheck %s
+// RUN: mlir-opt %s --linalg-fuse-elementwise-ops --sparse-reinterpret-map --sparsification | FileCheck %s
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
@@ -11,22 +11,59 @@
doc = "B(i) = OP A(i)"
}
-// CHECK-LABEL: func @sparse_fusion
-// CHECK: linalg.generic
-// CHECK: arith.addf
-// CHECK: linalg.generic
-// CHECK: math.exp
-// CHECK: arith.maximumf
-// CHECK-NOT: linalg.generic
-// CHECK: return
+
+// CHECK-LABEL: func.func @sparse_fusion(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<100xf64, #sparse>) -> tensor<100xf64> {
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant true
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 100 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1.000000e+00 : f64
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1.000000e+02 : f64
+// CHECK-DAG: %[[VAL_8:.*]] = tensor.empty() : tensor<100xf64>
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
+// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
+// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<100xf64, #sparse> to memref<?xf64>
+// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_8]] : memref<100xf64>
+// CHECK: linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_12]] : memref<100xf64>)
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_3]]) : (index, index) -> (index, index) {
+// CHECK: %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
+// CHECK: scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK: scf.if %[[VAL_22]] {
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xf64>
+// CHECK: %[[VAL_24:.*]] = arith.addf %[[VAL_23]], %[[VAL_6]] : f64
+// CHECK: %[[VAL_25:.*]] = math.exp %[[VAL_24]] : f64
+// CHECK: %[[VAL_26:.*]] = arith.maximumf %[[VAL_25]], %[[VAL_7]] : f64
+// CHECK: memref.store %[[VAL_26]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
+// CHECK: } else {
+// CHECK: scf.if %[[VAL_1]] {
+// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
+// CHECK: } else {
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
+// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_2]] : index
+// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_19]] : index
+// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_20]], %[[VAL_2]] : index
+// CHECK: scf.yield %[[VAL_29]], %[[VAL_30]] : index, index
+// CHECK: }
+// CHECK: scf.for %[[VAL_31:.*]] = %[[VAL_32:.*]]#1 to %[[VAL_5]] step %[[VAL_2]] {
+// CHECK: memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_31]]] : memref<100xf64>
+// CHECK: }
+// CHECK: %[[VAL_33:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<100xf64>
+// CHECK: return %[[VAL_33]] : tensor<100xf64>
+// CHECK: }
func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
%c1 = arith.constant 1.0 : f64
%c100 = arith.constant 100.0 : f64
- //
- // Densifying op.
- // Should not be fused with subsequent dense ops.
- //
%t0 = tensor.empty() : tensor<100xf64>
%l0 = linalg.generic #trait
ins(%argA: tensor<100xf64, #SV>) outs(%t0: tensor<100xf64>) {
@@ -34,12 +71,6 @@ func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
%b0 = arith.addf %in0, %c1 : f64
linalg.yield %b0 : f64
} -> tensor<100xf64>
-
-
- //
- // Two following dense ops.
- // Should be fused, but not with above.
- //
%t1 = tensor.empty() : tensor<100xf64>
%l1 = linalg.generic #trait
ins(%l0: tensor<100xf64>) outs(%t1: tensor<100xf64>) {
|
aartbik
approved these changes
May 3, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.