Skip to content

Commit 57fe7fd

Browse files
[mlir][Linalg] Add support for scf::ForOp in comprehensive bufferization (7/n)
scf::ForOp bufferization analysis proceeds just like for any other op (including FuncOp) at its boundaries; i.e. if: 1. The tensor operand is inplaceable. 2. The matching result has no subsequent read (i.e. all reads dominate the scf::ForOp). 3. In and does not create a RAW interference. then it can bufferize inplace. Still there are a few differences: 1. bbArgs for an scf::ForOp are always considered inplaceable when seen from ops inside the body. This is because a) either the matching tensor operand is not inplaceable and an alloc will be inserted (which makes bbArg itself inplaceable); or b) the tensor operand and bbArg are both already inplaceable. 2. Bufferization within the scf::ForOp body has implications to the outside world : the scf.yield terminator may well ping-pong values of the same type. This muddies the water for alias analysis and is not supported atm. Such cases result in a pass failure. Differential revision: https://reviews.llvm.org/D104490
1 parent c74aea4 commit 57fe7fd

File tree

5 files changed

+436
-73
lines changed

5 files changed

+436
-73
lines changed

mlir/include/mlir/Dialect/SCF/SCFOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ def ForOp : SCF_Op<"for",
261261
return getOperation()->getNumOperands() - getNumControlOperands();
262262
}
263263
/// Get the region iter arg that corresponds to an OpOperand.
264+
/// This helper prevents internal op implementation detail leakage to
265+
/// clients by hiding the operand / block argument mapping.
264266
BlockArgument getRegionIterArgForOpOperand(OpOperand &opOperand) {
265267
assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
266268
"expected an iter args operand");
@@ -270,6 +272,8 @@ def ForOp : SCF_Op<"for",
270272
opOperand.getOperandNumber() - getNumControlOperands()];
271273
}
272274
/// Get the OpOperand& that corresponds to a region iter arg.
275+
/// This helper prevents internal op implementation detail leakage to
276+
/// clients by hiding the operand / block argument mapping.
273277
OpOperand &getOpOperandForRegionIterArg(BlockArgument bbArg) {
274278
assert(bbArg.getArgNumber() >= getNumInductionVars() &&
275279
"expected a bbArg that is not an induction variable");
@@ -278,6 +282,27 @@ def ForOp : SCF_Op<"for",
278282
return getOperation()->getOpOperand(
279283
getNumControlOperands() + bbArg.getArgNumber() - getNumInductionVars());
280284
}
285+
/// Get the OpResult that corresponds to an OpOperand.
286+
/// Assert that opOperand is an iterArg.
287+
/// This helper prevents internal op implementation detail leakage to
288+
/// clients by hiding the operand / block argument mapping.
289+
OpResult getResultForOpOperand(OpOperand &opOperand) {
290+
assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
291+
"expected an iter args operand");
292+
assert(opOperand.getOwner() == getOperation() &&
293+
"opOperand does not belong to this scf::ForOp operation");
294+
return getOperation()->getResult(
295+
opOperand.getOperandNumber() - getNumControlOperands());
296+
}
297+
/// Get the OpOperand& that corresponds to an OpResultOpOperand.
298+
/// This helper prevents internal op implementation detail leakage to
299+
/// clients by hiding the operand / block argument mapping.
300+
OpOperand &getOpOperandForResult(OpResult opResult) {
301+
assert(opResult.getDefiningOp() == getOperation() &&
302+
"opResult does not belong to the scf::ForOp operation");
303+
return getOperation()->getOpOperand(
304+
getNumControlOperands() + opResult.getResultNumber());
305+
}
281306

282307
/// Return operands used when entering the region at 'index'. These operands
283308
/// correspond to the loop iterator operands, i.e., those exclusing the

0 commit comments

Comments
 (0)