Skip to content

Commit 56954a5

Browse files
[MLIR][LoopPipelining] Improve schedule verifier, so it checks also operands of nested operations (#88450)
`verifySchedule` was not looking at the operands of nested operations, which caused incorrect schedule to be allowed in some cases, potentially leading to crash during expansion. There is also minor fix in `cloneAndUpdateOperands` in the pipeline expander that prevents double visit of the cloned op. This one has no functional impact, so no test for it.
1 parent c42a262 commit 56954a5

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,17 @@ bool LoopPipelinerInternal::initializeLoopInfo(
203203
return true;
204204
}
205205

206+
/// Find operands of all the nested operations within `op`.
207+
static SetVector<Value> getNestedOperands(Operation *op) {
208+
SetVector<Value> operands;
209+
op->walk([&](Operation *nestedOp) {
210+
for (Value operand : nestedOp->getOperands()) {
211+
operands.insert(operand);
212+
}
213+
});
214+
return operands;
215+
}
216+
206217
/// Compute unrolled cycles of each op (consumer) and verify that each op is
207218
/// scheduled after its operands (producers) while adjusting for the distance
208219
/// between producer and consumer.
@@ -219,7 +230,7 @@ bool LoopPipelinerInternal::verifySchedule() {
219230
}
220231
for (Operation *consumer : opOrder) {
221232
int64_t consumerCycle = unrolledCyles[consumer];
222-
for (Value operand : consumer->getOperands()) {
233+
for (Value operand : getNestedOperands(consumer)) {
223234
auto [producer, distance] = getDefiningOpAndDistance(operand);
224235
if (!producer)
225236
continue;
@@ -245,9 +256,8 @@ static Operation *
245256
cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
246257
function_ref<void(OpOperand *newOperand)> callback) {
247258
Operation *clone = rewriter.clone(*op);
248-
for (OpOperand &operand : clone->getOpOperands())
249-
callback(&operand);
250-
clone->walk([&](Operation *nested) {
259+
clone->walk<WalkOrder::PreOrder>([&](Operation *nested) {
260+
// 'clone' itself will be visited first.
251261
for (OpOperand &operand : nested->getOpOperands()) {
252262
Operation *def = operand.get().getDefiningOp();
253263
if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))

mlir/test/Dialect/SCF/loop-pipelining.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,3 +846,26 @@ func.func @invalid_schedule2(%A: memref<?xf32>, %result: memref<?xf32>) {
846846
} { __test_pipelining_loop__ }
847847
return
848848
}
849+
850+
// -----
851+
852+
func.func @invalid_schedule3(%A: memref<?xf32>, %result: memref<?xf32>, %ext: index) {
853+
%c0 = arith.constant 0 : index
854+
%c1 = arith.constant 1 : index
855+
%c4 = arith.constant 4 : index
856+
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
857+
%cnd = arith.cmpi slt, %ext, %c4 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : index
858+
// expected-error@+1 {{operation scheduled before its operands}}
859+
%idx1 = scf.if %cnd -> (index) {
860+
%idxinc = arith.addi %idx, %c1 : index
861+
scf.yield %idxinc : index
862+
} else {
863+
scf.yield %idx : index
864+
} { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 }
865+
%A_elem = memref.load %A[%idx1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
866+
%idx2 = arith.addi %idx1, %c1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 3 } : index
867+
memref.store %A_elem, %result[%idx1] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 4 } : memref<?xf32>
868+
scf.yield %idx2 : index
869+
} { __test_pipelining_loop__ }
870+
return
871+
}

0 commit comments

Comments
 (0)