Skip to content

Commit 72de758

Browse files
[mlir][SCF] Add bufferization hook for scf.foreach_thread and terminator.
`scf.foreach_thread` results alias with the underlying `scf.foreach_thread.parallel_insert_slice` destination operands and they bufferize to equivalent buffers in the absence of other conflicts. `scf.foreach_thread.parallel_insert_slice` conflict detection is similar to `tensor.insert_slice` conflict detection. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D126769
1 parent 0d21863 commit 72de758

File tree

4 files changed

+437
-7
lines changed

4 files changed

+437
-7
lines changed

mlir/include/mlir/Dialect/SCF/SCFOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,12 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
448448
let hasCustomAssemblyFormat = 1;
449449
let hasVerifier = 1;
450450

451+
// The default builder does not add a region with an empty body, add our own.
452+
let skipDefaultBuilders = 1;
453+
let builders = [
454+
OpBuilder<(ins)>,
455+
];
456+
451457
// TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
452458
// appear inside perform_concurrently.
453459
let extraClassDeclaration = [{

mlir/lib/Dialect/SCF/SCF.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,10 +1138,11 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder,
11381138
result.addOperands(numThreads);
11391139

11401140
Region *bodyRegion = result.addRegion();
1141-
{
1142-
OpBuilder::InsertionGuard g(builder);
1143-
builder.createBlock(bodyRegion);
1144-
}
1141+
OpBuilder::InsertionGuard g(builder);
1142+
// createBlock sets the IP inside the block.
1143+
// Generally we would guard against that but the default ensureTerminator impl
1144+
// expects it ..
1145+
builder.createBlock(bodyRegion);
11451146
Block &bodyBlock = bodyRegion->front();
11461147
bodyBlock.addArguments(
11471148
SmallVector<Type>(numThreads.size(), builder.getIndexType()),
@@ -1158,18 +1159,21 @@ void ForeachThreadOp::build(
11581159
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
11591160
result.addOperands(numThreads);
11601161

1162+
OpBuilder::InsertionGuard g(builder);
11611163
Region *bodyRegion = result.addRegion();
1162-
bodyRegion->push_back(new Block);
1164+
builder.createBlock(bodyRegion);
11631165
Block &bodyBlock = bodyRegion->front();
11641166
bodyBlock.addArguments(
11651167
SmallVector<Type>(numThreads.size(), builder.getIndexType()),
11661168
SmallVector<Location>(numThreads.size(), result.location));
11671169

11681170
OpBuilder::InsertionGuard guard(builder);
11691171
builder.setInsertionPointToStart(&bodyBlock);
1170-
bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
1172+
bodyBuilder(builder, result.location, bodyBlock.getArguments());
11711173
auto terminator =
1172-
llvm::cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
1174+
llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
1175+
assert(terminator &&
1176+
"expected bodyBuilder to create PerformConcurrentlyOp terminator");
11731177
result.addTypes(terminator.yieldedTypes());
11741178
}
11751179

@@ -1272,6 +1276,13 @@ void ParallelInsertSliceOp::getCanonicalizationPatterns(
12721276
// PerformConcurrentlyOp
12731277
//===----------------------------------------------------------------------===//
12741278

1279+
// Build a PerformConcurrentlyOp with mixed static and dynamic entries.
1280+
void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
1281+
OpBuilder::InsertionGuard g(b);
1282+
Region *bodyRegion = result.addRegion();
1283+
b.createBlock(bodyRegion);
1284+
}
1285+
12751286
LogicalResult PerformConcurrentlyOp::verify() {
12761287
// TODO: PerformConcurrentlyOpInterface.
12771288
for (const Operation &op : getRegion().front().getOperations())

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
1414
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1515
#include "mlir/Dialect/SCF/SCF.h"
16+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1617
#include "mlir/IR/Dialect.h"
1718
#include "mlir/IR/Operation.h"
1819
#include "mlir/IR/PatternMatch.h"
@@ -812,6 +813,289 @@ struct YieldOpInterface
812813
}
813814
};
814815

816+
using tensor::ExtractSliceOp;
817+
818+
/// Return the destinations that an ForeachThreadOp is inserting into. One per
819+
/// ParallelInsertSliceOp.
820+
static SmallVector<OpOperand *>
821+
getInsertionDest(ForeachThreadOp foreachThreadOp) {
822+
PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
823+
SmallVector<OpOperand *> result;
824+
terminator.walk([&](ParallelInsertSliceOp insertOp) {
825+
result.push_back(&insertOp->getOpOperand(1) /*dest*/);
826+
});
827+
return result;
828+
}
829+
830+
/// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
831+
/// region. There are op interfaces for the terminators (PerformConcurrentlyOp
832+
/// and ParallelInsertSliceOp), but these are only used during analysis. Not
833+
/// for bufferization.
834+
struct ForeachThreadOpInterface
835+
: public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
836+
ForeachThreadOp> {
837+
SmallVector<OpOperand *>
838+
getAliasingOpOperand(Operation *op, OpResult opResult,
839+
const AnalysisState &state) const {
840+
// Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
841+
auto foreachThreadOp = cast<ForeachThreadOp>(op);
842+
return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
843+
}
844+
845+
bool isMemoryWrite(Operation *op, OpResult opResult,
846+
const AnalysisState &state) const {
847+
// This op is a memory write. Stop lookup here to avoid finding false
848+
// conflicts involving this op and one of the ops in the region. This is
849+
// similar to how scf.if ops are analyzed.
850+
return true;
851+
}
852+
853+
BufferRelation bufferRelation(Operation *op, OpResult opResult,
854+
const AnalysisState &state) const {
855+
return BufferRelation::Equivalent;
856+
}
857+
858+
LogicalResult bufferize(Operation *op, RewriterBase &b,
859+
BufferizationState &state) const {
860+
OpBuilder::InsertionGuard g(b);
861+
auto foreachThreadOp = cast<ForeachThreadOp>(op);
862+
863+
// Gather new results of the ForeachThreadOp.
864+
SmallVector<Value> newResults;
865+
for (OpResult opResult : foreachThreadOp->getOpResults()) {
866+
SmallVector<OpOperand *> insertDestOperands =
867+
state.getAnalysisState().getAliasingOpOperand(opResult);
868+
assert(insertDestOperands.size() == 1 &&
869+
"expected exactly one aliasing OpOperand");
870+
// Insert copies right before the PerformConcurrentlyOp terminator. They
871+
// should not be inside terminator (which would be the default insertion
872+
// point).
873+
Value buffer = *state.getBuffer(b, *insertDestOperands.front(),
874+
/*forceInPlace=*/llvm::None,
875+
/*customCopyInsertionPoint=*/op);
876+
newResults.push_back(buffer);
877+
}
878+
879+
// Create new ForeachThreadOp without any results and drop the automatically
880+
// introduced terminator.
881+
TypeRange newResultTypes;
882+
auto newForeachThreadOp =
883+
b.create<ForeachThreadOp>(foreachThreadOp.getLoc(), newResultTypes,
884+
foreachThreadOp.getNumThreads());
885+
newForeachThreadOp.getBody()->getTerminator()->erase();
886+
887+
// Move over block contents of the old op.
888+
b.mergeBlocks(foreachThreadOp.getBody(), newForeachThreadOp.getBody(),
889+
{newForeachThreadOp.getBody()->getArguments()});
890+
891+
// Bufferize terminator.
892+
auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(
893+
newForeachThreadOp.getBody()->getTerminator());
894+
b.setInsertionPoint(performConcurrentlyOp);
895+
unsigned resultCounter = 0;
896+
WalkResult walkResult =
897+
performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
898+
Location loc = insertOp.getLoc();
899+
Type srcType = getMemRefType(
900+
insertOp.getSource().getType().cast<RankedTensorType>(),
901+
state.getOptions());
902+
// ParallelInsertSliceOp bufferizes to a copy.
903+
auto srcMemref = b.create<bufferization::ToMemrefOp>(
904+
loc, srcType, insertOp.getSource());
905+
Value destMemref = newResults[resultCounter++];
906+
Value subview = b.create<memref::SubViewOp>(
907+
loc, destMemref, insertOp.getMixedOffsets(),
908+
insertOp.getMixedSizes(), insertOp.getMixedStrides());
909+
// This memcpy will fold away if everything bufferizes in-place.
910+
if (failed(state.getOptions().createMemCpy(b, insertOp.getLoc(),
911+
srcMemref, subview)))
912+
return WalkResult::interrupt();
913+
b.eraseOp(insertOp);
914+
return WalkResult::advance();
915+
});
916+
if (walkResult.wasInterrupted())
917+
return failure();
918+
919+
// Replace the op.
920+
replaceOpWithBufferizedValues(b, op, newResults);
921+
922+
return success();
923+
}
924+
};
925+
926+
/// Nothing to do for PerformConcurrentlyOp.
927+
struct PerformConcurrentlyOpInterface
928+
: public BufferizableOpInterface::ExternalModel<
929+
PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
930+
LogicalResult bufferize(Operation *op, RewriterBase &b,
931+
BufferizationState &state) const {
932+
assert(false && "op does not have any tensor OpOperands / OpResults");
933+
return failure();
934+
}
935+
};
936+
937+
/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
938+
/// equivalent operand / result and same offset/sizes/strides specification).
939+
static bool areEquivalentExtractSliceOps(const AnalysisState &state,
940+
ExtractSliceOp st,
941+
ParallelInsertSliceOp sti) {
942+
if (!st || !sti)
943+
return false;
944+
if (st != sti &&
945+
!state.areEquivalentBufferizedValues(st.source(), sti.getDest()))
946+
return false;
947+
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
948+
return false;
949+
return true;
950+
}
951+
952+
/// Return true if `value` is originating from an ExtractSliceOp that matches
953+
/// the given InsertSliceOp.
954+
static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
955+
ParallelInsertSliceOp insertOp) {
956+
auto condition = [&](Value val) {
957+
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
958+
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
959+
return true;
960+
return false;
961+
};
962+
963+
return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
964+
condition);
965+
}
966+
967+
/// Analysis of ParallelInsertSliceOp.
968+
struct ParallelInsertSliceOpInterface
969+
: public BufferizableOpInterface::ExternalModel<
970+
ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
971+
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
972+
const AnalysisState &state) const {
973+
if (&opOperand != &op->getOpOperand(1) /*dest*/)
974+
return {};
975+
976+
// ParallelInsertSliceOp itself has no results. Tensors are returned via
977+
// the parent op.
978+
auto foreachThreadOp = op->getParentOfType<ForeachThreadOp>();
979+
assert(foreachThreadOp &&
980+
"could not find valid owner of parallel_insert_slice");
981+
982+
// The i-th ParallelInsertSliceOp result is returned via the i-th OpResult
983+
// of the parent ForeachThreadOp.
984+
Block *block = op->getBlock();
985+
unsigned int opIdx = 0;
986+
for (ParallelInsertSliceOp insertOp :
987+
block->getOps<ParallelInsertSliceOp>()) {
988+
if (insertOp.getOperation() == op)
989+
break;
990+
++opIdx;
991+
}
992+
assert(opIdx < foreachThreadOp->getNumResults() &&
993+
"could not find op inside terminator op");
994+
995+
return {foreachThreadOp->getResult(opIdx)};
996+
}
997+
998+
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
999+
const AnalysisState &state) const {
1000+
return true;
1001+
}
1002+
1003+
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1004+
const AnalysisState &state) const {
1005+
return &opOperand == &op->getOpOperand(1) /*dest*/;
1006+
}
1007+
1008+
BufferRelation bufferRelation(Operation *op, OpResult opResult,
1009+
const AnalysisState &state) const {
1010+
return BufferRelation::Equivalent;
1011+
}
1012+
1013+
LogicalResult bufferize(Operation *op, RewriterBase &b,
1014+
BufferizationState &state) const {
1015+
// Will be bufferized as part of ForeachThreadOp.
1016+
return failure();
1017+
}
1018+
1019+
// TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
1020+
// the code.
1021+
bool isNotConflicting(Operation *op, OpOperand *uRead,
1022+
OpOperand *uConflictingWrite,
1023+
const AnalysisState &state) const {
1024+
Operation *readingOp = uRead->getOwner();
1025+
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
1026+
1027+
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
1028+
// uRead is an InsertSliceOp...
1029+
if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
1030+
// As an example, consider the following IR.
1031+
//
1032+
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1033+
// %1 = linalg.fill %cst, %0 {inplace= [true] }
1034+
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1035+
// {inplace= [true] }
1036+
1037+
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
1038+
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1039+
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
1040+
insertSliceOp))
1041+
// Case 1: The main insight is that InsertSliceOp reads only part of
1042+
// the destination tensor. The overwritten area is not read. If
1043+
// uConflictingWrite writes into exactly the memory location that is
1044+
// being read by uRead, this is not a conflict.
1045+
//
1046+
// In the above example:
1047+
// uRead = OpOperand 1 (%t) of tensor.insert_slice
1048+
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
1049+
//
1050+
// The read of %t does not conflict with the write of the FillOp
1051+
// (same aliases!) because the area that the FillOp operates on is
1052+
// exactly the one that is *not* read via %t.
1053+
return true;
1054+
1055+
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
1056+
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1057+
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
1058+
// Case 2: The read of the source tensor and the write to the dest
1059+
// tensor via an InsertSliceOp is not a conflict if the read is
1060+
// reading exactly that part of an equivalent tensor that the
1061+
// InsertSliceOp is writing.
1062+
//
1063+
// In the above example:
1064+
// uRead = OpOperand 0 (%1) of tensor.insert_slice
1065+
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1066+
return true;
1067+
}
1068+
1069+
// If uConflictingWrite is an InsertSliceOp...
1070+
if (auto insertSliceOp =
1071+
dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
1072+
// As an example, consider the following IR.
1073+
//
1074+
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1075+
// %1 = linalg.fill %cst, %0 {inplace= [true] }
1076+
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1077+
// {inplace= [true] }
1078+
// %3 = vector.transfer_read %1, %cst
1079+
//
1080+
// In the above example:
1081+
// uRead = OpOperand 0 (%1) of vector.transfer_read
1082+
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1083+
// lastWrite = %1
1084+
//
1085+
// This is not a conflict because the InsertSliceOp overwrites the
1086+
// memory segment of %1 with the exact same data. (Effectively, there
1087+
// is no memory write here.)
1088+
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
1089+
state.areEquivalentBufferizedValues(uRead->get(),
1090+
insertSliceOp.getSource()) &&
1091+
hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
1092+
insertSliceOp))
1093+
return true;
1094+
1095+
return false;
1096+
}
1097+
};
1098+
8151099
} // namespace
8161100
} // namespace scf
8171101
} // namespace mlir
@@ -822,6 +1106,11 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels(
8221106
ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
8231107
ForOp::attachInterface<ForOpInterface>(*ctx);
8241108
IfOp::attachInterface<IfOpInterface>(*ctx);
1109+
ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
1110+
ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1111+
*ctx);
1112+
PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
1113+
*ctx);
8251114
WhileOp::attachInterface<WhileOpInterface>(*ctx);
8261115
YieldOp::attachInterface<YieldOpInterface>(*ctx);
8271116
});

0 commit comments

Comments
 (0)