Skip to content

Commit c933bd8

Browse files
authored
[MLIR][SCF] Add checks to verify that the pipeliner schedule is correct. (#77083)
Add a check to validate that the schedule passed to the pipeliner transformation is valid and won't cause the pipeliner to break SSA. This checks that the for each operation in the loop operations are scheduled after their operands.
1 parent 1220c9b commit c933bd8

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ struct LoopPipelinerInternal {
6767
/// the Value.
6868
std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
6969

70+
/// Return true if the schedule is possible and return false otherwise. A
71+
/// schedule is correct if all definitions are scheduled before uses.
72+
bool verifySchedule();
73+
7074
public:
7175
/// Initalize the information for the given `op`, return true if it
7276
/// satisfies the pre-condition to apply pipelining.
@@ -156,6 +160,11 @@ bool LoopPipelinerInternal::initializeLoopInfo(
156160
}
157161
}
158162

163+
if (!verifySchedule()) {
164+
LDBG("--invalid schedule: " << op << " -> BAIL");
165+
return false;
166+
}
167+
159168
// Currently, we do not support assigning stages to ops in nested regions. The
160169
// block of all operations assigned a stage should be the single `scf.for`
161170
// body block.
@@ -194,6 +203,40 @@ bool LoopPipelinerInternal::initializeLoopInfo(
194203
return true;
195204
}
196205

206+
/// Compute unrolled cycles of each op (consumer) and verify that each op is
207+
/// scheduled after its operands (producers) while adjusting for the distance
208+
/// between producer and consumer.
209+
bool LoopPipelinerInternal::verifySchedule() {
210+
int64_t numCylesPerIter = opOrder.size();
211+
// Pre-compute the unrolled cycle of each op.
212+
DenseMap<Operation *, int64_t> unrolledCyles;
213+
for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
214+
Operation *def = opOrder[cycle];
215+
auto it = stages.find(def);
216+
assert(it != stages.end());
217+
int64_t stage = it->second;
218+
unrolledCyles[def] = cycle + stage * numCylesPerIter;
219+
}
220+
for (Operation *consumer : opOrder) {
221+
int64_t consumerCycle = unrolledCyles[consumer];
222+
for (Value operand : consumer->getOperands()) {
223+
auto [producer, distance] = getDefiningOpAndDistance(operand);
224+
if (!producer)
225+
continue;
226+
auto it = unrolledCyles.find(producer);
227+
// Skip producer coming from outside the loop.
228+
if (it == unrolledCyles.end())
229+
continue;
230+
int64_t producerCycle = it->second;
231+
if (consumerCycle < producerCycle - numCylesPerIter * distance) {
232+
consumer->emitError("operation scheduled before its operands");
233+
return false;
234+
}
235+
}
236+
}
237+
return true;
238+
}
239+
197240
/// Clone `op` and call `callback` on the cloned op's oeprands as well as any
198241
/// operands of nested ops that:
199242
/// 1) aren't defined within the new op or

mlir/test/Dialect/SCF/loop-pipelining.mlir

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-scf-pipelining -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-scf-pipelining -split-input-file -verify-diagnostics | FileCheck %s
22
// RUN: mlir-opt %s -test-scf-pipelining=annotate -split-input-file | FileCheck %s --check-prefix ANNOTATE
33
// RUN: mlir-opt %s -test-scf-pipelining=no-epilogue-peeling -split-input-file | FileCheck %s --check-prefix NOEPILOGUE
44

@@ -814,3 +814,35 @@ func.func @yield_constant_loop(%A: memref<?xf32>) -> f32 {
814814
return %r : f32
815815
}
816816

817+
// -----
818+
819+
func.func @invalid_schedule(%A: memref<?xf32>, %result: memref<?xf32>) {
820+
%c0 = arith.constant 0 : index
821+
%c1 = arith.constant 1 : index
822+
%c4 = arith.constant 4 : index
823+
%cf = arith.constant 1.0 : f32
824+
scf.for %i0 = %c0 to %c4 step %c1 {
825+
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
826+
%A1_elem = arith.addf %A_elem, %cf { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : f32
827+
// expected-error@+1 {{operation scheduled before its operands}}
828+
memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : memref<?xf32>
829+
} { __test_pipelining_loop__ }
830+
return
831+
}
832+
833+
// -----
834+
835+
func.func @invalid_schedule2(%A: memref<?xf32>, %result: memref<?xf32>) {
836+
%c0 = arith.constant 0 : index
837+
%c1 = arith.constant 1 : index
838+
%c4 = arith.constant 4 : index
839+
%cf = arith.constant 1.0 : f32
840+
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
841+
// expected-error@+1 {{operation scheduled before its operands}}
842+
%A_elem = memref.load %A[%idx] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref<?xf32>
843+
%idx1 = arith.addi %idx, %c1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : index
844+
memref.store %A_elem, %result[%idx] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
845+
scf.yield %idx1 : index
846+
} { __test_pipelining_loop__ }
847+
return
848+
}

0 commit comments

Comments
 (0)