23
23
#include " mlir/Transforms/DialectConversion.h"
24
24
#include " mlir/Transforms/FoldUtils.h"
25
25
26
+ #include " llvm/ADT/TypeSwitch.h"
27
+
26
28
using namespace mlir ;
27
29
using namespace mlir ::edsc;
28
30
using namespace mlir ::edsc::intrinsics;
@@ -65,7 +67,7 @@ static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
65
67
assert (op.getOperation ()->getNumRegions () == 1 &&
66
68
" Expected single region op" );
67
69
auto &b = ScopedContext::getBuilderRef ();
68
- auto &block = op.region ( ).front ();
70
+ auto &block = op.getOperation ()-> getRegion ( 0 ).front ();
69
71
BlockAndValueMapping map;
70
72
map.map (block.getArguments (), indexedValues);
71
73
for (auto &op : block.without_terminator ()) {
@@ -102,8 +104,6 @@ static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
102
104
makeCanonicalAffineApplies (b, loc, maps[2 ], allIvs)};
103
105
}
104
106
105
- namespace {
106
-
107
107
// / Emits the MLIR for the scalar part of the generic op by:
108
108
// / 1. Emitting load ops for each input and output view in order. This is
109
109
// / achieved by applying the appropriate input or output map to the
@@ -134,10 +134,9 @@ namespace {
134
134
// / }
135
135
// / }
136
136
// / ```
137
- // TODO: need a LinalgStructuredOpInterface.
138
- template <typename IndexedValueType, typename LinalgStructuredOpType>
139
- void emitScalarImplementation (ArrayRef<Value> allIvs,
140
- LinalgStructuredOpType linalgOp) {
137
+ template <typename IndexedValueType>
138
+ static void emitScalarImplementation (ArrayRef<Value> allIvs,
139
+ LinalgOp linalgOp) {
141
140
assert (linalgOp.hasBufferSemantics () &&
142
141
" expected linalg op with buffer semantics" );
143
142
auto &b = ScopedContext::getBuilderRef ();
@@ -150,7 +149,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
150
149
auto attr = linalgOp.template getAttrOfType <IntegerAttr>(" symbol_source" );
151
150
auto allIvsPlusDims = SmallVector<Value, 4 >(allIvs.begin (), allIvs.end ());
152
151
if (attr) {
153
- auto operand = linalgOp.getOperand (attr.getInt ());
152
+ auto operand = linalgOp.getOperation ()-> getOperand (attr.getInt ());
154
153
auto shapedType = operand.getType ().template cast <ShapedType>();
155
154
allIvsPlusDims.reserve (allIvs.size () + shapedType.getRank ());
156
155
for (unsigned idx = 0 , e = shapedType.getRank (); idx < e; ++idx)
@@ -190,7 +189,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs,
190
189
}
191
190
192
191
template <typename IndexedValueType>
193
- void emitScalarImplementation (ArrayRef<Value> allIvs, CopyOp copyOp) {
192
+ static void emitScalarImplementation (ArrayRef<Value> allIvs, CopyOp copyOp) {
194
193
assert (copyOp.hasBufferSemantics () &&
195
194
" expected linalg op with buffer semantics" );
196
195
auto nPar = copyOp.getNumParallelLoops ();
@@ -211,7 +210,7 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
211
210
}
212
211
213
212
template <typename IndexedValueType>
214
- void emitScalarImplementation (ArrayRef<Value> allIvs, FillOp fillOp) {
213
+ static void emitScalarImplementation (ArrayRef<Value> allIvs, FillOp fillOp) {
215
214
assert (fillOp.hasBufferSemantics () &&
216
215
" expected linalg op with buffer semantics" );
217
216
auto nPar = fillOp.getNumParallelLoops ();
@@ -224,8 +223,8 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
224
223
}
225
224
226
225
template <typename IndexedValueType>
227
- Value getConvOpInput (ConvOp convOp, StdIndexedValue im,
228
- MutableArrayRef<Value> imIdx) {
226
+ static Value getConvOpInput (ConvOp convOp, StdIndexedValue im,
227
+ MutableArrayRef<Value> imIdx) {
229
228
// TODO: add a level of indirection to linalg.generic.
230
229
if (!convOp.padding ())
231
230
return im (imIdx);
@@ -311,39 +310,44 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
311
310
}
312
311
}
313
312
314
- template <typename IndexedValueType>
315
- void emitScalarImplementation (ArrayRef<Value> allIvs, PoolingMaxOp op) {
313
+ template <typename IndexedValueType, typename OpType>
314
+ static void emitPoolingMinMaxScalarImplementation (ArrayRef<Value> allIvs,
315
+ OpType op) {
316
316
InputAndOutputIndices indices = getInputAndOutputIndices (allIvs, op);
317
317
// Emit scalar form.
318
318
IndexedValueType output (op.output ());
319
319
IndexedValueType input (op.input ());
320
320
Value lhs = output (indices.outputs );
321
321
Value rhs = input (indices.inputs );
322
322
using edsc::op::sgt;
323
- Value maxValue = std_select (sgt (lhs, rhs), lhs, rhs);
324
- output (indices.outputs ) = maxValue;
323
+ using edsc::op::slt;
324
+ Value value = std::is_same<OpType, PoolingMinOp>()
325
+ ? std_select (slt (lhs, rhs), lhs, rhs)
326
+ : std_select (sgt (lhs, rhs), lhs, rhs);
327
+ output (indices.outputs ) = value;
325
328
}
326
329
327
330
template <typename IndexedValueType>
328
- void emitScalarImplementation (ArrayRef<Value> allIvs, PoolingMinOp op) {
329
- InputAndOutputIndices indices = getInputAndOutputIndices (allIvs, op);
330
- // Emit scalar form.
331
- IndexedValueType output (op.output ());
332
- IndexedValueType input (op.input ());
333
- Value lhs = output (indices.outputs );
334
- Value rhs = input (indices.inputs );
335
- using edsc::op::slt;
336
- Value minValue = std_select (slt (lhs, rhs), lhs, rhs);
337
- output (indices.outputs ) = minValue;
331
+ static void emitScalarImplementation (ArrayRef<Value> allIvs, PoolingMaxOp op) {
332
+ emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMaxOp>(allIvs,
333
+ op);
338
334
}
335
+
339
336
template <typename IndexedValueType>
340
- void emitScalarImplementation (ArrayRef<Value> allIvs, PoolingSumOp op) {
337
+ static void emitScalarImplementation (ArrayRef<Value> allIvs, PoolingMinOp op) {
338
+ emitPoolingMinMaxScalarImplementation<IndexedValueType, PoolingMinOp>(allIvs,
339
+ op);
340
+ }
341
+
342
+ template <typename IndexedValueType>
343
+ static void emitScalarImplementation (ArrayRef<Value> allIvs, PoolingSumOp op) {
341
344
auto indices = getInputAndOutputIndices (allIvs, op);
342
345
IndexedValueType input (op.input ()), output (op.output ());
343
346
344
347
// Emit scalar form.
345
348
output (indices.outputs ) += input (indices.inputs );
346
349
}
350
+
347
351
// / Emits the MLIR for the scalar part of the indexed generic op by:
348
352
// / 1. Emitting load ops for each input and output view in order. This is
349
353
// / achieved by applying the appropriate input or output map to the
@@ -422,15 +426,16 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
422
426
indexing, outputBuffers);
423
427
}
424
428
425
- template <typename LoopTy, typename ConcreteOpTy>
426
- Optional<LinalgLoops> linalgOpToLoopsImpl (Operation *op, OpBuilder &builder) {
429
+ template <typename LoopTy>
430
+ static Optional<LinalgLoops> linalgOpToLoopsImpl (Operation *op,
431
+ OpBuilder &builder) {
427
432
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
428
433
429
434
ScopedContext scope (builder, op->getLoc ());
430
435
431
436
// The flattened loopToOperandRangesMaps is expected to be an invertible
432
437
// permutation map (which is asserted in the inverse calculation).
433
- auto linalgOp = cast<ConcreteOpTy >(op);
438
+ auto linalgOp = cast<LinalgOp >(op);
434
439
assert (linalgOp.hasBufferSemantics () &&
435
440
" expected linalg op with buffer semantics" );
436
441
auto mapsRange =
@@ -447,7 +452,12 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
447
452
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
448
453
assert (iterArgs.empty () && " unexpected iterArgs" );
449
454
allIvs.append (ivs.begin (), ivs.end ());
450
- emitScalarImplementation<IndexedValueTy>(allIvs, linalgOp);
455
+ llvm::TypeSwitch<Operation *>(op)
456
+ .Case <CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
457
+ PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
458
+ emitScalarImplementation<IndexedValueTy>(allIvs, op);
459
+ })
460
+ .Default ([&](Operation *op) { assert (false && " unexpected op" ); });
451
461
return scf::ValueVector{};
452
462
});
453
463
// Number of loop ops might be different from the number of ivs since some
@@ -467,32 +477,38 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
467
477
return loops;
468
478
}
469
479
470
- template <typename LoopType, typename ConcreteOp>
480
+ namespace {
481
+ template <typename LoopType>
471
482
class LinalgRewritePattern : public RewritePattern {
472
483
public:
473
- explicit LinalgRewritePattern (MLIRContext *context)
474
- : RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
484
+ LinalgRewritePattern () : RewritePattern(/* benefit=*/ 1 , MatchAnyOpTypeTag()) {}
475
485
476
486
LogicalResult matchAndRewrite (Operation *op,
477
487
PatternRewriter &rewriter) const override {
478
- if (!linalgOpToLoopsImpl<LoopType, ConcreteOp>(op, rewriter))
488
+ if (!isa<LinalgOp>(op))
489
+ return failure ();
490
+ if (!linalgOpToLoopsImpl<LoopType>(op, rewriter))
479
491
return failure ();
480
492
rewriter.eraseOp (op);
481
493
return success ();
482
494
}
483
495
};
484
496
485
- template <typename LoopType, typename ConcreteOp>
486
- void insertOnePattern (OwningRewritePatternList &patterns, MLIRContext *ctx) {
487
- patterns.insert <LinalgRewritePattern<LoopType, ConcreteOp>>(ctx);
488
- }
497
+ struct FoldAffineOp ;
498
+ } // namespace
489
499
490
- template <typename LoopType, typename ... Args>
491
- void insertPatterns (OwningRewritePatternList &patterns, MLIRContext *ctx) {
492
- (void )std::initializer_list<int >{
493
- 0 , (insertOnePattern<LoopType, Args>(patterns, ctx), 0 )...};
500
+ template <typename LoopType>
501
+ static void lowerLinalgToLoopsImpl (FuncOp funcOp, MLIRContext *context) {
502
+ OwningRewritePatternList patterns;
503
+ patterns.insert <LinalgRewritePattern<LoopType>>();
504
+ DimOp::getCanonicalizationPatterns (patterns, context);
505
+ AffineApplyOp::getCanonicalizationPatterns (patterns, context);
506
+ patterns.insert <FoldAffineOp>(context);
507
+ // Just apply the patterns greedily.
508
+ applyPatternsAndFoldGreedily (funcOp, patterns);
494
509
}
495
510
511
+ namespace {
496
512
// / Local folding pattern for AffineApplyOp that we can apply greedily.
497
513
// / This replaces AffineApplyOp by the proper value in cases where the
498
514
// / associated map is trivial.
@@ -529,38 +545,20 @@ struct FoldAffineOp : public RewritePattern {
529
545
return failure ();
530
546
}
531
547
};
532
- } // namespace
533
-
534
- template <typename LoopType>
535
- static void lowerLinalgToLoopsImpl (FuncOp funcOp, MLIRContext *context) {
536
- OwningRewritePatternList patterns;
537
- // Canonicalization and folding patterns applied greedily allow cleaning up
538
- // the emitted IR on the fly.
539
- // TODO: fold view and subview ops?
540
- insertPatterns<LoopType,
541
- #define GET_OP_LIST
542
- #include " mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
543
- >(patterns, context);
544
548
545
- DimOp::getCanonicalizationPatterns (patterns, context);
546
- AffineApplyOp::getCanonicalizationPatterns (patterns, context);
547
- patterns.insert <FoldAffineOp>(context);
548
- // Just apply the patterns greedily.
549
- applyPatternsAndFoldGreedily (funcOp, patterns);
550
- }
551
-
552
- namespace {
553
549
struct LowerToAffineLoops
554
550
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
555
551
void runOnFunction () override {
556
552
lowerLinalgToLoopsImpl<AffineForOp>(getFunction (), &getContext ());
557
553
}
558
554
};
555
+
559
556
struct LowerToLoops : public LinalgLowerToLoopsBase <LowerToLoops> {
560
557
void runOnFunction () override {
561
558
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction (), &getContext ());
562
559
}
563
560
};
561
+
564
562
struct LowerToParallelLoops
565
563
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
566
564
void runOnFunction () override {
@@ -583,60 +581,6 @@ mlir::createConvertLinalgToAffineLoopsPass() {
583
581
return std::make_unique<LowerToAffineLoops>();
584
582
}
585
583
586
- // TODO: gradually remove this layer as more ops become "named".
587
- template <typename LoopTy>
588
- static Optional<LinalgLoops> linalgOpToLoopsImplSwitch (Operation *op,
589
- OpBuilder &builder) {
590
- assert (isa<LinalgOp>(op) && " LinalgOp expected" );
591
- if (isa<CopyOp>(op))
592
- return linalgOpToLoopsImpl<LoopTy, CopyOp>(op, builder);
593
- if (isa<FillOp>(op))
594
- return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder);
595
- if (isa<ConvOp>(op))
596
- return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder);
597
- if (isa<PoolingMaxOp>(op))
598
- return linalgOpToLoopsImpl<LoopTy, PoolingMaxOp>(op, builder);
599
- if (isa<PoolingMinOp>(op))
600
- return linalgOpToLoopsImpl<LoopTy, PoolingMinOp>(op, builder);
601
- if (isa<PoolingSumOp>(op))
602
- return linalgOpToLoopsImpl<LoopTy, PoolingSumOp>(op, builder);
603
- if (isa<IndexedGenericOp>(op))
604
- return linalgOpToLoopsImpl<LoopTy, IndexedGenericOp>(op, builder);
605
-
606
- // TODO: Cases below are generic and need a LinalgStructuredOpInterface.
607
- if (isa<GenericOp>(op))
608
- return linalgOpToLoopsImpl<LoopTy, GenericOp>(op, builder);
609
- if (isa<MatmulOp>(op))
610
- return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder);
611
- if (isa<MatvecOp>(op))
612
- return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder);
613
- if (isa<VecmatOp>(op))
614
- return linalgOpToLoopsImpl<LoopTy, VecmatOp>(op, builder);
615
- if (isa<DotOp>(op))
616
- return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
617
- if (isa<BatchMatmulOp>(op))
618
- return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
619
- if (isa<ConvWOp>(op))
620
- return linalgOpToLoopsImpl<LoopTy, ConvWOp>(op, builder);
621
- if (isa<ConvNWCOp>(op))
622
- return linalgOpToLoopsImpl<LoopTy, ConvNWCOp>(op, builder);
623
- if (isa<ConvNCWOp>(op))
624
- return linalgOpToLoopsImpl<LoopTy, ConvNCWOp>(op, builder);
625
- if (isa<ConvHWOp>(op))
626
- return linalgOpToLoopsImpl<LoopTy, ConvHWOp>(op, builder);
627
- if (isa<ConvNHWCOp>(op))
628
- return linalgOpToLoopsImpl<LoopTy, ConvNHWCOp>(op, builder);
629
- if (isa<ConvNCHWOp>(op))
630
- return linalgOpToLoopsImpl<LoopTy, ConvNCHWOp>(op, builder);
631
- if (isa<ConvDHWOp>(op))
632
- return linalgOpToLoopsImpl<LoopTy, ConvDHWOp>(op, builder);
633
- if (isa<ConvNDHWCOp>(op))
634
- return linalgOpToLoopsImpl<LoopTy, ConvNDHWCOp>(op, builder);
635
- if (isa<ConvNCDHWOp>(op))
636
- return linalgOpToLoopsImpl<LoopTy, ConvNCDHWOp>(op, builder);
637
- llvm_unreachable (" Unexpected op in linalgOpToLoopsImpl" );
638
- }
639
-
640
584
SmallVector<Range, 4 > mlir::linalg::emitLoopRanges (OpBuilder &b, Location loc,
641
585
AffineMap map,
642
586
ValueRange viewSizes) {
@@ -705,7 +649,7 @@ SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
705
649
template <typename LoopTy>
706
650
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops (OpBuilder &builder,
707
651
Operation *op) {
708
- return linalgOpToLoopsImplSwitch <LoopTy>(op, builder);
652
+ return linalgOpToLoopsImpl <LoopTy>(op, builder);
709
653
}
710
654
711
655
template Optional<LinalgLoops>
0 commit comments