Skip to content

Commit 80f0785

Browse files
[mlir][Linalg] NFC - Refactor fusion APIs
This revision uniformizes fusion APIs to allow passing OpOperand, OpResult and adds a finer level of control fusion. Differential Revision: https://reviews.llvm.org/D94493
1 parent 2ed914c commit 80f0785

File tree

7 files changed

+128
-107
lines changed

7 files changed

+128
-107
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,18 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
726726
getNumShapedOperands());
727727
}]
728728
>,
729+
InterfaceMethod<
730+
/*desc=*/[{
731+
Return the OpOperands for all the shaped operands.
732+
}],
733+
/*retTy=*/" OpOperand&",
734+
/*methodName=*/"getShapedOpOperand",
735+
/*args=*/(ins "unsigned":$i),
736+
/*methodBody=*/"",
737+
/*defaultImplementation=*/[{
738+
return *(this->getShapedOpOperands().begin() + i);
739+
}]
740+
>,
729741
InterfaceMethod<
730742
/*desc=*/[{
731743
Return the range over input and output operands.

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct TiledLinalgOp {
3535
LinalgOp op;
3636
SmallVector<Operation *, 8> loops;
3737
SmallVector<Value, 4> tensorResults;
38+
TiledLinalgOp &operator=(const TiledLinalgOp &) = default;
3839
};
3940

4041
/// Populates patterns for vectorization of all ConvN-D ops.
@@ -412,9 +413,8 @@ struct LinalgBaseTilingPattern : public RewritePattern {
412413
LinalgTilingOptions options,
413414
LinalgMarker marker = LinalgMarker(),
414415
PatternBenefit benefit = 1);
415-
LogicalResult
416-
matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
417-
SmallVectorImpl<Value> &tensorResults) const;
416+
LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
417+
TiledLinalgOp &result) const;
418418

419419
private:
420420
/// LinalgTransformMarker handles special attribute manipulations.
@@ -432,14 +432,14 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
432432
marker, benefit) {}
433433
LogicalResult matchAndRewrite(Operation *op,
434434
PatternRewriter &rewriter) const override {
435-
SmallVector<Value, 4> tensorResults;
435+
TiledLinalgOp tiledLinalgOp;
436436
if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
437-
tensorResults)))
437+
tiledLinalgOp)))
438438
return failure();
439-
if (tensorResults.empty())
439+
if (tiledLinalgOp.tensorResults.empty())
440440
rewriter.eraseOp(op);
441441
else
442-
rewriter.replaceOp(op, tensorResults);
442+
rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
443443
return success();
444444
}
445445
};

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,31 @@ findAllFusableDependences(ArrayRef<LinalgOp> ops,
9292

9393
/// Fuses producer into consumer if the producer is structurally feasible and
9494
/// the fusion would not violate dependencies.
95-
/// Implements the fusion part of the "tileAndFuse on buffers"
96-
/// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
97-
/// to be a `subview` op (generally obtained by applying the tiling
98-
/// transformation).
99-
Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
100-
unsigned consumerIdx,
95+
/// Implements the fusion part of the "tileAndFuse on buffers" transformation
96+
/// and thus requires the `consumerOpOperand` to be a `subview` op (generally
97+
/// obtained by applying the tiling transformation).
98+
Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b,
99+
OpOperand &consumerOpOperand,
101100
const LinalgDependenceGraph &graph);
102101
/// Tensor counterpart of `fuseProducerOfBuffer`.
103102
/// This implements the fusion part of the "tileAndFuse on tensors"
104-
/// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
105-
/// to be the result of a `subtensor` op (generally obtained by applying the
106-
/// tiling transformation).
107-
Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
108-
unsigned consumerIdx);
103+
/// transformation and thus requires the `consumerOpOperand` to be a `subtensor`
104+
/// op (generally obtained by applying the tiling transformation).
105+
Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
106+
OpOperand &consumerOpOperand);
107+
/// Tensor counterpart of `fuseProducerOfBuffer`.
108+
/// This implements the fusion part of the "tileAndFuse on tensors"
109+
/// transformation and thus requires the `consumerOpOperand` to be a `subtensor`
110+
/// op (generally obtained by applying the tiling transformation).
111+
/// Assumes `producerOfTensor` is a Linalg op that produces `consumerOpOperand`.
112+
Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
113+
OpResult producerOpResult,
114+
OpOperand &consumerOpOperand);
109115

110116
/// Fuse linalg operation on tensors, with the producer of the operand at
111117
/// position `consumerIdx` of the consumer.
112118
Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
113-
Operation *consumer,
114-
unsigned consumerIdx);
119+
OpOperand &consumerOpOperand);
115120

116121
/// Like `getShape`, but only returns statically-known information, without
117122
/// generating any new IR. For each shape dimension, returns >=0 if that

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 70 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -258,20 +258,19 @@ static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
258258
/// `producer.getOutputBuffers()`.
259259
/// 2. Tensor case: `producerIdx` is the index of the tensor in
260260
/// `producer.getResults()`.
261-
static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
262-
LinalgOp consumer, unsigned consumerIdx) {
263-
AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
264-
LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
261+
static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp,
262+
unsigned producerOutNumber, OpOperand &consumerOpOperand) {
263+
AffineMap producerMap = producerOp.getOutputIndexingMap(producerOutNumber);
264+
LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerOutNumber
265265
<< ", producer map: " << producerMap << "\n");
266266
DenseMap<unsigned, Range> fusedLoopsAndRanges;
267-
Location loc = consumer.getLoc();
268-
Value shapedOperand = consumer.getShapedOperand(consumerIdx);
267+
Value shapedOperand = consumerOpOperand.get();
269268
for (auto en : llvm::enumerate(producerMap.getResults())) {
270269
unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
271-
fusedLoopsAndRanges[posInProducerLoop] =
272-
getRangeFromOperandShape(b, loc, shapedOperand, en.index());
270+
fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
271+
b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index());
273272
}
274-
return fuse(b, producer, fusedLoopsAndRanges);
273+
return fuse(b, producerOp, fusedLoopsAndRanges);
275274
}
276275

277276
// Encode structural fusion safety preconditions.
@@ -378,31 +377,27 @@ static bool isSameSubView(Value a, Value b) {
378377
}
379378

380379
static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
381-
findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
380+
findFusableProducer(OpOperand &consumerOpOperand,
382381
const LinalgDependenceGraph &dependenceGraph) {
383-
assert(consumer.hasBufferSemantics() && "revisit usage of shaped operand");
382+
LinalgOp consumerOp = cast<LinalgOp>(consumerOpOperand.getOwner());
383+
assert(consumerOp.hasBufferSemantics() && "revisit usage of shaped operand");
384384

385385
// Only consider RAW and WAW atm.
386386
for (auto depType : {
387387
LinalgDependenceGraph::DependenceType::RAW,
388388
LinalgDependenceGraph::DependenceType::WAW,
389389
}) {
390390
for (auto dependence : llvm::make_filter_range(
391-
dependenceGraph.getDependencesInto(consumer, depType),
392-
[consumerIdx](
393-
LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
394-
return elem.indexingOpView->getOperandNumber() == consumerIdx;
391+
dependenceGraph.getDependencesInto(consumerOp, depType),
392+
[&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
393+
return elem.indexingOpView->get() == consumerOpOperand.get() &&
394+
elem.indexingOpView->getOperandNumber() ==
395+
consumerOpOperand.getOperandNumber();
395396
})) {
396397

397-
// Check that the dependence is indeed on the input `consumerIdx` view.
398-
Value consumedView = dependence.indexingOpView->get();
399-
if (!isSameSubView(consumer.getShapedOperand(consumerIdx), consumedView))
400-
continue;
401-
402398
// Consumer consumes this view, `isStructurallyFusableProducer` also
403399
// checks whether it is a strict subview of the producer view.
404400
auto producer = cast<LinalgOp>(dependence.dependentOpView->getOwner());
405-
Value producedView = dependence.dependentOpView->get();
406401
LLVM_DEBUG(llvm::dbgs()
407402
<< "\n"
408403
<< LinalgDependenceGraph::getDependenceTypeStr(depType)
@@ -412,10 +407,10 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
412407
<< dependence.dependentOpView->getOperandNumber() -
413408
producer.getNumInputs()
414409
<< "\n");
415-
(void)producedView;
416410

417411
// Simple fusability checks.
418-
if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
412+
if (!isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
413+
producer))
419414
continue;
420415

421416
return dependence;
@@ -425,55 +420,54 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
425420
}
426421

427422
Optional<FusionInfo>
428-
mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
429-
unsigned consumerIdx,
423+
mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
430424
const LinalgDependenceGraph &graph) {
431425
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
432-
findFusableProducer(consumer, consumerIdx, graph);
426+
findFusableProducer(consumerOpOperand, graph);
433427
if (!fusableDependence)
434428
return {};
435429

436430
LinalgOp producerOp =
437431
cast<LinalgOp>(fusableDependence->dependentOpView->getOwner());
438432
// If producer is already in the same block as consumer, we are done.
439-
if (consumer->getBlock() == producerOp->getBlock())
433+
if (consumerOpOperand.get().getParentBlock() ==
434+
fusableDependence->dependentOpView->get().getParentBlock())
440435
return {};
441436

442437
unsigned producerIdx =
443438
fusableDependence->dependentOpView->getOperandNumber() -
444439
producerOp.getNumInputs();
445-
Value consumerView = consumer.getShapedOperand(consumerIdx);
446440

447441
// Must be a subview or a slice to guarantee there are loops we can fuse
448442
// into.
449-
auto subView = consumerView.getDefiningOp<SubViewOp>();
450-
auto slice = consumerView.getDefiningOp<SliceOp>();
443+
auto subView = consumerOpOperand.get().getDefiningOp<SubViewOp>();
444+
auto slice = consumerOpOperand.get().getDefiningOp<SliceOp>();
451445
if (!subView && !slice) {
452446
LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
453447
return {};
454448
}
455449

456450
// Fuse `producer` just before `consumer`.
457451
OpBuilder::InsertionGuard g(b);
458-
b.setInsertionPoint(consumer.getOperation());
459-
ScopedContext scope(b, consumer.getLoc());
460-
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
452+
b.setInsertionPoint(consumerOpOperand.getOwner());
453+
ScopedContext scope(b, consumerOpOperand.getOwner()->getLoc());
454+
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: "
455+
<< *consumerOpOperand.getOwner() << "\n");
461456

462-
auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
457+
auto fusedProducer = fuse(b, producerOp, producerIdx, consumerOpOperand);
463458
return FusionInfo{producerOp, fusedProducer};
464459
}
465460

466461
/// Walk back use-def chain through scf::For yields.
467462
/// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
468-
static void getProducerOfTensor(Value tensor, LinalgOp &producer,
469-
unsigned &outputIndex) {
463+
static void getProducerOfTensor(Value tensor, OpResult &opResult) {
470464
if (!tensor.getType().isa<RankedTensorType>())
471465
return;
472466

473467
while (true) {
468+
LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
474469
if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
475-
producer = linalgOp;
476-
outputIndex = tensor.cast<OpResult>().getResultNumber();
470+
opResult = tensor.cast<OpResult>();
477471
return;
478472
}
479473
if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
@@ -482,53 +476,66 @@ static void getProducerOfTensor(Value tensor, LinalgOp &producer,
482476
}
483477
if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
484478
if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
485-
tensor = forOp.getResult(blockArg.getArgNumber());
479+
tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
486480
continue;
487481
}
488482
}
489483
return;
490484
}
491485
}
492486

493-
Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
494-
LinalgOp consumer,
495-
unsigned consumerIdx) {
496-
Value inputTensor = consumer.getInput(consumerIdx);
497-
LinalgOp producerOp;
498-
unsigned producerIdx;
499-
getProducerOfTensor(inputTensor, producerOp, producerIdx);
487+
Optional<FusionInfo>
488+
mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
489+
Value inputTensor = consumerOpOperand.get();
490+
OpResult producerOpResult;
491+
getProducerOfTensor(inputTensor, producerOpResult);
492+
if (!producerOpResult) {
493+
LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
494+
return {};
495+
}
496+
return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
497+
}
498+
499+
Optional<FusionInfo>
500+
mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
501+
OpOperand &consumerOpOperand) {
502+
auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
503+
assert(producerOp && "expected Linalg producer");
504+
LinalgOp consumerOp = cast<LinalgOp>(consumerOpOperand.getOwner());
505+
Value inputTensor = consumerOpOperand.get();
500506

501507
// Must be a subtensor to guarantee there are loops we can fuse into.
502508
auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
503-
if (!subTensor || !producerOp) {
504-
LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
509+
if (!subTensor) {
510+
LLVM_DEBUG(llvm::dbgs()
511+
<< "\nNot fusable, not a subtensor: " << inputTensor);
505512
return {};
506513
}
507514

508515
// If producer is already in the same block as consumer, we are done.
509-
if (consumer->getBlock() == producerOp->getBlock())
516+
if (consumerOpOperand.get().getParentBlock() ==
517+
producerOpResult.getParentBlock())
510518
return {};
511519

512520
// Insert fused `producer` just before `consumer`.
513521
OpBuilder::InsertionGuard g(b);
514-
b.setInsertionPoint(consumer.getOperation());
515-
ScopedContext scope(b, consumer.getLoc());
516-
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
517-
LinalgOp fusedProducer =
518-
fuse(b, producerOp, producerIdx, consumer, consumerIdx);
522+
b.setInsertionPoint(consumerOp);
523+
ScopedContext scope(b, consumerOp->getLoc());
524+
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
525+
LinalgOp fusedProducer = fuse(
526+
b, producerOp, producerOpResult.getResultNumber(), consumerOpOperand);
519527

520528
// Replace use.
521529
// Canonicalizations are not guaranteed to have happened before constructing
522530
// `fusedProducer`. In the tensor case this can result in temporary type
523531
// mismatches. Insert a `tensor.cast` op to propagate the transformation
524532
// invariant that types are compatible.
525-
Value def = fusedProducer->getResult(producerIdx);
526-
OpOperand &use = consumer->getOpOperand(consumerIdx);
527-
Type consumerType = use.get().getType();
533+
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
534+
Type consumerType = consumerOpOperand.get().getType();
528535
if (consumerType != def.getType())
529536
def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
530-
use.set(def);
531-
return FusionInfo{producerOp, fusedProducer};
537+
consumerOpOperand.set(def);
538+
return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
532539
}
533540

534541
/// Prune all dimensions that are of reduction iterator type from `map`.
@@ -734,11 +741,9 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
734741
// in the meanwhile disallow such a fusion.
735742
DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
736743
for (LinalgOp op : reverse(ops)) {
737-
for (auto operandIndex :
738-
llvm::seq<unsigned>(0, op.getNumShapedOperands())) {
744+
for (OpOperand &opOperand : op.getShapedOpOperands()) {
739745
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
740-
fusableDependence =
741-
findFusableProducer(op, operandIndex, dependenceGraph);
746+
fusableDependence = findFusableProducer(opOperand, dependenceGraph);
742747
if (!fusableDependence)
743748
continue;
744749
LinalgOp producerOp =
@@ -759,7 +764,7 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
759764
op.emitRemark(
760765
"unhandled non permutation indexing map for fused view in "
761766
"producer for operand at index ")
762-
<< operandIndex;
767+
<< opOperand.getOperandNumber();
763768
return FusableOpDependencesTy{};
764769
}
765770

@@ -770,7 +775,7 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
770775
op.emitRemark(
771776
"unhandled case where indexing map for fused view in the consumer "
772777
"is not a projected permutation while fusing at index ")
773-
<< operandIndex;
778+
<< opOperand.getOperandNumber();
774779
return FusableOpDependencesTy{};
775780
}
776781

0 commit comments

Comments
 (0)