Skip to content

Commit e99fae8

Browse files
committed
[mlir] more aggressive folding in tiling/fusion transformations
Combine the recently added utilities for folded-by-construction affine operations with the attribute-based Range to enable more folding. This decreases the amount of emitted code but has little effect on test precisely because the tests are not checking for the spurious constants. The difference in the shape of affine maps comes from the internals of affine folding. Depends on D129633 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D130167
1 parent 70e99f3 commit e99fae8

21 files changed

+353
-322
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,13 @@ OpFoldResult makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
392392
OpFoldResult makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
393393
AffineExpr expr,
394394
ArrayRef<OpFoldResult> operands);
395+
/// Variant of `makeComposedFoldedAffineApply` suitable for multi-result maps.
396+
/// Note that this may create as many affine.apply operations as the map has
397+
/// results given that affine.apply must be single-result.
398+
SmallVector<OpFoldResult>
399+
makeComposedFoldedMultiResultAffineApply(RewriterBase &b, Location loc,
400+
AffineMap map,
401+
ArrayRef<OpFoldResult> operands);
395402

396403
/// Returns an AffineMinOp obtained by composing `map` and `operands` with
397404
/// AffineApplyOps supplying those operands.
@@ -405,16 +412,17 @@ OpFoldResult makeComposedFoldedAffineMin(RewriterBase &b, Location loc,
405412
AffineMap map,
406413
ArrayRef<OpFoldResult> operands);
407414

415+
/// Constructs an AffineMinOp that computes a maximum across the results of
416+
/// applying `map` to `operands`, then immediately attempts to fold it. If
417+
/// folding results in a constant value, erases all created ops.
418+
OpFoldResult makeComposedFoldedAffineMax(RewriterBase &b, Location loc,
419+
AffineMap map,
420+
ArrayRef<OpFoldResult> operands);
421+
408422
/// Returns the values obtained by applying `map` to the list of values.
409423
SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
410424
AffineMap map, ValueRange values);
411425

412-
/// Returns the values obtained by applying `map` to the list of values, which
413-
/// may be known constants.
414-
SmallVector<OpFoldResult> applyMapToValues(RewriterBase &b, Location loc,
415-
AffineMap map,
416-
ArrayRef<OpFoldResult> values);
417-
418426
/// Given an affine map `map` and its input `operands`, this method composes
419427
/// into `map`, maps of AffineApplyOps whose results are the values in
420428
/// `operands`, iteratively until no more of `operands` are the result of an

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
11331133
let extraClassDeclaration = [{
11341134
/// Return the flat list of all operand dimension sizes in the order they
11351135
/// appear in the operands.
1136-
SmallVector<Value, 4> createFlatListOfOperandDims(OpBuilder &, Location);
1136+
SmallVector<OpFoldResult> createFlatListOfOperandDims(OpBuilder &, Location);
11371137

11381138
/// Return the flat list of all operands' static dimension sizes in the
11391139
/// order they appear in the operands. All operand dimension sizes have to

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,8 @@ using TileSizeComputationFunction =
410410
using LoopIndexToRangeIndexMap = DenseMap<int, int>;
411411
std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
412412
makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
413-
ValueRange allShapeSizes, ValueRange allTileSizes);
413+
ArrayRef<OpFoldResult> allShapeSizes,
414+
ArrayRef<OpFoldResult> allTileSizes);
414415

415416
/// A description of a multi-size tiling comprising tile sizes and numbers of
416417
/// tiles, expressed as Values which may or may not be constant. Multi-size

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ bool isPermutation(ArrayRef<int64_t> permutation);
4848
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
4949
/// the type of `source`.
5050
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
51+
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source,
52+
int64_t dim);
5153

5254
/// Given an operation, retrieves the value of each dynamic dimension through
5355
/// constructing the necessary DimOp operators.
@@ -179,16 +181,17 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
179181

180182
/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a
181183
/// tile size is zero (i.e., no tiling), the corresponding offset is also zero.
182-
SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
183-
ValueRange ivs, ValueRange tileSizes);
184+
SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
185+
ArrayRef<OpFoldResult> ivs,
186+
ArrayRef<OpFoldResult> tileSizes);
184187

185188
/// Computes tile sizes, given a list of `tileSizes` and dimension
186189
/// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the
187190
/// corresponding result size is the corresponding value from `sizeBounds`.
188191
/// Note: The returned tile sizes are closed intervals.
189-
SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
190-
ValueRange tileSizes,
191-
ArrayRef<Value> sizeBounds);
192+
SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
193+
ArrayRef<OpFoldResult> tileSizes,
194+
ArrayRef<OpFoldResult> sizeBounds);
192195

193196
/// Returns the list of tensor output types produced when the given structured
194197
/// operation `op` is applied to the given `operands`. Note that `operands` are
@@ -217,8 +220,9 @@ Value materializeOpFoldResult(OpBuilder &b, Location loc,
217220
/// controls whether to omit the partial/boundary tile condition check in cases
218221
/// where we statically know that it is unnecessary.
219222
Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
220-
ValueRange tileSizes, AffineMap map, ValueRange lbs,
221-
ValueRange ubs, ValueRange subShapeSizes,
223+
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
224+
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
225+
ArrayRef<OpFoldResult> subShapeSizes,
222226
bool omitPartialTileCheck);
223227

224228
/// Creates extract_slice/subview ops for all `valuesToTile` of the given
@@ -232,18 +236,20 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
232236
/// Note that a constant zero in `tileSizes` means no tiling at that implicit
233237
/// loop. The number of non-zero values in `tileSizes` should be equal to the
234238
/// number of values in `ivs`.
235-
SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
236-
LinalgOp linalgOp,
237-
ArrayRef<Value> valuesToTile,
238-
ValueRange ivs, ValueRange tileSizes,
239-
ArrayRef<Value> sizeBounds,
240-
bool omitPartialTileCheck);
239+
SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
240+
LinalgOp linalgOp, ValueRange valuesToTile,
241+
ArrayRef<OpFoldResult> ivs,
242+
ArrayRef<OpFoldResult> tileSizes,
243+
ArrayRef<OpFoldResult> sizeBounds,
244+
bool omitPartialTileCheck);
241245

242246
/// Add the specified offsets to any `linalg.index` ops contained in the given
243247
/// `linalgOp`. The offsets are provided in the same order as iteration space
244248
/// dimensions. Null offests are assumed to be zero.
245-
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef<Value> offests);
246-
void offsetIndices(RewriterBase &b, LinalgOp linalgOp, ArrayRef<Value> offests);
249+
void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
250+
ArrayRef<OpFoldResult> offests);
251+
void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
252+
ArrayRef<OpFoldResult> offests);
247253

248254
using FusableOpDependencesTy = llvm::MapVector<
249255
Operation *,

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 60 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -790,33 +790,6 @@ AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
790790
values);
791791
}
792792

793-
OpFoldResult
794-
mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
795-
AffineMap map,
796-
ArrayRef<OpFoldResult> operands) {
797-
assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
798-
799-
SmallVector<Operation *> constants;
800-
SmallVector<Value> actualValues;
801-
materializeConstants(b, loc, operands, constants, actualValues);
802-
composeAffineMapAndOperands(&map, &actualValues);
803-
OpFoldResult result = createOrFold<AffineApplyOp>(b, loc, actualValues, map);
804-
if (result.is<Attribute>()) {
805-
for (Operation *op : constants)
806-
b.eraseOp(op);
807-
}
808-
return result;
809-
}
810-
811-
OpFoldResult
812-
mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
813-
AffineExpr expr,
814-
ArrayRef<OpFoldResult> operands) {
815-
return makeComposedFoldedAffineApply(
816-
b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}).front(),
817-
operands);
818-
}
819-
820793
/// Composes the given affine map with the given list of operands, pulling in
821794
/// the maps from any affine.apply operations that supply the operands.
822795
static void composeMultiResultAffineMap(AffineMap &map,
@@ -847,29 +820,81 @@ static void composeMultiResultAffineMap(AffineMap &map,
847820
canonicalizeMapAndOperands(&map, &operands);
848821
}
849822

823+
OpFoldResult
824+
mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
825+
AffineMap map,
826+
ArrayRef<OpFoldResult> operands) {
827+
assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
828+
829+
SmallVector<Operation *> constants;
830+
SmallVector<Value> actualValues;
831+
materializeConstants(b, loc, operands, constants, actualValues);
832+
composeAffineMapAndOperands(&map, &actualValues);
833+
OpFoldResult result = createOrFold<AffineApplyOp>(b, loc, actualValues, map);
834+
835+
// Constants are always folded into affine min/max because they can be
836+
// represented as constant expressions, so delete them.
837+
for (Operation *op : constants)
838+
b.eraseOp(op);
839+
return result;
840+
}
841+
842+
OpFoldResult
843+
mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
844+
AffineExpr expr,
845+
ArrayRef<OpFoldResult> operands) {
846+
return makeComposedFoldedAffineApply(
847+
b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}).front(),
848+
operands);
849+
}
850+
851+
SmallVector<OpFoldResult> mlir::makeComposedFoldedMultiResultAffineApply(
852+
RewriterBase &b, Location loc, AffineMap map,
853+
ArrayRef<OpFoldResult> operands) {
854+
return llvm::to_vector(llvm::map_range(
855+
llvm::seq<unsigned>(0, map.getNumResults()), [&](unsigned i) {
856+
return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
857+
operands);
858+
}));
859+
}
860+
850861
Value mlir::makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map,
851862
ValueRange operands) {
852863
SmallVector<Value> allOperands = llvm::to_vector(operands);
853864
composeMultiResultAffineMap(map, allOperands);
854865
return b.createOrFold<AffineMinOp>(loc, b.getIndexType(), map, allOperands);
855866
}
856867

857-
OpFoldResult
858-
mlir::makeComposedFoldedAffineMin(RewriterBase &b, Location loc, AffineMap map,
859-
ArrayRef<OpFoldResult> operands) {
868+
template <typename OpTy>
869+
static OpFoldResult makeComposedFoldedMinMax(RewriterBase &b, Location loc,
870+
AffineMap map,
871+
ArrayRef<OpFoldResult> operands) {
860872
SmallVector<Operation *> constants;
861873
SmallVector<Value> actualValues;
862874
materializeConstants(b, loc, operands, constants, actualValues);
863875
composeMultiResultAffineMap(map, actualValues);
864876
OpFoldResult result =
865-
createOrFold<AffineMinOp>(b, loc, actualValues, b.getIndexType(), map);
866-
if (result.is<Attribute>()) {
867-
for (Operation *op : constants)
868-
b.eraseOp(op);
869-
}
877+
createOrFold<OpTy>(b, loc, actualValues, b.getIndexType(), map);
878+
879+
// Constants are always folded into affine min/max because they can be
880+
// represented as constant expressions, so delete them.
881+
for (Operation *op : constants)
882+
b.eraseOp(op);
870883
return result;
871884
}
872885

886+
OpFoldResult
887+
mlir::makeComposedFoldedAffineMin(RewriterBase &b, Location loc, AffineMap map,
888+
ArrayRef<OpFoldResult> operands) {
889+
return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
890+
}
891+
892+
OpFoldResult
893+
mlir::makeComposedFoldedAffineMax(RewriterBase &b, Location loc, AffineMap map,
894+
ArrayRef<OpFoldResult> operands) {
895+
return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
896+
}
897+
873898
/// Fully compose map with operands and canonicalize the result.
874899
/// Return the `createOrFold`'ed AffineApply op.
875900
static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
@@ -896,40 +921,6 @@ SmallVector<Value, 4> mlir::applyMapToValues(OpBuilder &b, Location loc,
896921
return res;
897922
}
898923

899-
SmallVector<OpFoldResult>
900-
mlir::applyMapToValues(RewriterBase &b, Location loc, AffineMap map,
901-
ArrayRef<OpFoldResult> values) {
902-
// Materialize constants and keep track of produced operations so we can clean
903-
// them up later.
904-
SmallVector<Operation *> constants;
905-
SmallVector<Value> actualValues;
906-
materializeConstants(b, loc, values, constants, actualValues);
907-
908-
// Compose, fold and construct maps for each result independently because they
909-
// may simplify more effectively.
910-
SmallVector<OpFoldResult> results;
911-
results.reserve(map.getNumResults());
912-
bool foldedAll = true;
913-
for (auto i : llvm::seq<unsigned>(0, map.getNumResults())) {
914-
AffineMap submap = map.getSubMap({i});
915-
SmallVector<Value> operands = actualValues;
916-
fullyComposeAffineMapAndOperands(&submap, &operands);
917-
canonicalizeMapAndOperands(&submap, &operands);
918-
results.push_back(createOrFold<AffineApplyOp>(b, loc, operands, submap));
919-
if (!results.back().is<Attribute>())
920-
foldedAll = false;
921-
}
922-
923-
// If the entire map could be folded, remove the constants that were used in
924-
// the initial ops.
925-
if (foldedAll) {
926-
for (Operation *constant : constants)
927-
b.eraseOp(constant);
928-
}
929-
930-
return results;
931-
}
932-
933924
// A symbol may appear as a dim in affine.apply operations. This function
934925
// canonicalizes dims that are valid symbols into actual symbols.
935926
template <class MapOrSet>

mlir/lib/Dialect/Linalg/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
1616
LINK_LIBS PUBLIC
1717
MLIRAffineDialect
1818
MLIRArithmeticDialect
19+
MLIRArithmeticUtils
1920
MLIRBufferizationDialect
2021
MLIRDialectUtils
2122
MLIRInferTypeOpInterface

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1212
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13+
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
1314
#include "mlir/Dialect/Complex/IR/Complex.h"
1415
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1516
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -486,13 +487,20 @@ static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
486487
return b.createOrFold<tensor::DimOp>(loc, source, dim);
487488
llvm_unreachable("Expected MemRefType or TensorType");
488489
}
490+
static OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source,
491+
int64_t dim) {
492+
auto shapedType = source.getType().cast<ShapedType>();
493+
if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
494+
return createOrFoldDimOp(b, loc, source, dim);
495+
return b.getIndexAttr(shapedType.getDimSize(dim));
496+
}
489497

490-
SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
491-
Location loc) {
492-
SmallVector<Value, 4> res;
498+
SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
499+
Location loc) {
500+
SmallVector<OpFoldResult> res;
493501
for (OpOperand *opOperand : getInputAndOutputOperands()) {
494502
for (int64_t i = 0, e = getRank(opOperand); i < e; ++i)
495-
res.push_back(createOrFoldDimOp(b, loc, opOperand->get(), i));
503+
res.push_back(createFoldedDimOp(b, loc, opOperand->get(), i));
496504
}
497505
return res;
498506
}
@@ -510,14 +518,13 @@ SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
510518
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
511519
auto viewSizes = createFlatListOfOperandDims(b, loc);
512520
SmallVector<Range, 4> res(numDims);
513-
Value zeroVal = b.create<arith::ConstantIndexOp>(loc, 0);
514-
Value oneVal = b.create<arith::ConstantIndexOp>(loc, 1);
515521
for (unsigned idx = 0; idx < numRes; ++idx) {
516522
auto result = map.getResult(idx);
517523
if (auto d = result.dyn_cast<AffineDimExpr>()) {
518524
if (res[d.getPosition()].offset)
519525
continue;
520-
res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
526+
res[d.getPosition()] =
527+
Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
521528
}
522529
}
523530
return res;
@@ -591,9 +598,11 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
591598
outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
592599
HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
593600
Location loc = getOperation()->getLoc();
594-
auto allResultDimValues =
595-
applyMapToValues(b, loc, resultShapesFromInputShapesMap,
596-
createFlatListOfOperandDims(b, loc));
601+
IRRewriter rewriter(b);
602+
SmallVector<OpFoldResult> allResultDimValues =
603+
makeComposedFoldedMultiResultAffineApply(
604+
rewriter, loc, resultShapesFromInputShapesMap,
605+
createFlatListOfOperandDims(b, loc));
597606
int64_t pos = 0;
598607
ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
599608
for (OpOperand *opOperand : getOutputOperands()) {
@@ -602,7 +611,8 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
602611
if (checkDimExpr.visit(shapeExprs[pos]))
603612
shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim));
604613
else
605-
shapes.push_back(allResultDimValues[pos]);
614+
shapes.push_back(
615+
getValueOrCreateConstantIndexOp(b, loc, allResultDimValues[pos]));
606616
pos++;
607617
}
608618
reifiedReturnShapes.emplace_back(std::move(shapes));

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -630,12 +630,8 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
630630
// plus low padding sizes.
631631
SmallVector<OpFoldResult, 4> newOffsets;
632632
for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
633-
Value padValue = getValueOrCreateConstantIndexOp(
634-
rewriter, srcPadOp.getLoc(), std::get<0>(p));
635-
Value offsetValue = getValueOrCreateConstantIndexOp(
636-
rewriter, insertOp.getLoc(), std::get<1>(p));
637-
newOffsets.push_back(
638-
applyMapToValues(rewriter, loc, addMap, {offsetValue, padValue})[0]);
633+
newOffsets.push_back(makeComposedFoldedAffineApply(
634+
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
639635
}
640636

641637
SmallVector<OpFoldResult, 4> newSizes;

0 commit comments

Comments
 (0)