Skip to content

Commit 26821f7

Browse files
committed
[mlir][NFC] accept plain OpBuidler in folded construction helpers
A group of functions in the Affine dialect provides a mechanism for buliding folded-by-construction operations. These functions used to accept a `RewriterBase` reference because they may need to erase the operations that were folded and notify the rewriter when called from rewrite patterns. Adopt a different approach: postpone the builder notification of the op creation until we are certain that the op will not be folded away. This removes the need to notify the rewriter about op deletion following op construction in case of successful folding, and removes a bunch of one-off `IRRewriter` instances in transform code that may mess up insertion points. Reviewed By: springerm, mravishankar Differential Revision: https://reviews.llvm.org/D130616
1 parent ad16268 commit 26821f7

File tree

4 files changed

+58
-48
lines changed

4 files changed

+58
-48
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ namespace mlir {
2525
class AffineApplyOp;
2626
class AffineBound;
2727
class AffineValueMap;
28-
class RewriterBase;
2928

3029
/// TODO: These should be renamed if they are on the mlir namespace.
3130
/// Ideally, they should go in a mlir::affine:: namespace.
@@ -384,21 +383,20 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
384383
/// Constructs an AffineApplyOp that applies `map` to `operands` after composing
385384
/// the map with the maps of any other AffineApplyOp supplying the operands,
386385
/// then immediately attempts to fold it. If folding results in a constant
387-
/// value, erases all created ops. The `map` must be a single-result affine map.
388-
OpFoldResult makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
386+
/// value, no ops are actually created. The `map` must be a single-result affine
387+
/// map.
388+
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
389389
AffineMap map,
390390
ArrayRef<OpFoldResult> operands);
391391
/// Variant of `makeComposedFoldedAffineApply` that applies to an expression.
392-
OpFoldResult makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
392+
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
393393
AffineExpr expr,
394394
ArrayRef<OpFoldResult> operands);
395395
/// Variant of `makeComposedFoldedAffineApply` suitable for multi-result maps.
396396
/// Note that this may create as many affine.apply operations as the map has
397397
/// results given that affine.apply must be single-result.
398-
SmallVector<OpFoldResult>
399-
makeComposedFoldedMultiResultAffineApply(RewriterBase &b, Location loc,
400-
AffineMap map,
401-
ArrayRef<OpFoldResult> operands);
398+
SmallVector<OpFoldResult> makeComposedFoldedMultiResultAffineApply(
399+
OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands);
402400

403401
/// Returns an AffineMinOp obtained by composing `map` and `operands` with
404402
/// AffineApplyOps supplying those operands.
@@ -407,15 +405,15 @@ Value makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map,
407405

408406
/// Constructs an AffineMinOp that computes a minimum across the results of
409407
/// applying `map` to `operands`, then immediately attempts to fold it. If
410-
/// folding results in a constant value, erases all created ops.
411-
OpFoldResult makeComposedFoldedAffineMin(RewriterBase &b, Location loc,
408+
/// folding results in a constant value, no ops are actually created.
409+
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc,
412410
AffineMap map,
413411
ArrayRef<OpFoldResult> operands);
414412

415413
/// Constructs an AffineMinOp that computes a maximum across the results of
416414
/// applying `map` to `operands`, then immediately attempts to fold it. If
417-
/// folding results in a constant value, erases all created ops.
418-
OpFoldResult makeComposedFoldedAffineMax(RewriterBase &b, Location loc,
415+
/// folding results in a constant value, no ops are actually created.
416+
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc,
419417
AffineMap map,
420418
ArrayRef<OpFoldResult> operands);
421419

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/IR/OpDefinition.h"
1818
#include "mlir/IR/PatternMatch.h"
1919
#include "mlir/Transforms/InliningUtils.h"
20+
#include "llvm/ADT/ScopeExit.h"
2021
#include "llvm/ADT/SmallBitVector.h"
2122
#include "llvm/ADT/TypeSwitch.h"
2223
#include "llvm/Support/Debug.h"
@@ -709,11 +710,19 @@ void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
709710
/// Given a list of `OpFoldResult`, build the necessary operations to populate
710711
/// `actualValues` with values produced by operations. In particular, for any
711712
/// attribute-typed element in `values`, call the constant materializer
712-
/// associated with the Affine dialect to produce an operation.
713+
/// associated with the Affine dialect to produce an operation. Do NOT notify
714+
/// the builder listener about the constant ops being created as they are
715+
/// intended to be removed after being folded into affine constructs; this is
716+
/// not suitable for use beyond the Affine dialect.
713717
static void materializeConstants(OpBuilder &b, Location loc,
714718
ArrayRef<OpFoldResult> values,
715719
SmallVectorImpl<Operation *> &constants,
716720
SmallVectorImpl<Value> &actualValues) {
721+
OpBuilder::Listener *listener = b.getListener();
722+
b.setListener(nullptr);
723+
auto listenerResetter =
724+
llvm::make_scope_exit([listener, &b] { b.setListener(listener); });
725+
717726
actualValues.reserve(values.size());
718727
auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
719728
for (OpFoldResult ofr : values) {
@@ -742,7 +751,7 @@ static void materializeConstants(OpBuilder &b, Location loc,
742751
template <typename OpTy, typename... Args>
743752
static std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(),
744753
OpFoldResult>
745-
createOrFold(RewriterBase &b, Location loc, ValueRange operands,
754+
createOrFold(OpBuilder &b, Location loc, ValueRange operands,
746755
Args &&...leadingArguments) {
747756
// Identify the constant operands and extract their values as attributes.
748757
// Note that we cannot use the original values directly because the list of
@@ -759,17 +768,30 @@ createOrFold(RewriterBase &b, Location loc, ValueRange operands,
759768

760769
// Create the operation and immediately attempt to fold it. On success,
761770
// delete the operation and prepare the (unmaterialized) value for being
762-
// returned. On failure, return the operation result value.
771+
// returned. On failure, return the operation result value. Temporarily remove
772+
// the listener to avoid notifying it when the op is created as it may be
773+
// removed immediately and there is no way of notifying the caller about that
774+
// without resorting to RewriterBase.
775+
//
763776
// TODO: arguably, the main folder (createOrFold) API should support this use
764777
// case instead of indiscriminately materializing constants.
778+
OpBuilder::Listener *listener = b.getListener();
779+
b.setListener(nullptr);
780+
auto listenerResetter =
781+
llvm::make_scope_exit([listener, &b] { b.setListener(listener); });
765782
OpTy op =
766783
b.create<OpTy>(loc, std::forward<Args>(leadingArguments)..., operands);
767784
SmallVector<OpFoldResult, 1> foldResults;
768785
if (succeeded(op->fold(constantOperands, foldResults)) &&
769786
!foldResults.empty()) {
770-
b.eraseOp(op);
787+
op->erase();
771788
return foldResults.front();
772789
}
790+
791+
// Notify the listener now that we definitely know that the operation will
792+
// persist. Use the original listener stored in the variable.
793+
if (listener)
794+
listener->notifyOperationInserted(op);
773795
return op->getResult(0);
774796
}
775797

@@ -821,8 +843,7 @@ static void composeMultiResultAffineMap(AffineMap &map,
821843
}
822844

823845
OpFoldResult
824-
mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
825-
AffineMap map,
846+
mlir::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map,
826847
ArrayRef<OpFoldResult> operands) {
827848
assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
828849

@@ -835,21 +856,20 @@ mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
835856
// Constants are always folded into affine min/max because they can be
836857
// represented as constant expressions, so delete them.
837858
for (Operation *op : constants)
838-
b.eraseOp(op);
859+
op->erase();
839860
return result;
840861
}
841862

842863
OpFoldResult
843-
mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
844-
AffineExpr expr,
864+
mlir::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineExpr expr,
845865
ArrayRef<OpFoldResult> operands) {
846866
return makeComposedFoldedAffineApply(
847867
b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}).front(),
848868
operands);
849869
}
850870

851871
SmallVector<OpFoldResult> mlir::makeComposedFoldedMultiResultAffineApply(
852-
RewriterBase &b, Location loc, AffineMap map,
872+
OpBuilder &b, Location loc, AffineMap map,
853873
ArrayRef<OpFoldResult> operands) {
854874
return llvm::to_vector(llvm::map_range(
855875
llvm::seq<unsigned>(0, map.getNumResults()), [&](unsigned i) {
@@ -866,7 +886,7 @@ Value mlir::makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map,
866886
}
867887

868888
template <typename OpTy>
869-
static OpFoldResult makeComposedFoldedMinMax(RewriterBase &b, Location loc,
889+
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
870890
AffineMap map,
871891
ArrayRef<OpFoldResult> operands) {
872892
SmallVector<Operation *> constants;
@@ -879,18 +899,18 @@ static OpFoldResult makeComposedFoldedMinMax(RewriterBase &b, Location loc,
879899
// Constants are always folded into affine min/max because they can be
880900
// represented as constant expressions, so delete them.
881901
for (Operation *op : constants)
882-
b.eraseOp(op);
902+
op->erase();
883903
return result;
884904
}
885905

886906
OpFoldResult
887-
mlir::makeComposedFoldedAffineMin(RewriterBase &b, Location loc, AffineMap map,
907+
mlir::makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map,
888908
ArrayRef<OpFoldResult> operands) {
889909
return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
890910
}
891911

892912
OpFoldResult
893-
mlir::makeComposedFoldedAffineMax(RewriterBase &b, Location loc, AffineMap map,
913+
mlir::makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map,
894914
ArrayRef<OpFoldResult> operands) {
895915
return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
896916
}

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,11 @@ mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
132132
SmallVector<OpFoldResult> allShapes =
133133
op.createFlatListOfOperandDims(b, b.getLoc());
134134
AffineMap shapesToLoops = op.getShapesToLoopsMap();
135-
IRRewriter rewriter(b);
136135
SmallVector<OpFoldResult> loopRanges =
137-
makeComposedFoldedMultiResultAffineApply(rewriter, op.getLoc(),
138-
shapesToLoops, allShapes);
136+
makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
137+
allShapes);
139138
Value tripCount =
140-
materializeOpFoldResult(rewriter, op.getLoc(), loopRanges[dimension]);
139+
materializeOpFoldResult(b, op.getLoc(), loopRanges[dimension]);
141140

142141
// Compute the tile sizes and the respective numbers of tiles.
143142
AffineExpr s0 = b.getAffineSymbolExpr(0);
@@ -206,19 +205,17 @@ static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
206205
/// Build an `affine_max` of all the `vals`.
207206
static OpFoldResult buildMax(OpBuilder &b, Location loc,
208207
ArrayRef<OpFoldResult> vals) {
209-
IRRewriter rewriter(b);
210208
return makeComposedFoldedAffineMax(
211-
rewriter, loc,
212-
AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), vals);
209+
b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
210+
vals);
213211
}
214212

215213
/// Build an `affine_min` of all the `vals`.
216214
static OpFoldResult buildMin(OpBuilder &b, Location loc,
217215
ArrayRef<OpFoldResult> vals) {
218-
IRRewriter rewriter(b);
219216
return makeComposedFoldedAffineMin(
220-
rewriter, loc,
221-
AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), vals);
217+
b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
218+
vals);
222219
}
223220

224221
/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The
@@ -386,7 +383,7 @@ linalg::tileToForeachThreadOpUsingTileSizes(
386383
// Insert a tile `source` into the destination tensor `dest`. The position at
387384
// which the tile is inserted (as well as size of tile) is taken from a given
388385
// ExtractSliceOp `sliceOp`.
389-
static Value insertSliceIntoTensor(RewriterBase &b, Location loc,
386+
static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
390387
tensor::ExtractSliceOp sliceOp, Value source,
391388
Value dest) {
392389
return b.create<tensor::InsertSliceOp>(
@@ -478,10 +475,9 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
478475
static_cast<size_t>(op.getNumInputsAndOutputs()) &&
479476
"expect the number of operands and inputs and outputs to match");
480477
SmallVector<Value> valuesToTile = operandValuesToUse;
481-
IRRewriter rewriter(b);
482478
SmallVector<OpFoldResult> sizeBounds =
483-
makeComposedFoldedMultiResultAffineApply(
484-
rewriter, loc, shapeSizesToLoopsMap, allShapeSizes);
479+
makeComposedFoldedMultiResultAffineApply(b, loc, shapeSizesToLoopsMap,
480+
allShapeSizes);
485481
SmallVector<Value> tiledOperands = makeTiledShapes(
486482
b, loc, op, valuesToTile, getAsOpFoldResult(interchangedIvs), tileSizes,
487483
sizeBounds,
@@ -616,10 +612,8 @@ static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op,
616612
auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
617613
assert(sliceOp && "expected ExtractSliceOp");
618614
// Insert the tile into the output tensor.
619-
// TODO: Propagate RewriterBase everywhere.
620-
IRRewriter rewriter(b);
621615
Value yieldValue =
622-
insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]);
616+
insertSliceIntoTensor(b, loc, sliceOp, sliceOp, iterArgs[0]);
623617
return scf::ValueVector({yieldValue});
624618
});
625619
return success();

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,10 @@ struct LinalgOpTilingInterface
108108
linalgOp.createFlatListOfOperandDims(b, loc);
109109
AffineMap map = linalgOp.getShapesToLoopsMap();
110110

111-
IRRewriter rewriter(b);
112111
return llvm::to_vector(
113112
llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) {
114-
OpFoldResult ofr = makeComposedFoldedAffineApply(
115-
rewriter, loc, loopExpr, allShapesSizes);
113+
OpFoldResult ofr =
114+
makeComposedFoldedAffineApply(b, loc, loopExpr, allShapesSizes);
116115
return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)};
117116
}));
118117
}
@@ -156,10 +155,9 @@ struct LinalgOpTilingInterface
156155

157156
AffineExpr d0;
158157
bindDims(b.getContext(), d0);
159-
IRRewriter rewriter(b);
160158
SmallVector<OpFoldResult> subShapeSizes =
161159
llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) {
162-
return makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, ofr);
160+
return makeComposedFoldedAffineApply(b, loc, d0 - 1, ofr);
163161
}));
164162

165163
OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);

0 commit comments

Comments
 (0)