@@ -45,15 +45,8 @@ static bool isSparseTensor(OpOperand *op) {
45
45
// Helper method to find zero or empty initialization.
46
46
static bool isEmptyInit (OpOperand *op) {
47
47
Value val = op->get ();
48
- if (matchPattern (val, m_Zero ()))
49
- return true ;
50
- if (matchPattern (val, m_AnyZeroFloat ()))
51
- return true ;
52
- if (val.getDefiningOp <InitTensorOp>())
53
- return true ;
54
- if (val.getDefiningOp <InitOp>())
55
- return true ;
56
- return false ;
48
+ return matchPattern (val, m_Zero ()) || matchPattern (val, m_AnyZeroFloat ()) ||
49
+ val.getDefiningOp <InitTensorOp>() || val.getDefiningOp <InitOp>();
57
50
}
58
51
59
52
// Helper to detect sampling operation.
@@ -123,11 +116,9 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
123
116
PatternRewriter &rewriter) const override {
124
117
// Check consumer.
125
118
if (!op.hasTensorSemantics () || op.getNumInputs () != 2 ||
126
- op.getNumResults () != 1 )
127
- return failure ();
128
- if (op.getNumParallelLoops () != op.getNumLoops ())
129
- return failure ();
130
- if (!op.getTiedIndexingMap (op.getOutputOperand (0 )).isIdentity () ||
119
+ op.getNumResults () != 1 ||
120
+ op.getNumParallelLoops () != op.getNumLoops () ||
121
+ !op.getTiedIndexingMap (op.getOutputOperand (0 )).isIdentity () ||
131
122
!op.getTiedIndexingMap (op.getInputOperand (0 )).isIdentity () ||
132
123
!op.getTiedIndexingMap (op.getInputOperand (1 )).isIdentity ())
133
124
return failure ();
@@ -143,15 +134,13 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
143
134
// Check producer.
144
135
auto prod = dyn_cast_or_null<GenericOp>(
145
136
op.getInputOperand (other)->get ().getDefiningOp ());
146
- if (!prod || !prod.hasTensorSemantics () || prod.getNumResults () != 1 )
147
- return failure ();
148
- if (!prod.getResult (0 ).hasOneUse ())
137
+ if (!prod || !prod.hasTensorSemantics () || prod.getNumResults () != 1 ||
138
+ !prod.getResult (0 ).hasOneUse ())
149
139
return failure ();
150
140
// Sampling consumer and sum of multiplication chain producer.
151
141
if (!isEmptyInit (op.getOutputOperand (0 )) ||
152
- !isEmptyInit (prod.getOutputOperand (0 )))
153
- return failure ();
154
- if (!isSampling (op) || !isSumOfMul (prod))
142
+ !isEmptyInit (prod.getOutputOperand (0 )) || !isSampling (op) ||
143
+ !isSumOfMul (prod))
155
144
return failure ();
156
145
// Modify operand structure of producer and consumer.
157
146
Location loc = prod.getLoc ();
0 commit comments