@@ -331,7 +331,7 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
331
331
// / For `consumer` with buffer semantics, find the Linalg operation on buffers
332
332
// / that is the last writer of `consumerOpOperand`. For now the fusable
333
333
// / dependence is returned as an instance of the `dependenceGraph`.
334
- static Optional <LinalgDependenceGraph::LinalgDependenceGraphElem>
334
+ static FailureOr <LinalgDependenceGraph::LinalgDependenceGraphElem>
335
335
findFusableProducer (OpOperand &consumerOpOperand,
336
336
const LinalgDependenceGraph &dependenceGraph) {
337
337
LLVM_DEBUG (llvm::dbgs () << " findFusableProducer for: "
@@ -340,7 +340,7 @@ findFusableProducer(OpOperand &consumerOpOperand,
340
340
<< *consumerOpOperand.getOwner () << " \n " );
341
341
LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner ());
342
342
if (!consumerOp)
343
- return {} ;
343
+ return failure () ;
344
344
345
345
// Only consider RAW and WAW atm.
346
346
for (auto depType : {
@@ -386,37 +386,37 @@ findFusableProducer(OpOperand &consumerOpOperand,
386
386
}
387
387
}
388
388
}
389
- return {} ;
389
+ return failure () ;
390
390
}
391
391
392
- Optional <FusionInfo>
392
+ FailureOr <FusionInfo>
393
393
mlir::linalg::fuseProducerOfBuffer (OpBuilder &b, OpOperand &consumerOpOperand,
394
394
const LinalgDependenceGraph &graph) {
395
395
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
396
396
findFusableProducer (consumerOpOperand, graph);
397
397
if (!fusableDependence)
398
- return llvm::None ;
398
+ return failure () ;
399
399
400
400
LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp ());
401
401
if (!producerOp)
402
- return llvm::None ;
402
+ return failure () ;
403
403
404
404
// If producer is already in the same block as consumer, we are done.
405
405
if (consumerOpOperand.get ().getParentBlock () ==
406
406
fusableDependence->getDependentValue ().getParentBlock ())
407
- return llvm::None ;
407
+ return failure () ;
408
408
409
409
Optional<AffineMap> producerMap =
410
410
fusableDependence->getDependentOpViewIndexingMap ();
411
411
if (!producerMap)
412
- return llvm::None ;
412
+ return failure () ;
413
413
414
414
// Must be a subview or an extract_slice to guarantee there are loops we can
415
415
// fuse into.
416
416
auto subView = consumerOpOperand.get ().getDefiningOp <memref::SubViewOp>();
417
417
if (!subView) {
418
418
LLVM_DEBUG (llvm::dbgs () << " \n Not fusable (not a subview)" );
419
- return llvm::None ;
419
+ return failure () ;
420
420
}
421
421
422
422
// Fuse `producer` just before `consumer`.
@@ -459,28 +459,28 @@ static void getProducerOfTensor(Value tensor, OpResult &opResult) {
459
459
}
460
460
}
461
461
462
- Optional <FusionInfo>
462
+ FailureOr <FusionInfo>
463
463
mlir::linalg::fuseProducerOfTensor (OpBuilder &b, OpOperand &consumerOpOperand) {
464
464
Value inputTensor = consumerOpOperand.get ();
465
465
OpResult producerOpResult;
466
466
getProducerOfTensor (inputTensor, producerOpResult);
467
467
if (!producerOpResult) {
468
468
LLVM_DEBUG (llvm::dbgs () << " \n Unable to find producer" );
469
- return {} ;
469
+ return failure () ;
470
470
}
471
471
return fuseProducerOfTensor (b, producerOpResult, consumerOpOperand);
472
472
}
473
473
474
- Optional <FusionInfo>
474
+ FailureOr <FusionInfo>
475
475
mlir::linalg::fuseProducerOfTensor (OpBuilder &b, OpResult producerOpResult,
476
476
OpOperand &consumerOpOperand) {
477
477
auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner ());
478
478
if (!producerOp)
479
- return llvm::None ;
479
+ return failure () ;
480
480
481
481
LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner ());
482
482
if (!consumerOp)
483
- return llvm::None ;
483
+ return failure () ;
484
484
485
485
Value inputTensor = consumerOpOperand.get ();
486
486
@@ -489,13 +489,13 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
489
489
if (!sliceOp) {
490
490
LLVM_DEBUG (llvm::dbgs ()
491
491
<< " \n Not fusable, not an extract_slice op: " << inputTensor);
492
- return {} ;
492
+ return failure () ;
493
493
}
494
494
495
495
// If producer is already in the same block as consumer, we are done.
496
496
if (consumerOpOperand.get ().getParentBlock () ==
497
497
producerOpResult.getParentBlock ())
498
- return {} ;
498
+ return failure () ;
499
499
500
500
// Insert fused `producer` just before `consumer`.
501
501
OpBuilder::InsertionGuard g (b);
@@ -537,27 +537,27 @@ static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
537
537
// / - indexing map of the fused view in the producer : producerIndexMap
538
538
// / consumerLoopToProducerLoop =
539
539
// / inverse(producerIndexMap).compose(consumerIndexMap)
540
- static Optional <AffineMap> getConsumerLoopToProducerLoopMap (
540
+ static FailureOr <AffineMap> getConsumerLoopToProducerLoopMap (
541
541
LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
542
542
auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp ());
543
543
if (!producer)
544
- return None ;
544
+ return failure () ;
545
545
546
546
Optional<AffineMap> producerIndexingMap =
547
547
dependence.getDependentOpViewIndexingMap ();
548
548
Optional<AffineMap> consumerIndexingMap =
549
549
dependence.getIndexingOpViewIndexingMap ();
550
550
if (!producerIndexingMap || !consumerIndexingMap)
551
- return None ;
551
+ return failure () ;
552
552
553
553
AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap (
554
554
producer.iterator_types ().getValue (), *producerIndexingMap);
555
555
if (!prunedProducerIndexingMap.isPermutation ())
556
- return None ;
556
+ return failure () ;
557
557
558
558
if (consumerIndexingMap->getNumResults () !=
559
559
prunedProducerIndexingMap.getNumResults ())
560
- return None ;
560
+ return failure () ;
561
561
562
562
LLVM_DEBUG ({
563
563
llvm::dbgs () << " \t producerMap : " ;
@@ -572,7 +572,7 @@ static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
572
572
573
573
AffineMap invProducerIndexMap = inversePermutation (prunedProducerIndexingMap);
574
574
if (!invProducerIndexMap)
575
- return None ;
575
+ return failure () ;
576
576
577
577
return invProducerIndexMap.compose (*consumerIndexingMap);
578
578
}
@@ -776,7 +776,7 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
776
776
777
777
// / Tile the fused loops in the root operation, by setting the tile sizes for
778
778
// / all other loops to zero (those will be tiled later).
779
- static Optional <TiledLinalgOp>
779
+ static FailureOr <TiledLinalgOp>
780
780
tileRootOperation (OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizeVector,
781
781
const LinalgTilingOptions &options,
782
782
const std::set<unsigned > &fusedLoops) {
@@ -871,12 +871,12 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
871
871
return fusedOps;
872
872
}
873
873
874
- static Optional <TiledAndFusedLinalgOps>
874
+ static FailureOr <TiledAndFusedLinalgOps>
875
875
tileAndFuseLinalgOpsImpl (OpBuilder &b, ArrayRef<LinalgOp> ops,
876
876
const LinalgDependenceGraph &dependenceGraph,
877
877
const LinalgTilingOptions &tilingOptions) {
878
878
if (ops.size () < 2 )
879
- return llvm::None ;
879
+ return failure () ;
880
880
LinalgOp rootOp = ops.back ();
881
881
if (!llvm::all_of (
882
882
ops,
@@ -887,13 +887,13 @@ tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
887
887
rootOp.emitError (
888
888
" unable to fuse operations that have tensor semantics with operations "
889
889
" that have buffer semantics and viceversa." );
890
- return llvm::None ;
890
+ return failure () ;
891
891
}
892
892
// TODO: Support interchange with tile + fuse. This might actually help do
893
893
// better fusion.
894
894
if (!tilingOptions.interchangeVector .empty ()) {
895
895
rootOp.emitRemark (" unable to handle tile and fuse with interchange" );
896
- return llvm::None ;
896
+ return failure () ;
897
897
}
898
898
899
899
OpBuilder::InsertionGuard guard (b);
@@ -905,7 +905,7 @@ tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
905
905
findAllFusableDependences (ops, dependenceGraph);
906
906
if (fusableDependences.empty ()) {
907
907
LLVM_DEBUG (llvm::dbgs () << " no fusable dependencies found\n " );
908
- return llvm::None ;
908
+ return failure () ;
909
909
}
910
910
911
911
TiledAndFusedLinalgOps ret;
@@ -917,17 +917,17 @@ tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
917
917
// just return.
918
918
if (ret.fusedLoopDims .empty ()) {
919
919
LLVM_DEBUG (llvm::dbgs () << " no fusable loops found\n " );
920
- return llvm::None ;
920
+ return failure () ;
921
921
}
922
922
923
923
// Tile the fused loops in the last operation in the list.
924
924
SmallVector<Value, 4 > tileSizeVector =
925
925
tilingOptions.tileSizeComputationFunction (b, rootOp);
926
- Optional <TiledLinalgOp> tiledRootOp = tileRootOperation (
926
+ FailureOr <TiledLinalgOp> tiledRootOp = tileRootOperation (
927
927
b, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims );
928
- if (! tiledRootOp) {
928
+ if (failed ( tiledRootOp) ) {
929
929
rootOp.emitRemark (" failed to tile the fused loops" );
930
- return llvm::None ;
930
+ return failure () ;
931
931
}
932
932
ret.op = tiledRootOp->op ;
933
933
ret.fusedLoops .assign (tiledRootOp->loops .begin (), tiledRootOp->loops .end ());
@@ -939,7 +939,7 @@ tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
939
939
return ret;
940
940
}
941
941
942
- Optional <TiledAndFusedLinalgOps>
942
+ FailureOr <TiledAndFusedLinalgOps>
943
943
mlir::linalg::tileAndFuseLinalgOps (OpBuilder &b, ArrayRef<LinalgOp> ops,
944
944
const LinalgDependenceGraph &dependenceGraph,
945
945
const LinalgTilingOptions &tilingOptions) {
@@ -950,5 +950,5 @@ mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops,
950
950
return tileAndFuseLinalgOpsImpl (b, ops, dependenceGraph, tilingOptions);
951
951
default :;
952
952
}
953
- return llvm::None ;
953
+ return failure () ;
954
954
}
0 commit comments