-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Allow yielding values from selection regions #133702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Igor Wodiany (IgWod-IMG) ChangesThere are cases in SPIR-V shaders where values need to be yielded from the selection region to make valid MLIR. For example (part of the SPIR-V shader decompiled to GLSL):
This patch extends Full diff: https://github.com/llvm/llvm-project/pull/133702.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index cb7d27e8d4b9a..5a0db18b2e4b0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -352,7 +352,7 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
// -----
def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
- Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>]> {
+ Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>, ReturnLike]> {
let summary = "A special terminator for merging a structured selection/loop.";
let description = [{
@@ -360,14 +360,15 @@ def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
This op is a terminator used inside their regions to mean jumping to the
merge point, which is the next op following the `spirv.mlir.selection` or
`spirv.mlir.loop` op. This op does not have a corresponding instruction in the
- SPIR-V binary format; it's solely for structural purpose.
+ SPIR-V binary format; it's solely for structural purpose. The instruction is also
+ used to yield values outside the selection/loop region.
}];
- let arguments = (ins);
+ let arguments = (ins Variadic<AnyType>:$operands);
let results = (outs);
- let assemblyFormat = "attr-dict";
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
let hasOpcode = 0;
@@ -471,7 +472,7 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
SPIRV_SelectionControlAttr:$selection_control
);
- let results = (outs);
+ let results = (outs Variadic<AnyType>:$results);
let regions = (region AnyRegion:$body);
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 31d8cd2206148..1b8c548852619 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -228,10 +228,10 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
// Create `spirv.selection` operation, selection header block and merge
// block.
auto selectionOp =
- rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+ rewriter.create<spirv::SelectionOp>(loc, TypeRange(), spirv::SelectionControl::None);
auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
selectionOp.getBody().end());
- rewriter.create<spirv::MergeOp>(loc);
+ rewriter.create<spirv::MergeOp>(loc, ValueRange());
OpBuilder::InsertionGuard guard(rewriter);
auto *selectionHeaderBlock =
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 2959d67b366b9..6bcec6e1a29fd 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -383,7 +383,7 @@ void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
builder.createBlock(&getBody());
// Add a spirv.mlir.merge op into the merge block.
- builder.create<spirv::MergeOp>(getLoc());
+ builder.create<spirv::MergeOp>(getLoc(), ValueRange());
}
//===----------------------------------------------------------------------===//
@@ -452,6 +452,11 @@ ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
if (parseControlAttribute<spirv::SelectionControlAttr,
spirv::SelectionControl>(parser, result))
return failure();
+
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(result.types))
+ return failure();
+
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
}
@@ -459,6 +464,10 @@ void SelectionOp::print(OpAsmPrinter &printer) {
auto control = getSelectionControl();
if (control != spirv::SelectionControl::None)
printer << " control(" << spirv::stringifySelectionControl(control) << ")";
+ if(getNumResults() > 0) {
+ printer << " -> ";
+ printer << getResultTypes();
+ }
printer << ' ';
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
@@ -526,7 +535,7 @@ void SelectionOp::addMergeBlock(OpBuilder &builder) {
builder.createBlock(&getBody());
// Add a spirv.mlir.merge op into the merge block.
- builder.create<spirv::MergeOp>(getLoc());
+ builder.create<spirv::MergeOp>(getLoc(), ValueRange());
}
SelectionOp
@@ -534,7 +543,7 @@ SelectionOp::createIfThen(Location loc, Value condition,
function_ref<void(OpBuilder &builder)> thenBody,
OpBuilder &builder) {
auto selectionOp =
- builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+ builder.create<spirv::SelectionOp>(loc, TypeRange(), spirv::SelectionControl::None);
selectionOp.addMergeBlock(builder);
Block *mergeBlock = selectionOp.getMergeBlock();
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 273817d53d308..7e4b3c3fc2d56 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1843,7 +1843,7 @@ ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
OpBuilder builder(&mergeBlock->front());
auto control = static_cast<spirv::SelectionControl>(selectionControl);
- auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
+ auto selectionOp = builder.create<spirv::SelectionOp>(location, TypeRange(), control);
selectionOp.addMergeBlock(builder);
return selectionOp;
@@ -1992,6 +1992,29 @@ LogicalResult ControlFlowStructurizer::structurize() {
ArrayRef<Value>(blockArgs));
}
+ // Values defined inside the selection region that need to be yield outside
+ // the region.
+ SmallVector<Value> valuesToYield;
+ // Outside uses of values sank into the selection region. Those uses will be
+ // replaced with values returned by the SelectionOp.
+ SmallVector<Value> outsideUses;
+
+ // Move block arguments of the original block (`mergeBlock`) into the merge
+ // block inside the selection (`body.back()`). Values produced by block arguments
+ // will be yielded by the selection region. We do not update uses or erase original
+ // block arguments yet. It will be done later in the code.
+ if (!isLoop) {
+ for (BlockArgument blockArg : mergeBlock->getArguments()) {
+ // Create new block arguments in the last block ("merge block") of the
+ // selection region. We create one argument for each argument in `mergeBlock`.
+ // This new value will need to be yielded, and the original value replaced, so
+ // add them to appropriate vectors.
+ body.back().addArgument(blockArg.getType(), blockArg.getLoc());
+ valuesToYield.push_back(body.back().getArguments().back());
+ outsideUses.push_back(blockArg);
+ }
+ }
+
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
// cleaned up.
LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
@@ -2000,16 +2023,75 @@ LogicalResult ControlFlowStructurizer::structurize() {
for (auto *block : constructBlocks)
block->dropAllReferences();
+ // All internal uses should be removed from original blocks by now, so
+ // whatever is left is an outside use and will need to be yielded from
+ // the newly created selection region.
+ if(!isLoop) {
+ for (Block *block : constructBlocks) {
+ for (Operation &op : *block) {
+ if (!op.use_empty())
+ for (Value result : op.getResults()) {
+ valuesToYield.push_back(mapper.lookupOrNull(result));
+ outsideUses.push_back(result);
+ }
+ }
+ for (BlockArgument &arg : block->getArguments()) {
+ if (!arg.use_empty()) {
+ valuesToYield.push_back(mapper.lookupOrNull(arg));
+ outsideUses.push_back(arg);
+ }
+ }
+ }
+ }
+
+ assert(valuesToYield.size() == outsideUses.size());
+
+ // If we need to yield any values from the selection region we will take
+ // care of it here.
+ if(!isLoop && !valuesToYield.empty()) {
+ LLVM_DEBUG(logger.startLine() << "[cf] yielding values from the selection region\n");
+
+ // Update `mlir.merge` with values to be yield.
+ auto mergeOps = body.back().getOps<spirv::MergeOp>();
+ assert(std::next(mergeOps.begin()) == mergeOps.end());
+ Operation *merge = *mergeOps.begin();
+ merge->setOperands(valuesToYield);
+
+ // MLIR does not allow changing the number of results of an operation, so
+ // we create a new SelectionOp with required list of results and move
+ // the region from the initial SelectionOp. The initial operation is then
+ // removed. Since we move the region to the new op all links between blocks
+ // and remapping we have previously done should be preserved.
+ builder.setInsertionPoint(&mergeBlock->front());
+ auto selectionOp = builder.create<spirv::SelectionOp>(location, TypeRange(outsideUses), static_cast<spirv::SelectionControl>(control));
+ selectionOp->getRegion(0).takeBody(body);
+
+ // Remove initial op and swap the pointer to the newly created one.
+ op->erase();
+ op = selectionOp;
+
+ // Update all outside uses to use results of the SelectionOp and remove
+ // block arguments from the original merge block.
+ for (size_t i = 0; i < outsideUses.size(); i++)
+ outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
+ for (size_t i = 0; i < mergeBlock->getNumArguments(); ++i)
+ mergeBlock->eraseArgument(i);
+ }
+
// Check that whether some op in the to-be-erased blocks still has uses. Those
// uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
// region. We cannot handle such cases given that once a value is sinked into
- // the SelectionOp/LoopOp's region, there is no escape for it:
- // SelectionOp/LooOp does not support yield values right now.
+ // the SelectionOp/LoopOp's region, there is no escape for it.
for (auto *block : constructBlocks) {
for (Operation &op : *block)
if (!op.use_empty())
return op.emitOpError(
- "failed control flow structurization: it has uses outside of the "
+ "failed control flow structurization: value has uses outside of the "
+ "enclosing selection/loop construct");
+ for (BlockArgument &arg : block->getArguments())
+ if (!arg.use_empty())
+ return emitError(arg.getLoc(),
+ "failed control flow structurization: block argument has uses outside of the "
"enclosing selection/loop construct");
}
@@ -2236,7 +2318,7 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
auto *mergeBlock = mergeInfo.mergeBlock;
assert(mergeBlock && "merge block cannot be nullptr");
- if (!mergeBlock->args_empty())
+ if (mergeInfo.continueBlock != nullptr && !mergeBlock->args_empty())
return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
LLVM_DEBUG({
logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 4c15523a05fa8..918396adb0c88 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -447,6 +447,14 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
auto mergeID = getBlockID(mergeBlock);
auto loc = selectionOp.getLoc();
+ // Before we do anything wire yielded values with the result of the selection
+ // operation. The selection op is being flatten so we do not have to worry
+ // about values being defined inside a region and used outside it anymore.
+ auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
+ assert (selectionOp.getNumResults() == mergeOp.getNumOperands());
+ for (unsigned i = 0; i < selectionOp.getNumResults(); ++i)
+ selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
+
// This SelectionOp is in some MLIR block with preceding and following ops. In
// the binary format, it should reside in separate SPIR-V blocks from its
// preceding and following ops. So we need to emit unconditional branches to
@@ -483,6 +491,12 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
// instruction to start a new SPIR-V block for ops following this SelectionOp.
// The block should use the <id> for the merge block.
encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
+
+ // We do not process the mergeBlock but we still need to generate phi functions
+ // from its block arguments.
+ if (failed(emitPhiForBlockArguments(mergeBlock)))
+ return failure();
+
LLVM_DEBUG(llvm::dbgs() << "done merge ");
LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 188a55d755fd2..79c6b50fec73f 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -765,6 +765,44 @@ func.func @missing_entry_block() -> () {
// -----
+func.func @selection_yield(%cond: i1) -> () {
+ %zero = spirv.Constant 0: i32
+ %var1 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+ %var2 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+ // CHECK: {{%.*}}:2 = spirv.mlir.selection -> i32, i32 {
+ %yield:2 = spirv.mlir.selection -> i32, i32 {
+ // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
+ spirv.BranchConditional %cond, ^then, ^else
+
+ // CHECK: ^bb1
+ ^then:
+ %one = spirv.Constant 1: i32
+ %three = spirv.Constant 3: i32
+ // CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
+ spirv.Branch ^merge(%one, %three : i32, i32)
+
+ // CHECK: ^bb2
+ ^else:
+ %two = spirv.Constant 2: i32
+ %four = spirv.Constant 4 : i32
+ // CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
+ spirv.Branch ^merge(%two, %four : i32, i32)
+
+ // CHECK: ^bb3({{%.*}}: i32, {{%.*}}: i32)
+ ^merge(%merged_1_2: i32, %merged_3_4: i32):
+ // CHECK-NEXT: spirv.mlir.merge {{%.*}}, {{%.*}} : i32, i32
+ spirv.mlir.merge %merged_1_2, %merged_3_4 : i32, i32
+ }
+
+ spirv.Store "Function" %var1, %yield#0 : i32
+ spirv.Store "Function" %var2, %yield#1 : i32
+
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.Unreachable
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index 24abb12998d06..5c9b9cdf89f8a 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -150,3 +150,55 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.Return
}
}
+
+// -----
+
+// Selection yielding values
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @selection_yield
+ spirv.func @selection_yield(%cond: i1) -> () "None" {
+// CHECK-NEXT: spirv.Constant 0
+// CHECK-NEXT: spirv.Variable
+ %zero = spirv.Constant 0 : i32
+ %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+// CHECK: spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT: ^[[BB]]:
+// CHECK-NEXT: spirv.mlir.selection -> i32
+ %yield = spirv.mlir.selection -> i32 {
+// CHECK-NEXT: spirv.BranchConditional %{{.*}} [5, 10], ^[[THEN:.+]], ^[[ELSE:.+]]
+ spirv.BranchConditional %cond [5, 10], ^then, ^else
+
+// CHECK-NEXT: ^[[THEN]]:
+ ^then:
+// CHECK-NEXT: spirv.Constant 1
+ %one = spirv.Constant 1: i32
+
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]({{%.*}} : i32)
+ spirv.Branch ^merge(%one : i32)
+
+// CHECK-NEXT: ^[[ELSE]]:
+ ^else:
+// CHECK-NEXT: spirv.Constant 2
+ %two = spirv.Constant 2: i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE]]({{%.*}} : i32)
+ spirv.Branch ^merge(%two : i32)
+
+// CHECK-NEXT: ^[[MERGE]]({{%.*}}: i32):
+ ^merge(%merged: i32):
+// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+ spirv.mlir.merge %merged : i32
+ }
+
+ spirv.Store "Function" %var, %yield : i32
+
+ spirv.Return
+ }
+
+ spirv.func @main() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @main
+ spirv.ExecutionMode @main "LocalSize", 1, 1, 1
+}
|
d9b1b1f
to
e43db56
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a nice addition. I haven't had time to check the logic, so just some cosmetic issues for now
e43db56
to
83667bc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IgWod-IMG does this require deserialization changes to make sure we can roundtrip?
Do you mean serialization? Most of the work is done in the deserializer. I have also updated serialization in |
83667bc
to
5b4a14f
Compare
Am I understanding correctly if I say that this is a way to be able to return SPIR-V "results" (something like an SSA value?) directly from selection regions, rather than having to store them into an OpVariable? In other words, the problem it's solving is that values inside the region can't escape its scope currently? |
Yes, that's correct. |
Thanks. Then this seems like a reasonable way to solve the problem. I will have a closer look at the patch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My comments may be a little nitpicky but hopefully they make sense.
There are cases in SPIR-V shaders where values need to be yielded from the selection region to make valid MLIR. For example (part of the SPIR-V shader decompiled to GLSL): ``` bool _115 if (_107) { // ... float _200 = fma(...); // ... _115 = _200 < _174; } else { _115 = _107; } bool _123; if (_115) { // ... float _213 = fma(...); // ... _123 = _213 < _174; } else { _123 = _115; } ```` This patch extends `mlir.selection` so it can return values. `mlir.merge` is used as a "yield" operation. This allows to maintain a compatibility with code that does not yield any values, as well as, to maintain an assumption that `mlir.merge` is the only operation in the merge block of the selection region.
5b4a14f
to
1455e12
Compare
All feedback is addressed now. Please let me know if there are any more comments, and if not, I'll merge it once approved and passes checks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great now, thanks a lot for this!
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/140/builds/20340 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/12043 Here is the relevant piece of the build log for the reference
|
It seems this breaks on GCC 7 systems. |
Apologies for causing build failures. I have opened #134087 that hopefully addresses the issues. |
There are cases in SPIR-V shaders where values need to be yielded from the selection region to make valid MLIR. For example (part of the SPIR-V shader decompiled to GLSL):
This patch extends
mlir.selection
so it can return values.mlir.merge
is used as a "yield" operation. This allows to maintain a compatibility with code that does not yield any values, as well as, to maintain an assumption thatmlir.merge
is the only operation in the merge block of the selection region.