@@ -5199,7 +5199,11 @@ struct CancelLinearizeOfDelinearizePortion final
5199
5199
return rewriter.notifyMatchFailure (
5200
5200
linearizeOp, " no run of delinearize outputs to deal with" );
5201
5201
5202
- SmallVector<std::tuple<Value, Value>> delinearizeReplacements;
5202
+ // Record all the delinearize replacements so we can do them after creating
5203
+ // the new linearization operation, since the new operation might use
5204
+ // outputs of something we're replacing.
5205
+ SmallVector<SmallVector<Value>> delinearizeReplacements;
5206
+
5203
5207
SmallVector<Value> newIndex;
5204
5208
newIndex.reserve (numLinArgs);
5205
5209
SmallVector<OpFoldResult> newBasis;
@@ -5212,18 +5216,26 @@ struct CancelLinearizeOfDelinearizePortion final
5212
5216
// Update here so we don't forget this during early continues
5213
5217
prevMatchEnd = m.linStart + m.length ;
5214
5218
5219
+ PatternRewriter::InsertionGuard g (rewriter);
5220
+ rewriter.setInsertionPoint (m.delinearize );
5221
+
5222
+ ArrayRef<OpFoldResult> basisToMerge =
5223
+ linBasisRef.slice (m.linStart , m.length );
5215
5224
// We use the slice from the linearize's basis above because of the
5216
5225
// "bounds inferred from `disjoint`" case above.
5217
5226
OpFoldResult newSize =
5218
- computeProduct (linearizeOp.getLoc (), rewriter,
5219
- linBasisRef.slice (m.linStart , m.length ));
5227
+ computeProduct (linearizeOp.getLoc (), rewriter, basisToMerge);
5220
5228
5221
5229
// Trivial case where we can just skip past the delinearize all together
5222
5230
if (m.length == m.delinearize .getNumResults ()) {
5223
5231
newIndex.push_back (m.delinearize .getLinearIndex ());
5224
5232
newBasis.push_back (newSize);
5233
+ // Pad out set of replacements so we don't do anything with this one.
5234
+ delinearizeReplacements.push_back (SmallVector<Value>());
5225
5235
continue ;
5226
5236
}
5237
+
5238
+ SmallVector<Value> newDelinResults;
5227
5239
SmallVector<OpFoldResult> newDelinBasis = m.delinearize .getPaddedBasis ();
5228
5240
newDelinBasis.erase (newDelinBasis.begin () + m.delinStart ,
5229
5241
newDelinBasis.begin () + m.delinStart + m.length );
@@ -5232,31 +5244,39 @@ struct CancelLinearizeOfDelinearizePortion final
5232
5244
m.delinearize .getLoc (), m.delinearize .getLinearIndex (),
5233
5245
newDelinBasis);
5234
5246
5247
+ // Since there may be other uses of the indices we just merged together,
5248
+ // create a residual affine.delinearize_index that delinearizes the
5249
+ // merged output into its component parts.
5250
+ Value combinedElem = newDelinearize.getResult (m.delinStart );
5251
+ auto residualDelinearize = rewriter.create <AffineDelinearizeIndexOp>(
5252
+ m.delinearize .getLoc (), combinedElem, basisToMerge);
5253
+
5235
5254
// Swap all the uses of the unaffected delinearize outputs to the new
5236
5255
// delinearization so that the old code can be removed if this
5237
5256
// linearize_index is the only user of the merged results.
5257
+ llvm::append_range (newDelinResults,
5258
+ newDelinearize.getResults ().take_front (m.delinStart ));
5259
+ llvm::append_range (newDelinResults, residualDelinearize.getResults ());
5238
5260
llvm::append_range (
5239
- delinearizeReplacements,
5240
- llvm::zip_equal (
5241
- m.delinearize .getResults ().take_front (m.delinStart ),
5242
- newDelinearize.getResults ().take_front (m.delinStart )));
5243
- llvm::append_range (
5244
- delinearizeReplacements,
5245
- llvm::zip_equal (
5246
- m.delinearize .getResults ().drop_front (m.delinStart + m.length ),
5247
- newDelinearize.getResults ().drop_front (m.delinStart + 1 )));
5261
+ newDelinResults,
5262
+ newDelinearize.getResults ().drop_front (m.delinStart + 1 ));
5248
5263
5249
- Value newLinArg = newDelinearize. getResult (m. delinStart );
5250
- newIndex.push_back (newLinArg );
5264
+ delinearizeReplacements. push_back (newDelinResults );
5265
+ newIndex.push_back (combinedElem );
5251
5266
newBasis.push_back (newSize);
5252
5267
}
5253
5268
llvm::append_range (newIndex, multiIndex.drop_front (prevMatchEnd));
5254
5269
llvm::append_range (newBasis, linBasisRef.drop_front (prevMatchEnd));
5255
5270
rewriter.replaceOpWithNewOp <AffineLinearizeIndexOp>(
5256
5271
linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint ());
5257
5272
5258
- for (auto [from, to] : delinearizeReplacements)
5259
- rewriter.replaceAllUsesWith (from, to);
5273
+ for (auto [m, newResults] :
5274
+ llvm::zip_equal (matches, delinearizeReplacements)) {
5275
+ if (newResults.empty ())
5276
+ continue ;
5277
+ rewriter.replaceOp (m.delinearize , newResults);
5278
+ }
5279
+
5260
5280
return success ();
5261
5281
}
5262
5282
};
0 commit comments