Skip to content

Commit c303d9b

Browse files
[mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 2/n - Loops.cpp
This revision belongs to a series of patches that reduce reliance of Linalg transformations on templated rewrite and conversion patterns. Instead, this uses a MatchAnyTag pattern for the vast majority of cases and dispatches internally. Differential revision: https://reviews.llvm.org/D89133
1 parent e0dc3db commit c303d9b

File tree

1 file changed

+62
-118
lines changed

1 file changed

+62
-118
lines changed

mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

Lines changed: 62 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include "mlir/Transforms/DialectConversion.h"
2424
#include "mlir/Transforms/FoldUtils.h"
2525

26+
#include "llvm/ADT/TypeSwitch.h"
27+
2628
using namespace mlir;
2729
using namespace mlir::edsc;
2830
using namespace mlir::edsc::intrinsics;
@@ -65,7 +67,7 @@ static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
6567
assert(op.getOperation()->getNumRegions() == 1 &&
6668
"Expected single region op");
6769
auto &b = ScopedContext::getBuilderRef();
68-
auto &block = op.region().front();
70+
auto &block = op.getOperation()->getRegion(0).front();
6971
BlockAndValueMapping map;
7072
map.map(block.getArguments(), indexedValues);
7173
for (auto &op : block.without_terminator()) {
@@ -102,8 +104,6 @@ static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
102104
makeCanonicalAffineApplies(b, loc, maps[2], allIvs)};
103105
}
104106

105-
namespace {
106-
107107
/// Emits the MLIR for the scalar part of the generic op by:
108108
/// 1. Emitting load ops for each input and output view in order. This is
109109
/// achieved by applying the appropriate input or output map to the
@@ -134,10 +134,9 @@ namespace {
134134
/// }
135135
/// }
136136
/// ```
137-
// TODO: need a LinalgStructuredOpInterface.
138-
template <typename IndexedValueType, typename LinalgStructuredOpType>
139-
void emitScalarImplementation(ArrayRef<Value> allIvs,
140-
LinalgStructuredOpType linalgOp) {
137+
template <typename IndexedValueType>
138+
static void emitScalarImplementation(ArrayRef<Value> allIvs,
139+
LinalgOp linalgOp) {
141140
assert(linalgOp.hasBufferSemantics() &&
142141
"expected linalg op with buffer semantics");
143142
auto &b = ScopedContext::getBuilderRef();
@@ -150,7 +149,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
150149
auto attr = linalgOp.template getAttrOfType<IntegerAttr>("symbol_source");
151150
auto allIvsPlusDims = SmallVector<Value, 4>(allIvs.begin(), allIvs.end());
152151
if (attr) {
153-
auto operand = linalgOp.getOperand(attr.getInt());
152+
auto operand = linalgOp.getOperation()->getOperand(attr.getInt());
154153
auto shapedType = operand.getType().template cast<ShapedType>();
155154
allIvsPlusDims.reserve(allIvs.size() + shapedType.getRank());
156155
for (unsigned idx = 0, e = shapedType.getRank(); idx < e; ++idx)
@@ -190,7 +189,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
190189
}
191190

192191
template <typename IndexedValueType>
193-
void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
192+
static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
194193
assert(copyOp.hasBufferSemantics() &&
195194
"expected linalg op with buffer semantics");
196195
auto nPar = copyOp.getNumParallelLoops();
@@ -211,7 +210,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
211210
}
212211

213212
template <typename IndexedValueType>
214-
void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
213+
static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
215214
assert(fillOp.hasBufferSemantics() &&
216215
"expected linalg op with buffer semantics");
217216
auto nPar = fillOp.getNumParallelLoops();
@@ -224,8 +223,8 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
224223
}
225224

226225
template <typename IndexedValueType>
227-
Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
228-
MutableArrayRef<Value> imIdx) {
226+
static Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
227+
MutableArrayRef<Value> imIdx) {
229228
// TODO: add a level of indirection to linalg.generic.
230229
if (!convOp.padding())
231230
return im(imIdx);
@@ -311,39 +310,44 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
311310
}
312311
}
313312

314-
template <typename IndexedValueType>
315-
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
313+
template <typename IndexedValueType, typename OpType>
314+
static void emitPoolingMinMaxScalarImplementation(ArrayRef<Value> allIvs,
315+
OpType op) {
316316
InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
317317
// Emit scalar form.
318318
IndexedValueType output(op.output());
319319
IndexedValueType input(op.input());
320320
Value lhs = output(indices.outputs);
321321
Value rhs = input(indices.inputs);
322322
using edsc::op::sgt;
323-
Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs);
324-
output(indices.outputs) = maxValue;
323+
using edsc::op::slt;
324+
Value value = std::is_same<OpType, PoolingMinOp>()
325+
? std_select(slt(lhs, rhs), lhs, rhs)
326+
: std_select(sgt(lhs, rhs), lhs, rhs);
327+
output(indices.outputs) = value;
325328
}
326329

327330
template <typename IndexedValueType>
328-
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
329-
InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
330-
// Emit scalar form.
331-
IndexedValueType output(op.output());
332-
IndexedValueType input(op.input());
333-
Value lhs = output(indices.outputs);
334-
Value rhs = input(indices.inputs);
335-
using edsc::op::slt;
336-
Value minValue = std_select(slt(lhs, rhs), lhs, rhs);
337-
output(indices.outputs) = minValue;
331+
static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
332+
emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs,
333+
op);
338334
}
335+
339336
template <typename IndexedValueType>
340-
void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
337+
static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
338+
emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs,
339+
op);
340+
}
341+
342+
template <typename IndexedValueType>
343+
static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
341344
auto indices = getInputAndOutputIndices(allIvs, op);
342345
IndexedValueType input(op.input()), output(op.output());
343346

344347
// Emit scalar form.
345348
output(indices.outputs) += input(indices.inputs);
346349
}
350+
347351
/// Emits the MLIR for the scalar part of the indexed generic op by:
348352
/// 1. Emitting load ops for each input and output view in order. This is
349353
/// achieved by applying the appropriate input or output map to the
@@ -422,15 +426,16 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
422426
indexing, outputBuffers);
423427
}
424428

425-
template <typename LoopTy, typename ConcreteOpTy>
426-
Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
429+
template <typename LoopTy>
430+
static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
431+
OpBuilder &builder) {
427432
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
428433

429434
ScopedContext scope(builder, op->getLoc());
430435

431436
// The flattened loopToOperandRangesMaps is expected to be an invertible
432437
// permutation map (which is asserted in the inverse calculation).
433-
auto linalgOp = cast<ConcreteOpTy>(op);
438+
auto linalgOp = cast<LinalgOp>(op);
434439
assert(linalgOp.hasBufferSemantics() &&
435440
"expected linalg op with buffer semantics");
436441
auto mapsRange =
@@ -447,7 +452,12 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
447452
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
448453
assert(iterArgs.empty() && "unexpected iterArgs");
449454
allIvs.append(ivs.begin(), ivs.end());
450-
emitScalarImplementation<IndexedValueTy>(allIvs, linalgOp);
455+
llvm::TypeSwitch<Operation *>(op)
456+
.Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
457+
PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
458+
emitScalarImplementation<IndexedValueTy>(allIvs, op);
459+
})
460+
.Default([&](Operation *op) { assert(false && "unexpected op"); });
451461
return scf::ValueVector{};
452462
});
453463
// Number of loop ops might be different from the number of ivs since some
@@ -467,32 +477,38 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
467477
return loops;
468478
}
469479

470-
template <typename LoopType, typename ConcreteOp>
480+
namespace {
481+
template <typename LoopType>
471482
class LinalgRewritePattern : public RewritePattern {
472483
public:
473-
explicit LinalgRewritePattern(MLIRContext *context)
474-
: RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
484+
LinalgRewritePattern() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
475485

476486
LogicalResult matchAndRewrite(Operation *op,
477487
PatternRewriter &rewriter) const override {
478-
if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter))
488+
if (!isa<LinalgOp>(op))
489+
return failure();
490+
if (!linalgOpToLoopsImpl<LoopType>(op, rewriter))
479491
return failure();
480492
rewriter.eraseOp(op);
481493
return success();
482494
}
483495
};
484496

485-
template <typename LoopType, typename ConcreteOp>
486-
void insertOnePattern(OwningRewritePatternList &patterns, MLIRContext *ctx) {
487-
patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx);
488-
}
497+
struct FoldAffineOp;
498+
} // namespace
489499

490-
template <typename LoopType, typename... Args>
491-
void insertPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
492-
(void)std::initializer_list<int>{
493-
0, (insertOnePattern<LoopType, Args>(patterns, ctx), 0)...};
500+
template <typename LoopType>
501+
static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
502+
OwningRewritePatternList patterns;
503+
patterns.insert<LinalgRewritePattern<LoopType>>();
504+
DimOp::getCanonicalizationPatterns(patterns, context);
505+
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
506+
patterns.insert<FoldAffineOp>(context);
507+
// Just apply the patterns greedily.
508+
applyPatternsAndFoldGreedily(funcOp, patterns);
494509
}
495510

511+
namespace {
496512
/// Local folding pattern for AffineApplyOp that we can apply greedily.
497513
/// This replaces AffineApplyOp by the proper value in cases where the
498514
/// associated map is trivial.
@@ -529,38 +545,20 @@ struct FoldAffineOp : public RewritePattern {
529545
return failure();
530546
}
531547
};
532-
} // namespace
533-
534-
template <typename LoopType>
535-
static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
536-
OwningRewritePatternList patterns;
537-
// Canonicalization and folding patterns applied greedily allow cleaning up
538-
// the emitted IR on the fly.
539-
// TODO: fold view and subview ops?
540-
insertPatterns<LoopType,
541-
#define GET_OP_LIST
542-
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
543-
>(patterns, context);
544548

545-
DimOp::getCanonicalizationPatterns(patterns, context);
546-
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
547-
patterns.insert<FoldAffineOp>(context);
548-
// Just apply the patterns greedily.
549-
applyPatternsAndFoldGreedily(funcOp, patterns);
550-
}
551-
552-
namespace {
553549
struct LowerToAffineLoops
554550
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
555551
void runOnFunction() override {
556552
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext());
557553
}
558554
};
555+
559556
struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
560557
void runOnFunction() override {
561558
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext());
562559
}
563560
};
561+
564562
struct LowerToParallelLoops
565563
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
566564
void runOnFunction() override {
@@ -583,60 +581,6 @@ mlir::createConvertLinalgToAffineLoopsPass() {
583581
return std::make_unique<LowerToAffineLoops>();
584582
}
585583

586-
// TODO: gradually remove this layer as more ops become "named".
587-
template <typename LoopTy>
588-
static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
589-
OpBuilder &builder) {
590-
assert(isa<LinalgOp>(op) && "LinalgOp expected");
591-
if (isa<CopyOp>(op))
592-
return linalgOpToLoopsImpl<LoopTy, CopyOp>(op, builder);
593-
if (isa<FillOp>(op))
594-
return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder);
595-
if (isa<ConvOp>(op))
596-
return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder);
597-
if (isa<PoolingMaxOp>(op))
598-
return linalgOpToLoopsImpl<LoopTy, PoolingMaxOp>(op, builder);
599-
if (isa<PoolingMinOp>(op))
600-
return linalgOpToLoopsImpl<LoopTy, PoolingMinOp>(op, builder);
601-
if (isa<PoolingSumOp>(op))
602-
return linalgOpToLoopsImpl<LoopTy, PoolingSumOp>(op, builder);
603-
if (isa<IndexedGenericOp>(op))
604-
return linalgOpToLoopsImpl<LoopTy, IndexedGenericOp>(op, builder);
605-
606-
// TODO: Cases below are generic and need a LinalgStructuredOpInterface.
607-
if (isa<GenericOp>(op))
608-
return linalgOpToLoopsImpl<LoopTy, GenericOp>(op, builder);
609-
if (isa<MatmulOp>(op))
610-
return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder);
611-
if (isa<MatvecOp>(op))
612-
return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder);
613-
if (isa<VecmatOp>(op))
614-
return linalgOpToLoopsImpl<LoopTy, VecmatOp>(op, builder);
615-
if (isa<DotOp>(op))
616-
return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
617-
if (isa<BatchMatmulOp>(op))
618-
return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
619-
if (isa<ConvWOp>(op))
620-
return linalgOpToLoopsImpl<LoopTy, ConvWOp>(op, builder);
621-
if (isa<ConvNWCOp>(op))
622-
return linalgOpToLoopsImpl<LoopTy, ConvNWCOp>(op, builder);
623-
if (isa<ConvNCWOp>(op))
624-
return linalgOpToLoopsImpl<LoopTy, ConvNCWOp>(op, builder);
625-
if (isa<ConvHWOp>(op))
626-
return linalgOpToLoopsImpl<LoopTy, ConvHWOp>(op, builder);
627-
if (isa<ConvNHWCOp>(op))
628-
return linalgOpToLoopsImpl<LoopTy, ConvNHWCOp>(op, builder);
629-
if (isa<ConvNCHWOp>(op))
630-
return linalgOpToLoopsImpl<LoopTy, ConvNCHWOp>(op, builder);
631-
if (isa<ConvDHWOp>(op))
632-
return linalgOpToLoopsImpl<LoopTy, ConvDHWOp>(op, builder);
633-
if (isa<ConvNDHWCOp>(op))
634-
return linalgOpToLoopsImpl<LoopTy, ConvNDHWCOp>(op, builder);
635-
if (isa<ConvNCDHWOp>(op))
636-
return linalgOpToLoopsImpl<LoopTy, ConvNCDHWOp>(op, builder);
637-
llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
638-
}
639-
640584
SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
641585
AffineMap map,
642586
ValueRange viewSizes) {
@@ -705,7 +649,7 @@ SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
705649
template <typename LoopTy>
706650
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
707651
Operation *op) {
708-
return linalgOpToLoopsImplSwitch<LoopTy>(op, builder);
652+
return linalgOpToLoopsImpl<LoopTy>(op, builder);
709653
}
710654

711655
template Optional<LinalgLoops>

0 commit comments

Comments
 (0)