-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][SCF] Handle more cases in pipelining transform #74007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
-Fix case where an op is scheduled in stage 0 and used with a distance of 1 -Fix case where we don't peel the epilogue and a value not part of the last stage is used outside the loop.
@llvm/pr-subscribers-mlir Author: Thomas Raoux (ThomasRaoux) Changes-Fix case where an op is scheduled in stage 0 and used with a distance of 1 Full diff: https://github.com/llvm/llvm-project/pull/74007.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 5537a8b212c51f7..f25318fe52093ec 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -61,6 +61,11 @@ struct LoopPipelinerInternal {
/// `idx` of `key` in the epilogue.
void setValueMapping(Value key, Value el, int64_t idx);
+ /// Return the defining op of the given value, if the Value is an argument of
+ /// the loop return the associated defining op in the loop and its distance to
+ /// the Value.
+ std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
+
public:
/// Initalize the information for the given `op`, return true if it
/// satisfies the pre-condition to apply pipelining.
@@ -240,11 +245,12 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
unsigned stage = stages[op];
auto analyzeOperand = [&](OpOperand &operand) {
- Operation *def = operand.get().getDefiningOp();
+ auto [def, distance] = getDefiningOpAndDistance(operand.get());
if (!def)
return;
auto defStage = stages.find(def);
- if (defStage == stages.end() || defStage->second == stage)
+ if (defStage == stages.end() || defStage->second == stage ||
+ defStage->second == stage + distance)
return;
assert(stage > defStage->second);
LiverangeInfo &info = crossStageValues[operand.get()];
@@ -261,6 +267,25 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
return crossStageValues;
}
+std::pair<Operation *, int64_t>
+LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
+ int64_t distance = 0;
+ if (auto arg = dyn_cast<BlockArgument>(value)) {
+ if (arg.getOwner() != forOp.getBody())
+ return {nullptr, 0};
+ // Ignore induction variable.
+ if (arg.getArgNumber() == 0)
+ return {nullptr, 0};
+ distance++;
+ value =
+ forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
+ }
+ Operation *def = value.getDefiningOp();
+ if (!def)
+ return {nullptr, 0};
+ return {def, distance};
+}
+
scf::ForOp LoopPipelinerInternal::createKernelLoop(
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
&crossStageValues,
@@ -366,10 +391,9 @@ LogicalResult LoopPipelinerInternal::createKernel(
rewriter.setInsertionPointAfter(newOp);
continue;
}
- auto arg = dyn_cast<BlockArgument>(operand->get());
+ Value source = operand->get();
+ auto arg = dyn_cast<BlockArgument>(source);
if (arg && arg.getOwner() == forOp.getBody()) {
- // If the value is a loop carried value coming from stage N + 1 remap,
- // it will become a direct use.
Value ret = forOp.getBody()->getTerminator()->getOperand(
arg.getArgNumber() - 1);
Operation *dep = ret.getDefiningOp();
@@ -378,15 +402,19 @@ LogicalResult LoopPipelinerInternal::createKernel(
auto stageDep = stages.find(dep);
if (stageDep == stages.end() || stageDep->second == useStage)
continue;
- assert(stageDep->second == useStage + 1);
- nestedNewOp->setOperand(operand->getOperandNumber(),
- mapping.lookupOrDefault(ret));
- continue;
+ // If the value is a loop carried value coming from stage N + 1 remap,
+ // it will become a direct use.
+ if (stageDep->second == useStage + 1) {
+ nestedNewOp->setOperand(operand->getOperandNumber(),
+ mapping.lookupOrDefault(ret));
+ continue;
+ }
+ source = ret;
}
// For operands defined in a previous stage we need to remap it to use
// the correct region argument. We look for the right version of the
// Value based on the stage where it is used.
- Operation *def = operand->get().getDefiningOp();
+ Operation *def = source.getDefiningOp();
if (!def)
continue;
auto stageDef = stages.find(def);
@@ -418,9 +446,30 @@ LogicalResult LoopPipelinerInternal::createKernel(
// We create a mapping between original values and the associated loop
// returned values that will be needed by the epilogue.
llvm::SmallVector<Value> yieldOperands;
- for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) {
- yieldOperands.push_back(mapping.lookupOrDefault(retVal));
+ for (OpOperand &yielOperand :
+ forOp.getBody()->getTerminator()->getOpOperands()) {
+ Value source = mapping.lookupOrDefault(yielOperand.get());
+ // When we don't peel the epilogue the yield value is used outside the loop
+ // we need to make sure we return the version from numStages - defStage.
+ if (!peelEpilogue &&
+ !forOp.getResult(yielOperand.getOperandNumber()).use_empty()) {
+ auto [def, distance] = getDefiningOpAndDistance(yielOperand.get());
+ if (def) {
+ auto defStage = stages.find(def);
+ if (defStage != stages.end()) {
+ Value pred = predicates[defStage->second];
+ if (pred) {
+ source = rewriter.create<arith::SelectOp>(
+ pred.getLoc(), pred, source,
+ newForOp.getBody()
+ ->getArguments()[yielOperand.getOperandNumber() + 1]);
+ }
+ }
+ }
+ }
+ yieldOperands.push_back(source);
}
+
for (auto &it : crossStageValues) {
int64_t version = maxStage - it.second.lastUseStage + 1;
unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
@@ -444,9 +493,11 @@ LogicalResult LoopPipelinerInternal::createKernel(
Operation *def = retVal.value().getDefiningOp();
assert(def && "Only support loop carried dependencies of distance 1");
unsigned defStage = stages[def];
- setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
- newForOp->getResult(retVal.index()),
- maxStage - defStage + 1);
+ if (defStage > 0) {
+ setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
+ newForOp->getResult(retVal.index()),
+ maxStage - defStage + 1);
+ }
}
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
return success();
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 0309287e409c184..4cd686d2cdb86b6 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -670,4 +670,56 @@ func.func @backedge_mix_order(%A: memref<?xf32>) -> f32 {
scf.yield %A3_elem : f32
} { __test_pipelining_loop__ }
return %r : f32
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK-LABEL: @distance_1_use
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// Prologue:
+// CHECK: %[[L0:.+]] = memref.load %{{.*}}[%[[C0]]] : memref<?xf32>
+// CHECK: %[[L1:.+]] = memref.load %{{.*}}[%[[C1]]] : memref<?xf32>
+// CHECK: %[[R:.+]]:5 = scf.for {{.*}} iter_args(%[[IDX0:.+]] = %[[C2]], %[[L2:.+]] = %[[L0]], %[[L3:.+]] = %[[L1]]
+// CHECK: %[[L4:.+]] = memref.load %{{.*}}[%[[IDX0]]] : memref<?xf32>
+// CHECK: %[[IDX1:.+]] = arith.addi %[[IDX0]], %[[C1]] : index
+// CHECK: memref.store %[[L2]]
+// CHECK: scf.yield %[[IDX1]], %[[L3]], %[[L4]]
+func.func @distance_1_use(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %cf = arith.constant 1.0 : f32
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
+ %A_elem = memref.load %A[%idx] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref<?xf32>
+ %idx1 = arith.addi %idx, %c1 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : index
+ memref.store %A_elem, %result[%idx] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ scf.yield %idx1 : index
+ } { __test_pipelining_loop__ }
+ return
+}
+
+// -----
+
+// NOEPILOGUE-LABEL: stage_0_value_escape(
+func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %cf = arith.constant 1.0 : f32
+// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index
+// NOEPILOGUE: %[[A:.+]] = arith.addf
+// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]],
+// NOEPILOGUE: %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index
+// NOEPILOGUE: %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32
+// NOEPILOGUE: scf.yield %[[S]]
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
+ %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+ memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ scf.yield %A1_elem : f32
+ } { __test_pipelining_loop__ }
+ memref.store %r, %result[%c1] : memref<?xf32>
+ return
+}
|
@llvm/pr-subscribers-mlir-scf Author: Thomas Raoux (ThomasRaoux) Changes-Fix case where an op is scheduled in stage 0 and used with a distance of 1 Full diff: https://github.com/llvm/llvm-project/pull/74007.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 5537a8b212c51f7..f25318fe52093ec 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -61,6 +61,11 @@ struct LoopPipelinerInternal {
/// `idx` of `key` in the epilogue.
void setValueMapping(Value key, Value el, int64_t idx);
+ /// Return the defining op of the given value, if the Value is an argument of
+ /// the loop return the associated defining op in the loop and its distance to
+ /// the Value.
+ std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
+
public:
/// Initalize the information for the given `op`, return true if it
/// satisfies the pre-condition to apply pipelining.
@@ -240,11 +245,12 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
unsigned stage = stages[op];
auto analyzeOperand = [&](OpOperand &operand) {
- Operation *def = operand.get().getDefiningOp();
+ auto [def, distance] = getDefiningOpAndDistance(operand.get());
if (!def)
return;
auto defStage = stages.find(def);
- if (defStage == stages.end() || defStage->second == stage)
+ if (defStage == stages.end() || defStage->second == stage ||
+ defStage->second == stage + distance)
return;
assert(stage > defStage->second);
LiverangeInfo &info = crossStageValues[operand.get()];
@@ -261,6 +267,25 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
return crossStageValues;
}
+std::pair<Operation *, int64_t>
+LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
+ int64_t distance = 0;
+ if (auto arg = dyn_cast<BlockArgument>(value)) {
+ if (arg.getOwner() != forOp.getBody())
+ return {nullptr, 0};
+ // Ignore induction variable.
+ if (arg.getArgNumber() == 0)
+ return {nullptr, 0};
+ distance++;
+ value =
+ forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
+ }
+ Operation *def = value.getDefiningOp();
+ if (!def)
+ return {nullptr, 0};
+ return {def, distance};
+}
+
scf::ForOp LoopPipelinerInternal::createKernelLoop(
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
&crossStageValues,
@@ -366,10 +391,9 @@ LogicalResult LoopPipelinerInternal::createKernel(
rewriter.setInsertionPointAfter(newOp);
continue;
}
- auto arg = dyn_cast<BlockArgument>(operand->get());
+ Value source = operand->get();
+ auto arg = dyn_cast<BlockArgument>(source);
if (arg && arg.getOwner() == forOp.getBody()) {
- // If the value is a loop carried value coming from stage N + 1 remap,
- // it will become a direct use.
Value ret = forOp.getBody()->getTerminator()->getOperand(
arg.getArgNumber() - 1);
Operation *dep = ret.getDefiningOp();
@@ -378,15 +402,19 @@ LogicalResult LoopPipelinerInternal::createKernel(
auto stageDep = stages.find(dep);
if (stageDep == stages.end() || stageDep->second == useStage)
continue;
- assert(stageDep->second == useStage + 1);
- nestedNewOp->setOperand(operand->getOperandNumber(),
- mapping.lookupOrDefault(ret));
- continue;
+ // If the value is a loop carried value coming from stage N + 1 remap,
+ // it will become a direct use.
+ if (stageDep->second == useStage + 1) {
+ nestedNewOp->setOperand(operand->getOperandNumber(),
+ mapping.lookupOrDefault(ret));
+ continue;
+ }
+ source = ret;
}
// For operands defined in a previous stage we need to remap it to use
// the correct region argument. We look for the right version of the
// Value based on the stage where it is used.
- Operation *def = operand->get().getDefiningOp();
+ Operation *def = source.getDefiningOp();
if (!def)
continue;
auto stageDef = stages.find(def);
@@ -418,9 +446,30 @@ LogicalResult LoopPipelinerInternal::createKernel(
// We create a mapping between original values and the associated loop
// returned values that will be needed by the epilogue.
llvm::SmallVector<Value> yieldOperands;
- for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) {
- yieldOperands.push_back(mapping.lookupOrDefault(retVal));
+ for (OpOperand &yielOperand :
+ forOp.getBody()->getTerminator()->getOpOperands()) {
+ Value source = mapping.lookupOrDefault(yielOperand.get());
+ // When we don't peel the epilogue the yield value is used outside the loop
+ // we need to make sure we return the version from numStages - defStage.
+ if (!peelEpilogue &&
+ !forOp.getResult(yielOperand.getOperandNumber()).use_empty()) {
+ auto [def, distance] = getDefiningOpAndDistance(yielOperand.get());
+ if (def) {
+ auto defStage = stages.find(def);
+ if (defStage != stages.end()) {
+ Value pred = predicates[defStage->second];
+ if (pred) {
+ source = rewriter.create<arith::SelectOp>(
+ pred.getLoc(), pred, source,
+ newForOp.getBody()
+ ->getArguments()[yielOperand.getOperandNumber() + 1]);
+ }
+ }
+ }
+ }
+ yieldOperands.push_back(source);
}
+
for (auto &it : crossStageValues) {
int64_t version = maxStage - it.second.lastUseStage + 1;
unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
@@ -444,9 +493,11 @@ LogicalResult LoopPipelinerInternal::createKernel(
Operation *def = retVal.value().getDefiningOp();
assert(def && "Only support loop carried dependencies of distance 1");
unsigned defStage = stages[def];
- setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
- newForOp->getResult(retVal.index()),
- maxStage - defStage + 1);
+ if (defStage > 0) {
+ setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
+ newForOp->getResult(retVal.index()),
+ maxStage - defStage + 1);
+ }
}
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
return success();
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 0309287e409c184..4cd686d2cdb86b6 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -670,4 +670,56 @@ func.func @backedge_mix_order(%A: memref<?xf32>) -> f32 {
scf.yield %A3_elem : f32
} { __test_pipelining_loop__ }
return %r : f32
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK-LABEL: @distance_1_use
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// Prologue:
+// CHECK: %[[L0:.+]] = memref.load %{{.*}}[%[[C0]]] : memref<?xf32>
+// CHECK: %[[L1:.+]] = memref.load %{{.*}}[%[[C1]]] : memref<?xf32>
+// CHECK: %[[R:.+]]:5 = scf.for {{.*}} iter_args(%[[IDX0:.+]] = %[[C2]], %[[L2:.+]] = %[[L0]], %[[L3:.+]] = %[[L1]]
+// CHECK: %[[L4:.+]] = memref.load %{{.*}}[%[[IDX0]]] : memref<?xf32>
+// CHECK: %[[IDX1:.+]] = arith.addi %[[IDX0]], %[[C1]] : index
+// CHECK: memref.store %[[L2]]
+// CHECK: scf.yield %[[IDX1]], %[[L3]], %[[L4]]
+func.func @distance_1_use(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %cf = arith.constant 1.0 : f32
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
+ %A_elem = memref.load %A[%idx] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref<?xf32>
+ %idx1 = arith.addi %idx, %c1 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : index
+ memref.store %A_elem, %result[%idx] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ scf.yield %idx1 : index
+ } { __test_pipelining_loop__ }
+ return
+}
+
+// -----
+
+// NOEPILOGUE-LABEL: stage_0_value_escape(
+func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %cf = arith.constant 1.0 : f32
+// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index
+// NOEPILOGUE: %[[A:.+]] = arith.addf
+// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]],
+// NOEPILOGUE: %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index
+// NOEPILOGUE: %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32
+// NOEPILOGUE: scf.yield %[[S]]
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
+ %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+ memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ scf.yield %A1_elem : f32
+ } { __test_pipelining_loop__ }
+ memref.store %r, %result[%c1] : memref<?xf32>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, a few nits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks @qedawkins! |
-Fix case where an op is scheduled in stage 0 and used with a distance of 1
-Fix case where we don't peel the epilogue and a value not part of the last stage is used outside the loop.