@@ -258,20 +258,19 @@ static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
258
258
// / `producer.getOutputBuffers()`.
259
259
// / 2. Tensor case: `producerIdx` is the index of the tensor in
260
260
// / `producer.getResults()`.
261
- static LinalgOp fuse (OpBuilder &b, LinalgOp producer, unsigned producerIdx ,
262
- LinalgOp consumer, unsigned consumerIdx ) {
263
- AffineMap producerMap = producer .getOutputIndexingMap (producerIdx );
264
- LLVM_DEBUG (llvm::dbgs () << " Producer Idx: " << producerIdx
261
+ static LinalgOp fuse (OpBuilder &b, LinalgOp producerOp ,
262
+ unsigned producerOutNumber, OpOperand &consumerOpOperand ) {
263
+ AffineMap producerMap = producerOp .getOutputIndexingMap (producerOutNumber );
264
+ LLVM_DEBUG (llvm::dbgs () << " Producer Idx: " << producerOutNumber
265
265
<< " , producer map: " << producerMap << " \n " );
266
266
DenseMap<unsigned , Range> fusedLoopsAndRanges;
267
- Location loc = consumer.getLoc ();
268
- Value shapedOperand = consumer.getShapedOperand (consumerIdx);
267
+ Value shapedOperand = consumerOpOperand.get ();
269
268
for (auto en : llvm::enumerate (producerMap.getResults ())) {
270
269
unsigned posInProducerLoop = en.value ().cast <AffineDimExpr>().getPosition ();
271
- fusedLoopsAndRanges[posInProducerLoop] =
272
- getRangeFromOperandShape ( b, loc , shapedOperand, en.index ());
270
+ fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape (
271
+ b, consumerOpOperand. getOwner ()-> getLoc () , shapedOperand, en.index ());
273
272
}
274
- return fuse (b, producer , fusedLoopsAndRanges);
273
+ return fuse (b, producerOp , fusedLoopsAndRanges);
275
274
}
276
275
277
276
// Encode structural fusion safety preconditions.
@@ -378,31 +377,27 @@ static bool isSameSubView(Value a, Value b) {
378
377
}
379
378
380
379
static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
381
- findFusableProducer (LinalgOp consumer, unsigned consumerIdx ,
380
+ findFusableProducer (OpOperand &consumerOpOperand ,
382
381
const LinalgDependenceGraph &dependenceGraph) {
383
- assert (consumer.hasBufferSemantics () && " revisit usage of shaped operand" );
382
+ LinalgOp consumerOp = cast<LinalgOp>(consumerOpOperand.getOwner ());
383
+ assert (consumerOp.hasBufferSemantics () && " revisit usage of shaped operand" );
384
384
385
385
// Only consider RAW and WAW atm.
386
386
for (auto depType : {
387
387
LinalgDependenceGraph::DependenceType::RAW,
388
388
LinalgDependenceGraph::DependenceType::WAW,
389
389
}) {
390
390
for (auto dependence : llvm::make_filter_range (
391
- dependenceGraph.getDependencesInto (consumer, depType),
392
- [consumerIdx](
393
- LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
394
- return elem.indexingOpView ->getOperandNumber () == consumerIdx;
391
+ dependenceGraph.getDependencesInto (consumerOp, depType),
392
+ [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
393
+ return elem.indexingOpView ->get () == consumerOpOperand.get () &&
394
+ elem.indexingOpView ->getOperandNumber () ==
395
+ consumerOpOperand.getOperandNumber ();
395
396
})) {
396
397
397
- // Check that the dependence is indeed on the input `consumerIdx` view.
398
- Value consumedView = dependence.indexingOpView ->get ();
399
- if (!isSameSubView (consumer.getShapedOperand (consumerIdx), consumedView))
400
- continue ;
401
-
402
398
// Consumer consumes this view, `isStructurallyFusableProducer` also
403
399
// checks whether it is a strict subview of the producer view.
404
400
auto producer = cast<LinalgOp>(dependence.dependentOpView ->getOwner ());
405
- Value producedView = dependence.dependentOpView ->get ();
406
401
LLVM_DEBUG (llvm::dbgs ()
407
402
<< " \n "
408
403
<< LinalgDependenceGraph::getDependenceTypeStr (depType)
@@ -412,10 +407,10 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
412
407
<< dependence.dependentOpView ->getOperandNumber () -
413
408
producer.getNumInputs ()
414
409
<< " \n " );
415
- (void )producedView;
416
410
417
411
// Simple fusability checks.
418
- if (!isFusableInto (dependenceGraph, consumer, consumedView, producer))
412
+ if (!isFusableInto (dependenceGraph, consumerOp, consumerOpOperand.get (),
413
+ producer))
419
414
continue ;
420
415
421
416
return dependence;
@@ -425,55 +420,54 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
425
420
}
426
421
427
422
Optional<FusionInfo>
428
- mlir::linalg::fuseProducerOfBuffer (OpBuilder &b, LinalgOp consumer,
429
- unsigned consumerIdx,
423
+ mlir::linalg::fuseProducerOfBuffer (OpBuilder &b, OpOperand &consumerOpOperand,
430
424
const LinalgDependenceGraph &graph) {
431
425
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
432
- findFusableProducer (consumer, consumerIdx , graph);
426
+ findFusableProducer (consumerOpOperand , graph);
433
427
if (!fusableDependence)
434
428
return {};
435
429
436
430
LinalgOp producerOp =
437
431
cast<LinalgOp>(fusableDependence->dependentOpView ->getOwner ());
438
432
// If producer is already in the same block as consumer, we are done.
439
- if (consumer->getBlock () == producerOp->getBlock ())
433
+ if (consumerOpOperand.get ().getParentBlock () ==
434
+ fusableDependence->dependentOpView ->get ().getParentBlock ())
440
435
return {};
441
436
442
437
unsigned producerIdx =
443
438
fusableDependence->dependentOpView ->getOperandNumber () -
444
439
producerOp.getNumInputs ();
445
- Value consumerView = consumer.getShapedOperand (consumerIdx);
446
440
447
441
// Must be a subview or a slice to guarantee there are loops we can fuse
448
442
// into.
449
- auto subView = consumerView .getDefiningOp <SubViewOp>();
450
- auto slice = consumerView .getDefiningOp <SliceOp>();
443
+ auto subView = consumerOpOperand. get () .getDefiningOp <SubViewOp>();
444
+ auto slice = consumerOpOperand. get () .getDefiningOp <SliceOp>();
451
445
if (!subView && !slice) {
452
446
LLVM_DEBUG (llvm::dbgs () << " \n Not fusable (not a subview or slice)" );
453
447
return {};
454
448
}
455
449
456
450
// Fuse `producer` just before `consumer`.
457
451
OpBuilder::InsertionGuard g (b);
458
- b.setInsertionPoint (consumer.getOperation ());
459
- ScopedContext scope (b, consumer.getLoc ());
460
- LLVM_DEBUG (llvm::dbgs () << " Fuse into consumer: " << *consumer << " \n " );
452
+ b.setInsertionPoint (consumerOpOperand.getOwner ());
453
+ ScopedContext scope (b, consumerOpOperand.getOwner ()->getLoc ());
454
+ LLVM_DEBUG (llvm::dbgs () << " Fuse into consumer: "
455
+ << *consumerOpOperand.getOwner () << " \n " );
461
456
462
- auto fusedProducer = fuse (b, producerOp, producerIdx, consumer, consumerIdx );
457
+ auto fusedProducer = fuse (b, producerOp, producerIdx, consumerOpOperand );
463
458
return FusionInfo{producerOp, fusedProducer};
464
459
}
465
460
466
461
// / Walk back use-def chain through scf::For yields.
467
462
// / Sets `producer` and `outputIndex` if it finds a producer LinalgOp
468
- static void getProducerOfTensor (Value tensor, LinalgOp &producer,
469
- unsigned &outputIndex) {
463
+ static void getProducerOfTensor (Value tensor, OpResult &opResult) {
470
464
if (!tensor.getType ().isa <RankedTensorType>())
471
465
return ;
472
466
473
467
while (true ) {
468
+ LLVM_DEBUG (llvm::dbgs () << " \n getProducerOfTensor: " << tensor);
474
469
if (auto linalgOp = tensor.getDefiningOp <LinalgOp>()) {
475
- producer = linalgOp;
476
- outputIndex = tensor.cast <OpResult>().getResultNumber ();
470
+ opResult = tensor.cast <OpResult>();
477
471
return ;
478
472
}
479
473
if (auto subTensorOp = tensor.getDefiningOp <SubTensorOp>()) {
@@ -482,53 +476,66 @@ static void getProducerOfTensor(Value tensor, LinalgOp &producer,
482
476
}
483
477
if (auto blockArg = tensor.dyn_cast <BlockArgument>()) {
484
478
if (auto forOp = blockArg.getDefiningOp <scf::ForOp>()) {
485
- tensor = forOp.getResult ( blockArg.getArgNumber ());
479
+ tensor = *( forOp.getIterOperands (). begin () + blockArg.getArgNumber ());
486
480
continue ;
487
481
}
488
482
}
489
483
return ;
490
484
}
491
485
}
492
486
493
- Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor (OpBuilder &b,
494
- LinalgOp consumer,
495
- unsigned consumerIdx) {
496
- Value inputTensor = consumer.getInput (consumerIdx);
497
- LinalgOp producerOp;
498
- unsigned producerIdx;
499
- getProducerOfTensor (inputTensor, producerOp, producerIdx);
487
+ Optional<FusionInfo>
488
+ mlir::linalg::fuseProducerOfTensor (OpBuilder &b, OpOperand &consumerOpOperand) {
489
+ Value inputTensor = consumerOpOperand.get ();
490
+ OpResult producerOpResult;
491
+ getProducerOfTensor (inputTensor, producerOpResult);
492
+ if (!producerOpResult) {
493
+ LLVM_DEBUG (llvm::dbgs () << " \n Unable to find producer" );
494
+ return {};
495
+ }
496
+ return fuseProducerOfTensor (b, producerOpResult, consumerOpOperand);
497
+ }
498
+
499
+ Optional<FusionInfo>
500
+ mlir::linalg::fuseProducerOfTensor (OpBuilder &b, OpResult producerOpResult,
501
+ OpOperand &consumerOpOperand) {
502
+ auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner ());
503
+ assert (producerOp && " expected Linalg producer" );
504
+ LinalgOp consumerOp = cast<LinalgOp>(consumerOpOperand.getOwner ());
505
+ Value inputTensor = consumerOpOperand.get ();
500
506
501
507
// Must be a subtensor to guarantee there are loops we can fuse into.
502
508
auto subTensor = inputTensor.getDefiningOp <SubTensorOp>();
503
- if (!subTensor || !producerOp) {
504
- LLVM_DEBUG (llvm::dbgs () << " \n Not fusable (not a subtensor)" );
509
+ if (!subTensor) {
510
+ LLVM_DEBUG (llvm::dbgs ()
511
+ << " \n Not fusable, not a subtensor: " << inputTensor);
505
512
return {};
506
513
}
507
514
508
515
// If producer is already in the same block as consumer, we are done.
509
- if (consumer->getBlock () == producerOp->getBlock ())
516
+ if (consumerOpOperand.get ().getParentBlock () ==
517
+ producerOpResult.getParentBlock ())
510
518
return {};
511
519
512
520
// Insert fused `producer` just before `consumer`.
513
521
OpBuilder::InsertionGuard g (b);
514
- b.setInsertionPoint (consumer. getOperation () );
515
- ScopedContext scope (b, consumer. getLoc ());
516
- LLVM_DEBUG (llvm::dbgs () << " Fuse into consumer: " << *consumer << " \n " );
517
- LinalgOp fusedProducer =
518
- fuse ( b, producerOp, producerIdx, consumer, consumerIdx );
522
+ b.setInsertionPoint (consumerOp );
523
+ ScopedContext scope (b, consumerOp-> getLoc ());
524
+ LLVM_DEBUG (llvm::dbgs () << " Fuse into consumer: " << *consumerOp << " \n " );
525
+ LinalgOp fusedProducer = fuse (
526
+ b, producerOp, producerOpResult. getResultNumber (), consumerOpOperand );
519
527
520
528
// Replace use.
521
529
// Canonicalizations are not guaranteed to have happened before constructing
522
530
// `fusedProducer`. In the tensor case this can result in temporary type
523
531
// mismatches. Insert a `tensor.cast` op to propagate the transformation
524
532
// invariant that types are compatible.
525
- Value def = fusedProducer->getResult (producerIdx);
526
- OpOperand &use = consumer->getOpOperand (consumerIdx);
527
- Type consumerType = use.get ().getType ();
533
+ Value def = fusedProducer->getResult (producerOpResult.getResultNumber ());
534
+ Type consumerType = consumerOpOperand.get ().getType ();
528
535
if (consumerType != def.getType ())
529
536
def = b.create <tensor::CastOp>(fusedProducer.getLoc (), consumerType, def);
530
- use .set (def);
531
- return FusionInfo{producerOp , fusedProducer};
537
+ consumerOpOperand .set (def);
538
+ return FusionInfo{cast<LinalgOp>(producerOpResult. getOwner ()) , fusedProducer};
532
539
}
533
540
534
541
// / Prune all dimensions that are of reduction iterator type from `map`.
@@ -734,11 +741,9 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
734
741
// in the meanwhile disallow such a fusion.
735
742
DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
736
743
for (LinalgOp op : reverse (ops)) {
737
- for (auto operandIndex :
738
- llvm::seq<unsigned >(0 , op.getNumShapedOperands ())) {
744
+ for (OpOperand &opOperand : op.getShapedOpOperands ()) {
739
745
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
740
- fusableDependence =
741
- findFusableProducer (op, operandIndex, dependenceGraph);
746
+ fusableDependence = findFusableProducer (opOperand, dependenceGraph);
742
747
if (!fusableDependence)
743
748
continue ;
744
749
LinalgOp producerOp =
@@ -759,7 +764,7 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
759
764
op.emitRemark (
760
765
" unhandled non permutation indexing map for fused view in "
761
766
" producer for operand at index " )
762
- << operandIndex ;
767
+ << opOperand. getOperandNumber () ;
763
768
return FusableOpDependencesTy{};
764
769
}
765
770
@@ -770,7 +775,7 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
770
775
op.emitRemark (
771
776
" unhandled case where indexing map for fused view in the consumer "
772
777
" is not a projected permutation while fusing at index " )
773
- << operandIndex ;
778
+ << opOperand. getOperandNumber () ;
774
779
return FusableOpDependencesTy{};
775
780
}
776
781
0 commit comments