26
26
#include " llvm/Support/CommandLine.h"
27
27
#include " llvm/Support/Debug.h"
28
28
29
+ #include < set>
30
+
29
31
#define DEBUG_TYPE " linalg-drop-unit-dims"
30
32
31
33
using namespace mlir ;
@@ -145,15 +147,42 @@ static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
145
147
context);
146
148
}
147
149
150
+ // / Modify the region of indexed generic op to drop arguments corresponding to
151
+ // / loops that are unit trip count.
152
+ template <typename OpTy>
153
+ static LogicalResult
154
+ replaceBlockArgForUnitDimLoops (OpTy op, const DenseSet<unsigned > &unitDims,
155
+ PatternRewriter &rewriterp) {
156
+ return success ();
157
+ }
158
+
159
+ template <>
160
+ LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
161
+ IndexedGenericOp op, const DenseSet<unsigned > &unitDims,
162
+ PatternRewriter &rewriter) {
163
+ OpBuilder::InsertionGuard guard (rewriter);
164
+ Block *entryBlock = &op.getOperation ()->getRegion (0 ).front ();
165
+ rewriter.setInsertionPointToStart (entryBlock);
166
+ Value zero = rewriter.create <ConstantIndexOp>(op.getLoc (), 0 );
167
+ for (unsigned unitDimLoop : unitDims) {
168
+ entryBlock->getArgument (unitDimLoop).replaceAllUsesWith (zero);
169
+ }
170
+ std::set<unsigned > orderedUnitDims (unitDims.begin (), unitDims.end ());
171
+ for (unsigned i : llvm::reverse (orderedUnitDims))
172
+ entryBlock->eraseArgument (i);
173
+ return success ();
174
+ }
175
+
148
176
namespace {
149
177
// / Pattern to fold unit-trip count loops in GenericOps.
150
178
// TODO: Generalize this to indexed-generic as well by modifying the region args
151
179
// as well.
152
- struct FoldUnitDimLoops : public OpRewritePattern <GenericOp> {
153
- using OpRewritePattern<GenericOp>::OpRewritePattern;
154
- LogicalResult matchAndRewrite (GenericOp genericOp,
180
+ template <typename GenericOpTy>
181
+ struct FoldUnitDimLoops : public OpRewritePattern <GenericOpTy> {
182
+ using OpRewritePattern<GenericOpTy>::OpRewritePattern;
183
+ LogicalResult matchAndRewrite (GenericOpTy op,
155
184
PatternRewriter &rewriter) const override {
156
- SmallVector<AffineMap, 4 > indexingMaps = genericOp .getIndexingMaps ();
185
+ SmallVector<AffineMap, 4 > indexingMaps = op .getIndexingMaps ();
157
186
if (indexingMaps.empty ())
158
187
return failure ();
159
188
@@ -164,10 +193,10 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
164
193
if (!invertedMap)
165
194
return failure ();
166
195
SmallVector<int64_t , 4 > dims;
167
- for (ShapedType shapedType : genericOp .getInputOutputShapedTypes ())
196
+ for (ShapedType shapedType : op .getInputOutputShapedTypes ())
168
197
dims.append (shapedType.getShape ().begin (), shapedType.getShape ().end ());
169
198
DenseSet<unsigned > unitDims;
170
- ArrayAttr iteratorTypes = genericOp .iterator_types ();
199
+ ArrayAttr iteratorTypes = op .iterator_types ();
171
200
for (auto expr : enumerate(invertedMap.getResults ())) {
172
201
if (AffineDimExpr dimExpr = expr.value ().dyn_cast <AffineDimExpr>())
173
202
if (dims[dimExpr.getPosition ()] == 1 &&
@@ -183,7 +212,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
183
212
ArrayAttr newIndexingMapAttr =
184
213
replaceUnitDims (unitDims, indexingMaps, context);
185
214
if (!newIndexingMapAttr)
186
- return genericOp .emitError (" unable to compute modified indexing_maps" );
215
+ return op .emitError (" unable to compute modified indexing_maps" );
187
216
188
217
// Compute the iterator types of the modified op by dropping the one-trip
189
218
// count loops.
@@ -193,10 +222,11 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
193
222
newIteratorTypes.push_back (attr.value ());
194
223
}
195
224
196
- rewriter.startRootUpdate (genericOp);
197
- genericOp.indexing_mapsAttr (newIndexingMapAttr);
198
- genericOp.iterator_typesAttr (ArrayAttr::get (newIteratorTypes, context));
199
- rewriter.finalizeRootUpdate (genericOp);
225
+ rewriter.startRootUpdate (op);
226
+ op.indexing_mapsAttr (newIndexingMapAttr);
227
+ op.iterator_typesAttr (ArrayAttr::get (newIteratorTypes, context));
228
+ replaceBlockArgForUnitDimLoops (op, unitDims, rewriter);
229
+ rewriter.finalizeRootUpdate (op);
200
230
return success ();
201
231
}
202
232
};
@@ -263,25 +293,27 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
263
293
namespace {
264
294
265
295
// / Pattern to replace tensors operands/results that are unit extents.
266
- struct ReplaceUnitExtentTensors : public OpRewritePattern <GenericOp> {
267
- using OpRewritePattern<GenericOp>::OpRewritePattern;
268
- LogicalResult matchAndRewrite (GenericOp genericOp,
296
+ template <typename GenericOpTy>
297
+ struct ReplaceUnitExtentTensors : public OpRewritePattern <GenericOpTy> {
298
+ using OpRewritePattern<GenericOpTy>::OpRewritePattern;
299
+ LogicalResult matchAndRewrite (GenericOpTy op,
269
300
PatternRewriter &rewriter) const override {
270
301
// TODO: support init_tensors and reductions.
271
- if (!genericOp .hasTensorSemantics () || !genericOp .init_tensors ().empty ())
302
+ if (!op .hasTensorSemantics () || !op .init_tensors ().empty ())
272
303
return failure ();
273
304
274
305
MLIRContext *context = rewriter.getContext ();
275
- Location loc = genericOp .getLoc ();
306
+ Location loc = op .getLoc ();
276
307
277
308
SmallVector<AffineMap, 4 > newIndexingMaps;
278
309
SmallVector<ArrayAttr, 4 > reassociationMaps;
279
310
SmallVector<ShapedType, 4 > newInputOutputTypes;
280
311
bool doCanonicalization = false ;
281
- for (auto it : llvm::zip (genericOp. getIndexingMaps (),
282
- genericOp .getInputOutputShapedTypes ())) {
312
+ for (auto it :
313
+ llvm::zip (op. getIndexingMaps (), op .getInputOutputShapedTypes ())) {
283
314
auto replacementInfo = replaceUnitExtents (
284
- std::get<0 >(it), std::get<1 >(it).cast <RankedTensorType>(), context);
315
+ std::get<0 >(it), std::get<1 >(it).template cast <RankedTensorType>(),
316
+ context);
285
317
reassociationMaps.push_back (replacementInfo.reassociation );
286
318
newIndexingMaps.push_back (replacementInfo.indexMap );
287
319
newInputOutputTypes.push_back (replacementInfo.type );
@@ -313,41 +345,40 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
313
345
return res;
314
346
};
315
347
316
- SmallVector<Value, 4 > newInputs = insertReshapes (genericOp .inputs ());
348
+ SmallVector<Value, 4 > newInputs = insertReshapes (op .inputs ());
317
349
SmallVector<Value, 4 > newOutputBuffers =
318
- insertReshapes (genericOp.output_buffers ());
319
- SmallVector<Value, 4 > newInitTensors =
320
- insertReshapes (genericOp.init_tensors ());
350
+ insertReshapes (op.output_buffers ());
351
+ SmallVector<Value, 4 > newInitTensors = insertReshapes (op.init_tensors ());
321
352
322
353
// If any result type change, insert a reshape to convert from the original
323
354
// type to the new type.
324
355
SmallVector<Type, 4 > resultTypes;
325
- resultTypes.reserve (genericOp .getNumResults ());
326
- for (unsigned i : llvm::seq<unsigned >(0 , genericOp .getNumResults ()))
327
- resultTypes.push_back (newInputOutputTypes[i + genericOp .getNumInputs ()]);
328
- GenericOp replacementOp = rewriter.create <GenericOp >(
356
+ resultTypes.reserve (op .getNumResults ());
357
+ for (unsigned i : llvm::seq<unsigned >(0 , op .getNumResults ()))
358
+ resultTypes.push_back (newInputOutputTypes[i + op .getNumInputs ()]);
359
+ GenericOpTy replacementOp = rewriter.create <GenericOpTy >(
329
360
loc, resultTypes, newInputs, newOutputBuffers, newInitTensors,
330
361
newIndexingMaps,
331
362
llvm::to_vector<4 >(
332
- genericOp .iterator_types ().getAsValueRange <StringAttr>()));
333
- rewriter.inlineRegionBefore (genericOp .region (), replacementOp.region (),
363
+ op .iterator_types ().template getAsValueRange <StringAttr>()));
364
+ rewriter.inlineRegionBefore (op .region (), replacementOp.region (),
334
365
replacementOp.region ().begin ());
335
366
336
367
// If any result tensor has a modified shape, then add reshape to recover
337
368
// the original shape.
338
369
SmallVector<Value, 4 > resultReplacements;
339
370
for (auto result : llvm::enumerate (replacementOp.getResults ())) {
340
371
unsigned index = result.index () + replacementOp.getNumOperands ();
341
- RankedTensorType origResultType = genericOp .getResult (result.index ())
372
+ RankedTensorType origResultType = op .getResult (result.index ())
342
373
.getType ()
343
- .cast <RankedTensorType>();
374
+ .template cast <RankedTensorType>();
344
375
if (origResultType != result.value ().getType ())
345
376
resultReplacements.push_back (rewriter.create <linalg::TensorReshapeOp>(
346
377
loc, origResultType, result.value (), reassociationMaps[index]));
347
378
else
348
379
resultReplacements.push_back (result.value ());
349
380
}
350
- rewriter.replaceOp (genericOp , resultReplacements);
381
+ rewriter.replaceOp (op , resultReplacements);
351
382
return success ();
352
383
}
353
384
};
@@ -467,7 +498,10 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
467
498
// / broadcasting.
468
499
void mlir::populateLinalgFoldUnitExtentDimsPatterns (
469
500
MLIRContext *context, OwningRewritePatternList &patterns) {
470
- patterns.insert <FoldUnitDimLoops, ReplaceUnitExtentTensors>(context);
501
+ patterns
502
+ .insert <FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
503
+ ReplaceUnitExtentTensors<GenericOp>,
504
+ ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
471
505
TensorReshapeOp::getCanonicalizationPatterns (patterns, context);
472
506
patterns.insert <FoldReshapeOpWithUnitExtent>(context);
473
507
}
@@ -481,7 +515,8 @@ struct LinalgFoldUnitExtentDimsPass
481
515
FuncOp funcOp = getFunction ();
482
516
MLIRContext *context = funcOp.getContext ();
483
517
if (foldOneTripLoopsOnly)
484
- patterns.insert <FoldUnitDimLoops>(context);
518
+ patterns.insert <FoldUnitDimLoops<GenericOp>,
519
+ FoldUnitDimLoops<IndexedGenericOp>>(context);
485
520
else
486
521
populateLinalgFoldUnitExtentDimsPatterns (context, patterns);
487
522
applyPatternsAndFoldGreedily (funcOp.getBody (), patterns);
0 commit comments