Skip to content

Commit 489fec2

Browse files
[mlir][Linalg] NFC - Drop Optional in favor of FailureOr
Differential revision: https://reviews.llvm.org/D112332
1 parent 58e7ec4 commit 489fec2

File tree

7 files changed

+111
-105
lines changed

7 files changed

+111
-105
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ struct TiledLinalgOp {
158158
SmallVector<Operation *, 8> loops;
159159
SmallVector<Value, 4> tensorResults;
160160
};
161-
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
162-
const LinalgTilingOptions &options);
161+
FailureOr<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
162+
const LinalgTilingOptions &options);
163163

164164
/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
165165
/// proceeds as follows:
@@ -221,7 +221,7 @@ struct TiledAndFusedLinalgOps {
221221
/// The fused loop generated.
222222
SmallVector<Operation *, 4> fusedLoops;
223223
};
224-
Optional<TiledAndFusedLinalgOps>
224+
FailureOr<TiledAndFusedLinalgOps>
225225
tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
226226
const LinalgDependenceGraph &dependenceGraph,
227227
const LinalgTilingOptions &tilingOptions);
@@ -344,7 +344,7 @@ struct PromotionInfo {
344344
Value fullLocalView;
345345
Value partialLocalView;
346346
};
347-
Optional<PromotionInfo>
347+
FailureOr<PromotionInfo>
348348
promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
349349
AllocBufferCallbackFn allocationFn,
350350
DataLayout &layout);
@@ -359,24 +359,24 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
359359
///
360360
/// Returns the modified linalg op (the modification happens in place) as well
361361
/// as all the copy ops created.
362-
Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
363-
LinalgPromotionOptions options);
362+
FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
363+
LinalgPromotionOptions options);
364364

365365
/// Emit a suitable vector form for a Linalg op with fully static shape.
366366
LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
367367
SmallVectorImpl<Value> &newResults);
368368

369369
/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
370-
Optional<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
371-
LinalgOp linalgOp);
370+
FailureOr<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
371+
LinalgOp linalgOp);
372372

373373
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
374-
Optional<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
375-
LinalgOp linalgOp);
374+
FailureOr<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
375+
LinalgOp linalgOp);
376376

377377
/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
378-
Optional<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
379-
LinalgOp linalgOp);
378+
FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
379+
LinalgOp linalgOp);
380380

381381
//===----------------------------------------------------------------------===//
382382
// Preconditions that ensure the corresponding transformation succeeds and can
@@ -961,15 +961,15 @@ struct LinalgLoweringPattern : public RewritePattern {
961961
// TODO: Move lowering to library calls here.
962962
return failure();
963963
case LinalgLoweringType::Loops:
964-
if (!linalgOpToLoops(rewriter, op))
964+
if (failed(linalgOpToLoops(rewriter, op)))
965965
return failure();
966966
break;
967967
case LinalgLoweringType::AffineLoops:
968-
if (!linalgOpToAffineLoops(rewriter, op))
968+
if (failed(linalgOpToAffineLoops(rewriter, op)))
969969
return failure();
970970
break;
971971
case LinalgLoweringType::ParallelLoops:
972-
if (!linalgOpToParallelLoops(rewriter, op))
972+
if (failed(linalgOpToParallelLoops(rewriter, op)))
973973
return failure();
974974
break;
975975
}

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,25 +164,25 @@ struct FusionInfo {
164164
/// Implements the fusion part of the "tileAndFuse on buffers" transformation
165165
/// and thus requires the `consumerOpOperand` to be a `subview` op (generally
166166
/// obtained by applying the tiling transformation).
167-
Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b,
168-
OpOperand &consumerOpOperand,
169-
const LinalgDependenceGraph &graph);
167+
FailureOr<FusionInfo> fuseProducerOfBuffer(OpBuilder &b,
168+
OpOperand &consumerOpOperand,
169+
const LinalgDependenceGraph &graph);
170170
/// Tensor counterpart of `fuseProducerOfBuffer`.
171171
/// This implements the fusion part of the "tileAndFuse on tensors"
172172
/// transformation and thus requires the `consumerOpOperand` to be a
173173
/// `extract_slice` op (generally obtained by applying the tiling
174174
/// transformation).
175-
Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
176-
OpOperand &consumerOpOperand);
175+
FailureOr<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
176+
OpOperand &consumerOpOperand);
177177
/// Tensor counterpart of `fuseProducerOfBuffer`.
178178
/// This implements the fusion part of the "tileAndFuse on tensors"
179179
/// transformation and thus requires the `consumerOpOperand` to be a
180180
/// `extract_slice` op (generally obtained by applying the tiling
181181
/// transformation). Assumes `producerOfTensor` is a Linalg op that produces
182182
/// `consumerOpOperand`.
183-
Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
184-
OpResult producerOpResult,
185-
OpOperand &consumerOpOperand);
183+
FailureOr<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
184+
OpResult producerOpResult,
185+
OpOperand &consumerOpOperand);
186186

187187
//===----------------------------------------------------------------------===//
188188
// Fusion on tensor utilities

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
331331
/// For `consumer` with buffer semantics, find the Linalg operation on buffers
332332
/// that is the last writer of `consumerOpOperand`. For now the fusable
333333
/// dependence is returned as an instance of the `dependenceGraph`.
334-
static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
334+
static FailureOr<LinalgDependenceGraph::LinalgDependenceGraphElem>
335335
findFusableProducer(OpOperand &consumerOpOperand,
336336
const LinalgDependenceGraph &dependenceGraph) {
337337
LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: "
@@ -340,7 +340,7 @@ findFusableProducer(OpOperand &consumerOpOperand,
340340
<< *consumerOpOperand.getOwner() << "\n");
341341
LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
342342
if (!consumerOp)
343-
return {};
343+
return failure();
344344

345345
// Only consider RAW and WAW atm.
346346
for (auto depType : {
@@ -386,37 +386,37 @@ findFusableProducer(OpOperand &consumerOpOperand,
386386
}
387387
}
388388
}
389-
return {};
389+
return failure();
390390
}
391391

392-
Optional<FusionInfo>
392+
FailureOr<FusionInfo>
393393
mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
394394
const LinalgDependenceGraph &graph) {
395395
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
396396
findFusableProducer(consumerOpOperand, graph);
397397
if (!fusableDependence)
398-
return llvm::None;
398+
return failure();
399399

400400
LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
401401
if (!producerOp)
402-
return llvm::None;
402+
return failure();
403403

404404
// If producer is already in the same block as consumer, we are done.
405405
if (consumerOpOperand.get().getParentBlock() ==
406406
fusableDependence->getDependentValue().getParentBlock())
407-
return llvm::None;
407+
return failure();
408408

409409
Optional<AffineMap> producerMap =
410410
fusableDependence->getDependentOpViewIndexingMap();
411411
if (!producerMap)
412-
return llvm::None;
412+
return failure();
413413

414414
// Must be a subview or an extract_slice to guarantee there are loops we can
415415
// fuse into.
416416
auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>();
417417
if (!subView) {
418418
LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)");
419-
return llvm::None;
419+
return failure();
420420
}
421421

422422
// Fuse `producer` just before `consumer`.
@@ -459,28 +459,28 @@ static void getProducerOfTensor(Value tensor, OpResult &opResult) {
459459
}
460460
}
461461

462-
Optional<FusionInfo>
462+
FailureOr<FusionInfo>
463463
mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
464464
Value inputTensor = consumerOpOperand.get();
465465
OpResult producerOpResult;
466466
getProducerOfTensor(inputTensor, producerOpResult);
467467
if (!producerOpResult) {
468468
LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
469-
return {};
469+
return failure();
470470
}
471471
return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
472472
}
473473

474-
Optional<FusionInfo>
474+
FailureOr<FusionInfo>
475475
mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
476476
OpOperand &consumerOpOperand) {
477477
auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
478478
if (!producerOp)
479-
return llvm::None;
479+
return failure();
480480

481481
LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
482482
if (!consumerOp)
483-
return llvm::None;
483+
return failure();
484484

485485
Value inputTensor = consumerOpOperand.get();
486486

@@ -489,13 +489,13 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
489489
if (!sliceOp) {
490490
LLVM_DEBUG(llvm::dbgs()
491491
<< "\nNot fusable, not an extract_slice op: " << inputTensor);
492-
return {};
492+
return failure();
493493
}
494494

495495
// If producer is already in the same block as consumer, we are done.
496496
if (consumerOpOperand.get().getParentBlock() ==
497497
producerOpResult.getParentBlock())
498-
return {};
498+
return failure();
499499

500500
// Insert fused `producer` just before `consumer`.
501501
OpBuilder::InsertionGuard g(b);
@@ -537,27 +537,27 @@ static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
537537
/// - indexing map of the fused view in the producer : producerIndexMap
538538
/// consumerLoopToProducerLoop =
539539
/// inverse(producerIndexMap).compose(consumerIndexMap)
540-
static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
540+
static FailureOr<AffineMap> getConsumerLoopToProducerLoopMap(
541541
LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
542542
auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp());
543543
if (!producer)
544-
return None;
544+
return failure();
545545

546546
Optional<AffineMap> producerIndexingMap =
547547
dependence.getDependentOpViewIndexingMap();
548548
Optional<AffineMap> consumerIndexingMap =
549549
dependence.getIndexingOpViewIndexingMap();
550550
if (!producerIndexingMap || !consumerIndexingMap)
551-
return None;
551+
return failure();
552552

553553
AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
554554
producer.iterator_types().getValue(), *producerIndexingMap);
555555
if (!prunedProducerIndexingMap.isPermutation())
556-
return None;
556+
return failure();
557557

558558
if (consumerIndexingMap->getNumResults() !=
559559
prunedProducerIndexingMap.getNumResults())
560-
return None;
560+
return failure();
561561

562562
LLVM_DEBUG({
563563
llvm::dbgs() << "\t producerMap : ";
@@ -572,7 +572,7 @@ static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
572572

573573
AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
574574
if (!invProducerIndexMap)
575-
return None;
575+
return failure();
576576

577577
return invProducerIndexMap.compose(*consumerIndexingMap);
578578
}
@@ -776,7 +776,7 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
776776

777777
/// Tile the fused loops in the root operation, by setting the tile sizes for
778778
/// all other loops to zero (those will be tiled later).
779-
static Optional<TiledLinalgOp>
779+
static FailureOr<TiledLinalgOp>
780780
tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizeVector,
781781
const LinalgTilingOptions &options,
782782
const std::set<unsigned> &fusedLoops) {
@@ -871,12 +871,12 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
871871
return fusedOps;
872872
}
873873

874-
static Optional<TiledAndFusedLinalgOps>
874+
static FailureOr<TiledAndFusedLinalgOps>
875875
tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
876876
const LinalgDependenceGraph &dependenceGraph,
877877
const LinalgTilingOptions &tilingOptions) {
878878
if (ops.size() < 2)
879-
return llvm::None;
879+
return failure();
880880
LinalgOp rootOp = ops.back();
881881
if (!llvm::all_of(
882882
ops,
@@ -887,13 +887,13 @@ tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
887887
rootOp.emitError(
888888
"unable to fuse operations that have tensor semantics with operations "
889889
"that have buffer semantics and viceversa.");
890-
return llvm::None;
890+
return failure();
891891
}
892892
// TODO: Support interchange with tile + fuse. This might actually help do
893893
// better fusion.
894894
if (!tilingOptions.interchangeVector.empty()) {
895895
rootOp.emitRemark("unable to handle tile and fuse with interchange");
896-
return llvm::None;
896+
return failure();
897897
}
898898

899899
OpBuilder::InsertionGuard guard(b);
@@ -905,7 +905,7 @@ tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
905905
findAllFusableDependences(ops, dependenceGraph);
906906
if (fusableDependences.empty()) {
907907
LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n");
908-
return llvm::None;
908+
return failure();
909909
}
910910

911911
TiledAndFusedLinalgOps ret;
@@ -917,17 +917,17 @@ tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
917917
// just return.
918918
if (ret.fusedLoopDims.empty()) {
919919
LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n");
920-
return llvm::None;
920+
return failure();
921921
}
922922

923923
// Tile the fused loops in the last operation in the list.
924924
SmallVector<Value, 4> tileSizeVector =
925925
tilingOptions.tileSizeComputationFunction(b, rootOp);
926-
Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
926+
FailureOr<TiledLinalgOp> tiledRootOp = tileRootOperation(
927927
b, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
928-
if (!tiledRootOp) {
928+
if (failed(tiledRootOp)) {
929929
rootOp.emitRemark("failed to tile the fused loops");
930-
return llvm::None;
930+
return failure();
931931
}
932932
ret.op = tiledRootOp->op;
933933
ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
@@ -939,7 +939,7 @@ tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
939939
return ret;
940940
}
941941

942-
Optional<TiledAndFusedLinalgOps>
942+
FailureOr<TiledAndFusedLinalgOps>
943943
mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops,
944944
const LinalgDependenceGraph &dependenceGraph,
945945
const LinalgTilingOptions &tilingOptions) {
@@ -950,5 +950,5 @@ mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops,
950950
return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions);
951951
default:;
952952
}
953-
return llvm::None;
953+
return failure();
954954
}

0 commit comments

Comments
 (0)