@@ -5109,13 +5109,19 @@ struct CancelLinearizeOfDelinearizePortion final
5109
5109
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
5110
5110
using OpRewritePattern::OpRewritePattern;
5111
5111
5112
+ private:
5113
+ // Struct representing a case where the cancellation pattern
5114
+ // applies. A `Match` means that `length` inputs to the linearize operation
5115
+ // starting at `linStart` can be cancelled with `length` outputs of
5116
+ // `delinearize`, starting from `delinStart`.
5112
5117
struct Match {
5113
5118
AffineDelinearizeIndexOp delinearize;
5114
5119
unsigned linStart = 0 ;
5115
5120
unsigned delinStart = 0 ;
5116
5121
unsigned length = 0 ;
5117
5122
};
5118
5123
5124
+ public:
5119
5125
LogicalResult matchAndRewrite (affine::AffineLinearizeIndexOp linearizeOp,
5120
5126
PatternRewriter &rewriter) const override {
5121
5127
SmallVector<Match> matches;
@@ -5128,7 +5134,7 @@ struct CancelLinearizeOfDelinearizePortion final
5128
5134
unsigned linArgIdx = 0 ;
5129
5135
// We only want to replace one run from the same delinearize op per
5130
5136
// pattern invocation lest we run into invalidation issues.
5131
- llvm::SmallPtrSet<Operation *, 2 > seen ;
5137
+ llvm::SmallPtrSet<Operation *, 2 > alreadyMatchedDelinearize ;
5132
5138
while (linArgIdx < numLinArgs) {
5133
5139
auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5134
5140
if (!asResult) {
@@ -5155,37 +5161,37 @@ struct CancelLinearizeOfDelinearizePortion final
5155
5161
// / - The delinearization doesn't specify a bound, but the linearization
5156
5162
// / is `disjoint`, which asserts that the bound on the linearization is
5157
5163
// / correct.
5158
- unsigned firstDelinArg = asResult.getResultNumber ();
5164
+ unsigned delinArgIdx = asResult.getResultNumber ();
5159
5165
SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis ();
5160
- OpFoldResult firstDelinBound = delinBasis[firstDelinArg ];
5166
+ OpFoldResult firstDelinBound = delinBasis[delinArgIdx ];
5161
5167
OpFoldResult firstLinBound = linBasis[linArgIdx];
5162
5168
bool boundsMatch = firstDelinBound == firstLinBound;
5163
- bool bothAtFront = linArgIdx == 0 && firstDelinArg == 0 ;
5169
+ bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0 ;
5164
5170
bool knownByDisjoint =
5165
- linearizeOp.getDisjoint () && firstDelinArg == 0 && !firstDelinBound;
5171
+ linearizeOp.getDisjoint () && delinArgIdx == 0 && !firstDelinBound;
5166
5172
if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5167
5173
linArgIdx++;
5168
5174
continue ;
5169
5175
}
5170
5176
5171
5177
unsigned j = 1 ;
5172
5178
unsigned numDelinOuts = delinearizeOp.getNumResults ();
5173
- for (; j + linArgIdx < numLinArgs && j + firstDelinArg < numDelinOuts;
5179
+ for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5174
5180
++j) {
5175
5181
if (multiIndex[linArgIdx + j] !=
5176
- delinearizeOp.getResult (firstDelinArg + j))
5182
+ delinearizeOp.getResult (delinArgIdx + j))
5177
5183
break ;
5178
- if (linBasis[linArgIdx + j] != delinBasis[firstDelinArg + j])
5184
+ if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5179
5185
break ;
5180
5186
}
5181
5187
// If there're multiple matches against the same delinearize_index,
5182
5188
// only rewrite the first one we find to prevent invalidations. The next
5183
- // ones will be taken caer of by subsequent pattern invocations.
5184
- if (j <= 1 || !seen .insert (delinearizeOp).second ) {
5189
+ // ones will be taken care of by subsequent pattern invocations.
5190
+ if (j <= 1 || !alreadyMatchedDelinearize .insert (delinearizeOp).second ) {
5185
5191
linArgIdx++;
5186
5192
continue ;
5187
5193
}
5188
- matches.push_back (Match{delinearizeOp, linArgIdx, firstDelinArg , j});
5194
+ matches.push_back (Match{delinearizeOp, linArgIdx, delinArgIdx , j});
5189
5195
linArgIdx += j;
5190
5196
}
5191
5197
0 commit comments