Skip to content

Commit 8e4f8d3

Browse files
committed
[mlir][sparse] merge ifs in new sparse rewriting rules
Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D120500
1 parent 180c9f9 commit 8e4f8d3

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,8 @@ static bool isSparseTensor(OpOperand *op) {
4545
// Helper method to find zero or empty initialization.
4646
static bool isEmptyInit(OpOperand *op) {
4747
Value val = op->get();
48-
if (matchPattern(val, m_Zero()))
49-
return true;
50-
if (matchPattern(val, m_AnyZeroFloat()))
51-
return true;
52-
if (val.getDefiningOp<InitTensorOp>())
53-
return true;
54-
if (val.getDefiningOp<InitOp>())
55-
return true;
56-
return false;
48+
return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) ||
49+
val.getDefiningOp<InitTensorOp>() || val.getDefiningOp<InitOp>();
5750
}
5851

5952
// Helper to detect sampling operation.
@@ -123,11 +116,9 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
123116
PatternRewriter &rewriter) const override {
124117
// Check consumer.
125118
if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
126-
op.getNumResults() != 1)
127-
return failure();
128-
if (op.getNumParallelLoops() != op.getNumLoops())
129-
return failure();
130-
if (!op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
119+
op.getNumResults() != 1 ||
120+
op.getNumParallelLoops() != op.getNumLoops() ||
121+
!op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
131122
!op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() ||
132123
!op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity())
133124
return failure();
@@ -143,15 +134,13 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
143134
// Check producer.
144135
auto prod = dyn_cast_or_null<GenericOp>(
145136
op.getInputOperand(other)->get().getDefiningOp());
146-
if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1)
147-
return failure();
148-
if (!prod.getResult(0).hasOneUse())
137+
if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
138+
!prod.getResult(0).hasOneUse())
149139
return failure();
150140
// Sampling consumer and sum of multiplication chain producer.
151141
if (!isEmptyInit(op.getOutputOperand(0)) ||
152-
!isEmptyInit(prod.getOutputOperand(0)))
153-
return failure();
154-
if (!isSampling(op) || !isSumOfMul(prod))
142+
!isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) ||
143+
!isSumOfMul(prod))
155144
return failure();
156145
// Modify operand structure of producer and consumer.
157146
Location loc = prod.getLoc();

0 commit comments

Comments
 (0)