Skip to content

Commit 77f5b33

Browse files
[mlir][SCF] Retire SCF-specific to_memref/to_tensor canonicalization patterns (#74551)
The partial bufferization framework has been replaced with One-Shot Bufferize. SCF-specific canonicalization patterns for `to_memref`/`to_tensor` are no longer needed.
1 parent 23d402e commit 77f5b33

File tree

4 files changed

+4
-182
lines changed

4 files changed

+4
-182
lines changed

mlir/lib/Dialect/SCF/IR/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@ add_mlir_dialect_library(MLIRSCFDialect
1111

1212
LINK_LIBS PUBLIC
1313
MLIRArithDialect
14-
MLIRBufferizationDialect
1514
MLIRControlFlowDialect
15+
MLIRDialectUtils
1616
MLIRFunctionInterfaces
1717
MLIRIR
1818
MLIRLoopLikeInterface
1919
MLIRSideEffectInterfaces
20+
MLIRTensorDialect
2021
MLIRValueBoundsOpInterface
2122
)
2223

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 2 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "mlir/Dialect/SCF/IR/SCF.h"
1010
#include "mlir/Dialect/Arith/IR/Arith.h"
1111
#include "mlir/Dialect/Arith/Utils/Utils.h"
12-
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1312
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1413
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1514
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
@@ -1082,139 +1081,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
10821081
}
10831082
};
10841083

1085-
/// Canonicalize the iter_args of an scf::ForOp that involve a
1086-
/// `bufferization.to_tensor` and for which only the last loop iteration is
1087-
/// actually visible outside of the loop. The canonicalization looks for a
1088-
/// pattern such as:
1089-
/// ```
1090-
/// %t0 = ... : tensor_type
1091-
/// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
1092-
/// ...
1093-
/// // %m is either buffer_cast(%bb00) or defined above the loop
1094-
/// %m... : memref_type
1095-
/// ... // uses of %m with potential inplace updates
1096-
/// %new_tensor = bufferization.to_tensor %m : memref_type
1097-
/// ...
1098-
/// scf.yield %new_tensor : tensor_type
1099-
/// }
1100-
/// ```
1101-
///
1102-
/// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
1103-
/// `%m = buffer_cast %bb0` op that feeds into the yielded
1104-
/// `bufferization.to_tensor` op.
1105-
///
1106-
/// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
1107-
/// occurs between `bufferization.to_tensor and yield then the value %0
1108-
/// visible outside of the loop is the last `bufferization.to_tensor`
1109-
/// produced in the loop.
1110-
///
1111-
/// For now, we approximate the absence of aliasing by only supporting the case
1112-
/// when the bufferization.to_tensor is the operation immediately preceding
1113-
/// the yield.
1114-
//
1115-
/// The canonicalization rewrites the pattern as:
1116-
/// ```
1117-
/// // %m is either a buffer_cast or defined above
1118-
/// %m... : memref_type
1119-
/// scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
1120-
/// ... // uses of %m with potential inplace updates
1121-
/// scf.yield %bb0: tensor_type
1122-
/// }
1123-
/// %0 = bufferization.to_tensor %m : memref_type
1124-
/// ```
1125-
///
1126-
/// A later bbArg canonicalization will further rewrite as:
1127-
/// ```
1128-
/// // %m is either a buffer_cast or defined above
1129-
/// %m... : memref_type
1130-
/// scf.for ... { // no iter_args
1131-
/// ... // uses of %m with potential inplace updates
1132-
/// }
1133-
/// %0 = bufferization.to_tensor %m : memref_type
1134-
/// ```
1135-
struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
1136-
using OpRewritePattern<ForOp>::OpRewritePattern;
1137-
1138-
LogicalResult matchAndRewrite(ForOp forOp,
1139-
PatternRewriter &rewriter) const override {
1140-
assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
1141-
"unexpected multiple blocks");
1142-
1143-
Location loc = forOp.getLoc();
1144-
DenseMap<Value, Value> replacements;
1145-
for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
1146-
unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
1147-
auto yieldOp =
1148-
cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
1149-
Value yieldVal = yieldOp->getOperand(idx);
1150-
auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
1151-
bool isTensor = llvm::isa<TensorType>(bbArg.getType());
1152-
1153-
bufferization::ToMemrefOp tensorToMemref;
1154-
// Either bbArg has no use or it has a single buffer_cast use.
1155-
if (bbArg.hasOneUse())
1156-
tensorToMemref =
1157-
dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
1158-
if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
1159-
continue;
1160-
// If tensorToMemref is present, it must feed into the `ToTensorOp`.
1161-
if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
1162-
continue;
1163-
// TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
1164-
// must be before `ToTensorOp` in the block so that the lastWrite
1165-
// property is not subject to additional side-effects.
1166-
// For now, we only support the case when ToTensorOp appears
1167-
// immediately before the terminator.
1168-
if (tensorLoadOp->getNextNode() != yieldOp)
1169-
continue;
1170-
1171-
// Clone the optional tensorToMemref before forOp.
1172-
if (tensorToMemref) {
1173-
rewriter.setInsertionPoint(forOp);
1174-
rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>(
1175-
tensorToMemref, tensorToMemref.getMemref().getType(),
1176-
tensorToMemref.getTensor());
1177-
}
1178-
1179-
// Clone the tensorLoad after forOp.
1180-
rewriter.setInsertionPointAfter(forOp);
1181-
Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>(
1182-
loc, tensorLoadOp.getMemref());
1183-
Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
1184-
replacements.insert(std::make_pair(forOpResult, newTensorLoad));
1185-
1186-
// Make the terminator just yield the bbArg, the old tensorLoadOp + the
1187-
// old bbArg (that is now directly yielded) will canonicalize away.
1188-
rewriter.startRootUpdate(yieldOp);
1189-
yieldOp.setOperand(idx, bbArg);
1190-
rewriter.finalizeRootUpdate(yieldOp);
1191-
}
1192-
if (replacements.empty())
1193-
return failure();
1194-
1195-
// We want to replace a subset of the results of `forOp`. rewriter.replaceOp
1196-
// replaces the whole op and erase it unconditionally. This is wrong for
1197-
// `forOp` as it generally contains ops with side effects.
1198-
// Instead, use `rewriter.replaceOpWithIf`.
1199-
SmallVector<Value> newResults;
1200-
newResults.reserve(forOp.getNumResults());
1201-
for (Value v : forOp.getResults()) {
1202-
auto it = replacements.find(v);
1203-
newResults.push_back((it != replacements.end()) ? it->second : v);
1204-
}
1205-
unsigned idx = 0;
1206-
rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
1207-
return op.get() != newResults[idx++];
1208-
});
1209-
return success();
1210-
}
1211-
};
12121084
} // namespace
12131085

12141086
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
12151087
MLIRContext *context) {
1216-
results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
1217-
LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
1088+
results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1089+
context);
12181090
}
12191091

12201092
std::optional<APInt> ForOp::getConstantStep() {

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -773,56 +773,6 @@ func.func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
773773

774774
// -----
775775

776-
func.func private @process(%0 : memref<128x128xf32>)
777-
func.func private @process_tensor(%0 : tensor<128x128xf32>) -> memref<128x128xf32>
778-
779-
// CHECK-LABEL: last_value
780-
// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<128x128xf32>
781-
// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<128x128xf32>
782-
// CHECK-SAME: %[[T2:[0-9a-z]*]]: tensor<128x128xf32>
783-
// CHECK-SAME: %[[M0:[0-9a-z]*]]: memref<128x128xf32>
784-
func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
785-
%t2: tensor<128x128xf32>, %m0: memref<128x128xf32>,
786-
%lb : index, %ub : index, %step : index)
787-
-> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>)
788-
{
789-
// CHECK-NEXT: %[[M1:.*]] = bufferization.to_memref %[[T1]] : memref<128x128xf32>
790-
// CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[BBARG_T2:.*]] = %[[T2]]) -> (tensor<128x128xf32>) {
791-
%0:3 = scf.for %arg0 = %lb to %ub step %step iter_args(%arg1 = %t0, %arg2 = %t1, %arg3 = %t2)
792-
-> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>)
793-
{
794-
%m1 = bufferization.to_memref %arg2 : memref<128x128xf32>
795-
796-
// CHECK-NEXT: call @process(%[[M0]]) : (memref<128x128xf32>) -> ()
797-
func.call @process(%m0) : (memref<128x128xf32>) -> ()
798-
799-
// CHECK-NEXT: call @process(%[[M1]]) : (memref<128x128xf32>) -> ()
800-
func.call @process(%m1) : (memref<128x128xf32>) -> ()
801-
802-
// This does not hoist (fails the bbArg has at most a single check).
803-
// CHECK-NEXT: %[[T:.*]] = func.call @process_tensor(%[[BBARG_T2]]) : (tensor<128x128xf32>) -> memref<128x128xf32>
804-
// CHECK-NEXT: %[[YIELD_T:.*]] = bufferization.to_tensor %[[T:.*]]
805-
%m2 = func.call @process_tensor(%arg3): (tensor<128x128xf32>) -> memref<128x128xf32>
806-
%3 = bufferization.to_tensor %m2 : memref<128x128xf32>
807-
808-
// All this stuff goes away, incrementally
809-
%1 = bufferization.to_tensor %m0 : memref<128x128xf32>
810-
%2 = bufferization.to_tensor %m1 : memref<128x128xf32>
811-
812-
// CHECK-NEXT: scf.yield %[[YIELD_T]] : tensor<128x128xf32>
813-
scf.yield %1, %2, %3 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
814-
815-
// CHECK-NEXT: }
816-
}
817-
818-
// CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
819-
// CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
820-
// CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
821-
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
822-
}
823-
824-
// -----
825-
826776
// CHECK-LABEL: fold_away_iter_with_no_use_and_yielded_input
827777
// CHECK-SAME: %[[A0:[0-9a-z]*]]: i32
828778
func.func @fold_away_iter_with_no_use_and_yielded_input(%arg0 : i32,

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3994,7 +3994,6 @@ cc_library(
39943994
deps = [
39953995
":ArithDialect",
39963996
":ArithUtils",
3997-
":BufferizationDialect",
39983997
":ControlFlowDialect",
39993998
":ControlFlowInterfaces",
40003999
":DestinationStyleOpInterface",

0 commit comments

Comments
 (0)