Skip to content

[MLIR][SCF] Add checks to verify that the pipeliner schedule is correct. #77083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ struct LoopPipelinerInternal {
/// the Value.
std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);

/// Return true if the schedule is possible and return false otherwise. A
/// schedule is correct if all definitions are scheduled before uses.
bool verifySchedule();

public:
/// Initalize the information for the given `op`, return true if it
/// satisfies the pre-condition to apply pipelining.
Expand Down Expand Up @@ -156,6 +160,11 @@ bool LoopPipelinerInternal::initializeLoopInfo(
}
}

if (!verifySchedule()) {
LDBG("--invalid schedule: " << op << " -> BAIL");
return false;
}

// Currently, we do not support assigning stages to ops in nested regions. The
// block of all operations assigned a stage should be the single `scf.for`
// body block.
Expand Down Expand Up @@ -194,6 +203,40 @@ bool LoopPipelinerInternal::initializeLoopInfo(
return true;
}

/// Compute unrolled cycles of each op (consumer) and verify that each op is
/// scheduled after its operands (producers) while adjusting for the distance
/// between producer and consumer.
bool LoopPipelinerInternal::verifySchedule() {
int64_t numCylesPerIter = opOrder.size();
// Pre-compute the unrolled cycle of each op.
DenseMap<Operation *, int64_t> unrolledCyles;
for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
Operation *def = opOrder[cycle];
auto it = stages.find(def);
assert(it != stages.end());
int64_t stage = it->second;
unrolledCyles[def] = cycle + stage * numCylesPerIter;
}
for (Operation *consumer : opOrder) {
int64_t consumerCycle = unrolledCyles[consumer];
for (Value operand : consumer->getOperands()) {
auto [producer, distance] = getDefiningOpAndDistance(operand);
if (!producer)
continue;
auto it = unrolledCyles.find(producer);
// Skip producer coming from outside the loop.
if (it == unrolledCyles.end())
continue;
int64_t producerCycle = it->second;
if (consumerCycle < producerCycle - numCylesPerIter * distance) {
consumer->emitError("operation scheduled before its operands");
return false;
}
}
}
return true;
}

/// Clone `op` and call `callback` on the cloned op's oeprands as well as any
/// operands of nested ops that:
/// 1) aren't defined within the new op or
Expand Down
34 changes: 33 additions & 1 deletion mlir/test/Dialect/SCF/loop-pipelining.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-scf-pipelining -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-scf-pipelining -split-input-file -verify-diagnostics | FileCheck %s
// RUN: mlir-opt %s -test-scf-pipelining=annotate -split-input-file | FileCheck %s --check-prefix ANNOTATE
// RUN: mlir-opt %s -test-scf-pipelining=no-epilogue-peeling -split-input-file | FileCheck %s --check-prefix NOEPILOGUE

Expand Down Expand Up @@ -814,3 +814,35 @@ func.func @yield_constant_loop(%A: memref<?xf32>) -> f32 {
return %r : f32
}

// -----

func.func @invalid_schedule(%A: memref<?xf32>, %result: memref<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%cf = arith.constant 1.0 : f32
scf.for %i0 = %c0 to %c4 step %c1 {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
%A1_elem = arith.addf %A_elem, %cf { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : f32
// expected-error@+1 {{operation scheduled before its operands}}
memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : memref<?xf32>
} { __test_pipelining_loop__ }
return
}

// -----

func.func @invalid_schedule2(%A: memref<?xf32>, %result: memref<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%cf = arith.constant 1.0 : f32
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
// expected-error@+1 {{operation scheduled before its operands}}
%A_elem = memref.load %A[%idx] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref<?xf32>
%idx1 = arith.addi %idx, %c1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : index
memref.store %A_elem, %result[%idx] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
scf.yield %idx1 : index
} { __test_pipelining_loop__ }
return
}