Skip to content

Commit 5c03c05

Browse files
committed
[mlir][sparse] enhance element-wise fusion heuristics
We prevent merging a sparse-in/dense-out with dense-in kernels because the result is usuall not sparsifiable. Dense kernels and sparse kernels are still fused, obviously. Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D153077
1 parent 9167dd4 commit 5c03c05

File tree

3 files changed

+81
-5
lines changed

3 files changed

+81
-5
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,23 @@ RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src,
140140

141141
RankedTensorType getCOOFromType(RankedTensorType src, bool ordered);
142142

143-
/// Returns true iff MLIR operand has any sparse operand or result.
144-
inline bool hasAnySparseOperandOrResult(Operation *op) {
145-
bool anySparseIn = llvm::any_of(op->getOperands().getTypes(), [](Type t) {
143+
/// Returns true iff MLIR operand has any sparse operand.
144+
inline bool hasAnySparseOperand(Operation *op) {
145+
return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
146146
return getSparseTensorEncoding(t) != nullptr;
147147
});
148-
bool anySparseOut = llvm::any_of(op->getResults().getTypes(), [](Type t) {
148+
}
149+
150+
/// Returns true iff MLIR operand has any sparse result.
151+
inline bool hasAnySparseResult(Operation *op) {
152+
return llvm::any_of(op->getResults().getTypes(), [](Type t) {
149153
return getSparseTensorEncoding(t) != nullptr;
150154
});
151-
return anySparseIn || anySparseOut;
155+
}
156+
157+
/// Returns true iff MLIR operand has any sparse operand or result.
158+
inline bool hasAnySparseOperandOrResult(Operation *op) {
159+
return hasAnySparseOperand(op) || hasAnySparseResult(op);
152160
}
153161

154162
//

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,11 +422,20 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
422422
if (!controlFn(&opOperand))
423423
continue;
424424

425+
// Find the producer of the operand.
425426
FailureOr<ElementwiseOpFusionResult> fusionResult =
426427
fuseElementwiseOps(rewriter, &opOperand);
427428
if (failed(fusionResult))
428429
return rewriter.notifyMatchFailure(genericOp, "fusion failed");
429430
Operation *producer = opOperand.get().getDefiningOp();
431+
432+
// Do not fuse a sparse-in/dense-out operation, as the
433+
// result is too often not sparsifiable anymore.
434+
if (sparse_tensor::hasAnySparseOperand(producer) &&
435+
!sparse_tensor::hasAnySparseResult(producer))
436+
return failure();
437+
438+
// Perform the fusion.
430439
for (auto [origVal, replacement] : fusionResult->replacements) {
431440
rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
432441
// Only replace consumer uses.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// RUN: mlir-opt %s --linalg-fuse-elementwise-ops | FileCheck %s
2+
3+
#SV = #sparse_tensor.encoding<{ lvlTypes = ["compressed"] }>
4+
5+
#trait = {
6+
indexing_maps = [
7+
affine_map<(i) -> (i)>, // A
8+
affine_map<(i) -> (i)> // B (out)
9+
],
10+
iterator_types = ["parallel"],
11+
doc = "B(i) = OP A(i)"
12+
}
13+
14+
// CHECK-LABEL: func @sparse_fusion
15+
// CHECK: linalg.generic
16+
// CHECK: arith.addf
17+
// CHECK: linalg.generic
18+
// CHECK: math.exp
19+
// CHECK: arith.maxf
20+
// CHECK-NOT: linalg.generic
21+
// CHECK: return
22+
func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
23+
%c1 = arith.constant 1.0 : f64
24+
%c100 = arith.constant 100.0 : f64
25+
26+
//
27+
// Densifying op.
28+
// Should not be fused with subsequent dense ops.
29+
//
30+
%t0 = tensor.empty() : tensor<100xf64>
31+
%l0 = linalg.generic #trait
32+
ins(%argA: tensor<100xf64, #SV>) outs(%t0: tensor<100xf64>) {
33+
^bb0(%in0: f64, %out0: f64):
34+
%b0 = arith.addf %in0, %c1 : f64
35+
linalg.yield %b0 : f64
36+
} -> tensor<100xf64>
37+
38+
39+
//
40+
// Two following dense ops.
41+
// Should be fused, but not with above.
42+
//
43+
%t1 = tensor.empty() : tensor<100xf64>
44+
%l1 = linalg.generic #trait
45+
ins(%l0: tensor<100xf64>) outs(%t1: tensor<100xf64>) {
46+
^bb0(%in1: f64, %out1: f64):
47+
%b1 = math.exp %in1 : f64
48+
linalg.yield %b1 : f64
49+
} -> tensor<100xf64>
50+
%t2 = tensor.empty() : tensor<100xf64>
51+
%l2 = linalg.generic #trait
52+
ins(%l1: tensor<100xf64>) outs(%t2: tensor<100xf64>) {
53+
^bb0(%in2: f64, %out2: f64):
54+
%b2 = arith.maxf %in2, %c100 : f64
55+
linalg.yield %b2 : f64
56+
} -> tensor<100xf64>
57+
58+
return %l2 : tensor<100xf64>
59+
}

0 commit comments

Comments
 (0)