Skip to content

Commit 78f37b7

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Miscalleneous enhancements to cover more fusion cases.
Adds support for - Dropping unit dimension loops for indexed_generic ops. - Folding consecutive folding (or expanding) reshapes when the result (or src) is a scalar. - Fixes to indexed_generic -> generic fusion when zero-dim tensors are involved. Differential Revision: https://reviews.llvm.org/D90118
1 parent 0b2f4cd commit 78f37b7

File tree

6 files changed

+252
-43
lines changed

6 files changed

+252
-43
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,10 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
461461
static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
462462
ArrayRef<AffineMap> mapsConsumer,
463463
MLIRContext *context) {
464+
// Handle the corner case of the result being a rank 0 shaped type. Return an
465+
// emtpy ArrayAttr.
466+
if (mapsConsumer.empty() && !mapsProducer.empty())
467+
return ArrayAttr::get(ArrayRef<Attribute>(), context);
464468
if (mapsProducer.empty() || mapsConsumer.empty() ||
465469
mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
466470
mapsProducer.size() != mapsConsumer[0].getNumDims())
@@ -500,8 +504,7 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
500504
ShapedType intermediateType,
501505
ShapedType smallerType) -> bool {
502506
return largerType.getRank() > intermediateType.getRank() &&
503-
intermediateType.getRank() > smallerType.getRank() &&
504-
smallerType.getRank() > 0;
507+
intermediateType.getRank() > smallerType.getRank();
505508
};
506509
// Check if producer and consumer are both expanding dims.
507510
if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(),

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include "llvm/Support/CommandLine.h"
2727
#include "llvm/Support/Debug.h"
2828

29+
#include <set>
30+
2931
#define DEBUG_TYPE "linalg-drop-unit-dims"
3032

3133
using namespace mlir;
@@ -145,15 +147,42 @@ static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
145147
context);
146148
}
147149

150+
/// Modify the region of indexed generic op to drop arguments corresponding to
151+
/// loops that are unit trip count.
152+
template <typename OpTy>
153+
static LogicalResult
154+
replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims,
155+
PatternRewriter &rewriterp) {
156+
return success();
157+
}
158+
159+
template <>
160+
LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
161+
IndexedGenericOp op, const DenseSet<unsigned> &unitDims,
162+
PatternRewriter &rewriter) {
163+
OpBuilder::InsertionGuard guard(rewriter);
164+
Block *entryBlock = &op.getOperation()->getRegion(0).front();
165+
rewriter.setInsertionPointToStart(entryBlock);
166+
Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
167+
for (unsigned unitDimLoop : unitDims) {
168+
entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
169+
}
170+
std::set<unsigned> orderedUnitDims(unitDims.begin(), unitDims.end());
171+
for (unsigned i : llvm::reverse(orderedUnitDims))
172+
entryBlock->eraseArgument(i);
173+
return success();
174+
}
175+
148176
namespace {
149177
/// Pattern to fold unit-trip count loops in GenericOps.
150178
// TODO: Generalize this to indexed-generic as well by modifying the region args
151179
// as well.
152-
struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
153-
using OpRewritePattern<GenericOp>::OpRewritePattern;
154-
LogicalResult matchAndRewrite(GenericOp genericOp,
180+
template <typename GenericOpTy>
181+
struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
182+
using OpRewritePattern<GenericOpTy>::OpRewritePattern;
183+
LogicalResult matchAndRewrite(GenericOpTy op,
155184
PatternRewriter &rewriter) const override {
156-
SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
185+
SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
157186
if (indexingMaps.empty())
158187
return failure();
159188

@@ -164,10 +193,10 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
164193
if (!invertedMap)
165194
return failure();
166195
SmallVector<int64_t, 4> dims;
167-
for (ShapedType shapedType : genericOp.getInputOutputShapedTypes())
196+
for (ShapedType shapedType : op.getInputOutputShapedTypes())
168197
dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
169198
DenseSet<unsigned> unitDims;
170-
ArrayAttr iteratorTypes = genericOp.iterator_types();
199+
ArrayAttr iteratorTypes = op.iterator_types();
171200
for (auto expr : enumerate(invertedMap.getResults())) {
172201
if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
173202
if (dims[dimExpr.getPosition()] == 1 &&
@@ -183,7 +212,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
183212
ArrayAttr newIndexingMapAttr =
184213
replaceUnitDims(unitDims, indexingMaps, context);
185214
if (!newIndexingMapAttr)
186-
return genericOp.emitError("unable to compute modified indexing_maps");
215+
return op.emitError("unable to compute modified indexing_maps");
187216

188217
// Compute the iterator types of the modified op by dropping the one-trip
189218
// count loops.
@@ -193,10 +222,11 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
193222
newIteratorTypes.push_back(attr.value());
194223
}
195224

196-
rewriter.startRootUpdate(genericOp);
197-
genericOp.indexing_mapsAttr(newIndexingMapAttr);
198-
genericOp.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
199-
rewriter.finalizeRootUpdate(genericOp);
225+
rewriter.startRootUpdate(op);
226+
op.indexing_mapsAttr(newIndexingMapAttr);
227+
op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
228+
replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
229+
rewriter.finalizeRootUpdate(op);
200230
return success();
201231
}
202232
};
@@ -263,25 +293,27 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
263293
namespace {
264294

265295
/// Pattern to replace tensors operands/results that are unit extents.
266-
struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
267-
using OpRewritePattern<GenericOp>::OpRewritePattern;
268-
LogicalResult matchAndRewrite(GenericOp genericOp,
296+
template <typename GenericOpTy>
297+
struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
298+
using OpRewritePattern<GenericOpTy>::OpRewritePattern;
299+
LogicalResult matchAndRewrite(GenericOpTy op,
269300
PatternRewriter &rewriter) const override {
270301
// TODO: support init_tensors and reductions.
271-
if (!genericOp.hasTensorSemantics() || !genericOp.init_tensors().empty())
302+
if (!op.hasTensorSemantics() || !op.init_tensors().empty())
272303
return failure();
273304

274305
MLIRContext *context = rewriter.getContext();
275-
Location loc = genericOp.getLoc();
306+
Location loc = op.getLoc();
276307

277308
SmallVector<AffineMap, 4> newIndexingMaps;
278309
SmallVector<ArrayAttr, 4> reassociationMaps;
279310
SmallVector<ShapedType, 4> newInputOutputTypes;
280311
bool doCanonicalization = false;
281-
for (auto it : llvm::zip(genericOp.getIndexingMaps(),
282-
genericOp.getInputOutputShapedTypes())) {
312+
for (auto it :
313+
llvm::zip(op.getIndexingMaps(), op.getInputOutputShapedTypes())) {
283314
auto replacementInfo = replaceUnitExtents(
284-
std::get<0>(it), std::get<1>(it).cast<RankedTensorType>(), context);
315+
std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
316+
context);
285317
reassociationMaps.push_back(replacementInfo.reassociation);
286318
newIndexingMaps.push_back(replacementInfo.indexMap);
287319
newInputOutputTypes.push_back(replacementInfo.type);
@@ -313,41 +345,40 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
313345
return res;
314346
};
315347

316-
SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
348+
SmallVector<Value, 4> newInputs = insertReshapes(op.inputs());
317349
SmallVector<Value, 4> newOutputBuffers =
318-
insertReshapes(genericOp.output_buffers());
319-
SmallVector<Value, 4> newInitTensors =
320-
insertReshapes(genericOp.init_tensors());
350+
insertReshapes(op.output_buffers());
351+
SmallVector<Value, 4> newInitTensors = insertReshapes(op.init_tensors());
321352

322353
// If any result type change, insert a reshape to convert from the original
323354
// type to the new type.
324355
SmallVector<Type, 4> resultTypes;
325-
resultTypes.reserve(genericOp.getNumResults());
326-
for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
327-
resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
328-
GenericOp replacementOp = rewriter.create<GenericOp>(
356+
resultTypes.reserve(op.getNumResults());
357+
for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults()))
358+
resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]);
359+
GenericOpTy replacementOp = rewriter.create<GenericOpTy>(
329360
loc, resultTypes, newInputs, newOutputBuffers, newInitTensors,
330361
newIndexingMaps,
331362
llvm::to_vector<4>(
332-
genericOp.iterator_types().getAsValueRange<StringAttr>()));
333-
rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
363+
op.iterator_types().template getAsValueRange<StringAttr>()));
364+
rewriter.inlineRegionBefore(op.region(), replacementOp.region(),
334365
replacementOp.region().begin());
335366

336367
// If any result tensor has a modified shape, then add reshape to recover
337368
// the original shape.
338369
SmallVector<Value, 4> resultReplacements;
339370
for (auto result : llvm::enumerate(replacementOp.getResults())) {
340371
unsigned index = result.index() + replacementOp.getNumOperands();
341-
RankedTensorType origResultType = genericOp.getResult(result.index())
372+
RankedTensorType origResultType = op.getResult(result.index())
342373
.getType()
343-
.cast<RankedTensorType>();
374+
.template cast<RankedTensorType>();
344375
if (origResultType != result.value().getType())
345376
resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
346377
loc, origResultType, result.value(), reassociationMaps[index]));
347378
else
348379
resultReplacements.push_back(result.value());
349380
}
350-
rewriter.replaceOp(genericOp, resultReplacements);
381+
rewriter.replaceOp(op, resultReplacements);
351382
return success();
352383
}
353384
};
@@ -467,7 +498,10 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
467498
/// broadcasting.
468499
void mlir::populateLinalgFoldUnitExtentDimsPatterns(
469500
MLIRContext *context, OwningRewritePatternList &patterns) {
470-
patterns.insert<FoldUnitDimLoops, ReplaceUnitExtentTensors>(context);
501+
patterns
502+
.insert<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
503+
ReplaceUnitExtentTensors<GenericOp>,
504+
ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
471505
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
472506
patterns.insert<FoldReshapeOpWithUnitExtent>(context);
473507
}
@@ -481,7 +515,8 @@ struct LinalgFoldUnitExtentDimsPass
481515
FuncOp funcOp = getFunction();
482516
MLIRContext *context = funcOp.getContext();
483517
if (foldOneTripLoopsOnly)
484-
patterns.insert<FoldUnitDimLoops>(context);
518+
patterns.insert<FoldUnitDimLoops<GenericOp>,
519+
FoldUnitDimLoops<IndexedGenericOp>>(context);
485520
else
486521
populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
487522
applyPatternsAndFoldGreedily(funcOp.getBody(), patterns);

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,19 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
109109
// consumer's operand.
110110
// If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
111111
// generic op. In this case, there are no indices in block arguments.
112-
unsigned numProducerIndices =
113-
isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0;
114-
unsigned numConsumerIndices =
115-
isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0;
112+
unsigned numProducerIndices = isa<IndexedGenericOp>(producer.getOperation())
113+
? producer.getNumLoops()
114+
: 0;
115+
unsigned numConsumerIndices = isa<IndexedGenericOp>(consumer.getOperation())
116+
? consumer.getNumLoops()
117+
: 0;
118+
unsigned numFusedOpIndices =
119+
(isa<IndexedGenericOp>(producer.getOperation()) ||
120+
isa<IndexedGenericOp>(consumer.getOperation()))
121+
? std::max(producer.getNumLoops(), consumer.getNumLoops())
122+
: 0;
116123
// Firstly, add all the indices to the block arguments.
117-
for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices);
118-
i < e; ++i)
124+
for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i)
119125
fusedBlock->addArgument(rewriter.getIndexType());
120126
// Map the arguments for the unmodified args from the consumer.
121127
for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
@@ -129,7 +135,7 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
129135
auto newIndex = rewriter.create<mlir::AffineApplyOp>(
130136
producer.getLoc(),
131137
consumerToProducerLoopsMap.getSubMap(producerArg.index()),
132-
fusedBlock->getArguments().take_front(nloops));
138+
fusedBlock->getArguments().take_front(numFusedOpIndices));
133139
mapper.map(producerArg.value(), newIndex);
134140
} else {
135141
mapper.map(producerArg.value(),

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,34 @@ func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf3
4343

4444
// -----
4545

46+
// -----
47+
48+
func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
49+
-> tensor<f32> {
50+
%0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
51+
tensor<1x1x1xf32> into tensor<1xf32>
52+
%1 = linalg.tensor_reshape %0 [] : tensor<1xf32> into tensor<f32>
53+
return %1 : tensor<f32>
54+
}
55+
// CHECK-LABEL: collapsing_tensor_reshapes_to_zero
56+
// CHECK: linalg.tensor_reshape %{{.*}} []
57+
// CHECK-SAME: tensor<1x1x1xf32> into tensor<f32>
58+
59+
// -----
60+
61+
func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
62+
-> memref<f32> {
63+
%0 = linalg.reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
64+
memref<1x1x1xf32> into memref<1xf32>
65+
%1 = linalg.reshape %0 [] : memref<1xf32> into memref<f32>
66+
return %1 : memref<f32>
67+
}
68+
// CHECK-LABEL: collapsing_memref_reshapes_to_zero
69+
// CHECK: linalg.reshape %{{.*}} []
70+
// CHECK-SAME: memref<1x1x1xf32> into memref<f32>
71+
72+
// -----
73+
4674
func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x?x?x?x?xf32>
4775
{
4876
%0 = linalg.tensor_reshape %arg0
@@ -106,6 +134,33 @@ func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x?x?x?x?xf32
106134

107135
// -----
108136

137+
func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
138+
-> tensor<1x1x1xf32> {
139+
%0 = linalg.tensor_reshape %arg0 [] : tensor<f32> into tensor<1xf32>
140+
%1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
141+
tensor<1xf32> into tensor<1x1x1xf32>
142+
return %1 : tensor<1x1x1xf32>
143+
}
144+
// CHECK-LABEL: expanding_tensor_reshapes_to_zero
145+
// CHECK: linalg.tensor_reshape %{{.*}} []
146+
// CHECK-SAME: tensor<f32> into tensor<1x1x1xf32>
147+
148+
// -----
149+
150+
func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
151+
-> memref<1x1x1xf32> {
152+
%0 = linalg.reshape %arg0 [] : memref<f32> into memref<1xf32>
153+
%1 = linalg.reshape %0
154+
[affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
155+
memref<1xf32> into memref<1x1x1xf32>
156+
return %1 : memref<1x1x1xf32>
157+
}
158+
// CHECK-LABEL: expanding_memref_reshapes_to_zero
159+
// CHECK: linalg.reshape %{{.*}} []
160+
// CHECK-SAME: memref<f32> into memref<1x1x1xf32>
161+
162+
// -----
163+
109164
func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
110165
{
111166
%0 = linalg.tensor_reshape %arg0

0 commit comments

Comments
 (0)