Skip to content

Commit 2cd1361

Browse files
Make shared outs of scf.forall have read semantics.
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent e167d94 commit 2cd1361

File tree

1 file changed

+5
-23
lines changed

1 file changed

+5
-23
lines changed

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,18 +1186,6 @@ struct YieldOpInterface
11861186
}
11871187
};
11881188

1189-
/// Return `true` if the given loop may have 0 iterations.
1190-
bool mayHaveZeroIterations(scf::ForallOp forallOp) {
1191-
for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
1192-
forallOp.getMixedUpperBound())) {
1193-
std::optional<int64_t> lbConst = getConstantIntValue(lb);
1194-
std::optional<int64_t> ubConst = getConstantIntValue(ub);
1195-
if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
1196-
return true;
1197-
}
1198-
return false;
1199-
}
1200-
12011189
/// Bufferization of ForallOp. This also bufferizes the terminator of the
12021190
/// region. There are op interfaces for the terminators (InParallelOp
12031191
/// and ParallelInsertSliceOp), but these are only used during analysis. Not
@@ -1207,17 +1195,11 @@ struct ForallOpInterface
12071195
ForallOp> {
12081196
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
12091197
const AnalysisState &state) const {
1210-
auto forallOp = cast<ForallOp>(op);
1211-
1212-
// If the loop has zero iterations, the results of the op are their
1213-
// corresponding shared_outs, meaning that the shared_outs bufferize to a
1214-
// read.
1215-
if (mayHaveZeroIterations(forallOp))
1216-
return true;
1217-
1218-
// scf::ForallOp alone doesn't bufferize to a memory read, one of the
1219-
// uses of its matching bbArg may.
1220-
return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand));
1198+
// All tensor operands to `scf.forall` are `shared_outs` and all
1199+
// shared outs are assumed to be read by the loop. This does not
1200+
// account for the case where the entire value is over-written,
1201+
// but being conservative here.
1202+
return true;
12211203
}
12221204

12231205
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,

0 commit comments

Comments
 (0)