Skip to content

Commit 5b4a14f

Browse files
committed
[mlir][spirv] Allow yielding values from selection regions
There are cases in SPIR-V shaders where values need to be yielded from the selection region to make valid MLIR. For example (part of the SPIR-V shader decompiled to GLSL): ``` bool _115 if (_107) { // ... float _200 = fma(...); // ... _115 = _200 < _174; } else { _115 = _107; } bool _123; if (_115) { // ... float _213 = fma(...); // ... _123 = _213 < _174; } else { _123 = _115; } ```` This patch extends `mlir.selection` so it can return values. `mlir.merge` is used as a "yield" operation. This allows to maintain a compatibility with code that does not yield any values, as well as, to maintain an assumption that `mlir.merge` is the only operation in the merge block of the selection region.
1 parent 0ec9498 commit 5b4a14f

File tree

6 files changed

+224
-11
lines changed

6 files changed

+224
-11
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,22 +352,30 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
352352
// -----
353353

354354
def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
355-
Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>]> {
355+
Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>, ReturnLike]> {
356356
let summary = "A special terminator for merging a structured selection/loop.";
357357

358358
let description = [{
359359
We use `spirv.mlir.selection`/`spirv.mlir.loop` for modelling structured selection/loop.
360360
This op is a terminator used inside their regions to mean jumping to the
361361
merge point, which is the next op following the `spirv.mlir.selection` or
362362
`spirv.mlir.loop` op. This op does not have a corresponding instruction in the
363-
SPIR-V binary format; it's solely for structural purpose.
363+
SPIR-V binary format; it's solely for structural purpose. The instruction is also
364+
used to yield values outside the selection/loop region.
364365
}];
365366

366-
let arguments = (ins);
367+
let arguments = (ins Variadic<AnyType>:$operands);
367368

368369
let results = (outs);
369370

370-
let assemblyFormat = "attr-dict";
371+
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
372+
373+
let builders = [
374+
OpBuilder<(ins),
375+
[{
376+
build($_builder, $_state, ValueRange());
377+
}]>
378+
];
371379

372380
let hasOpcode = 0;
373381

@@ -471,7 +479,7 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
471479
SPIRV_SelectionControlAttr:$selection_control
472480
);
473481

474-
let results = (outs);
482+
let results = (outs Variadic<AnyType>:$results);
475483

476484
let regions = (region AnyRegion:$body);
477485

@@ -494,6 +502,13 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
494502
OpBuilder &builder);
495503
}];
496504

505+
let builders = [
506+
OpBuilder<(ins "spirv::SelectionControl":$selectionControl),
507+
[{
508+
build($_builder, $_state, TypeRange(), selectionControl);
509+
}]>
510+
];
511+
497512
let hasOpcode = 0;
498513

499514
let autogenSerialization = 0;

mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,13 +452,22 @@ ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
452452
if (parseControlAttribute<spirv::SelectionControlAttr,
453453
spirv::SelectionControl>(parser, result))
454454
return failure();
455+
456+
if (succeeded(parser.parseOptionalArrow()))
457+
if (parser.parseTypeList(result.types))
458+
return failure();
459+
455460
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
456461
}
457462

458463
void SelectionOp::print(OpAsmPrinter &printer) {
459464
auto control = getSelectionControl();
460465
if (control != spirv::SelectionControl::None)
461466
printer << " control(" << spirv::stringifySelectionControl(control) << ")";
467+
if (getNumResults() > 0) {
468+
printer << " -> ";
469+
printer << getResultTypes();
470+
}
462471
printer << ' ';
463472
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
464473
/*printBlockTerminators=*/true);

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,6 +1992,29 @@ LogicalResult ControlFlowStructurizer::structurize() {
19921992
ArrayRef<Value>(blockArgs));
19931993
}
19941994

1995+
// Values defined inside the selection region that need to be yield outside
1996+
// the region.
1997+
SmallVector<Value> valuesToYield;
1998+
// Outside uses of values sank into the selection region. Those uses will be
1999+
// replaced with values returned by the SelectionOp.
2000+
SmallVector<Value> outsideUses;
2001+
2002+
// Move block arguments of the original block (`mergeBlock`) into the merge
2003+
// block inside the selection (`body.back()`). Values produced by block
2004+
// arguments will be yielded by the selection region. We do not update uses or
2005+
// erase original block arguments yet. It will be done later in the code.
2006+
if (!isLoop) {
2007+
for (BlockArgument blockArg : mergeBlock->getArguments()) {
2008+
// Create new block arguments in the last block ("merge block") of the
2009+
// selection region. We create one argument for each argument in
2010+
// `mergeBlock`. This new value will need to be yielded, and the original
2011+
// value replaced, so add them to appropriate vectors.
2012+
body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2013+
valuesToYield.push_back(body.back().getArguments().back());
2014+
outsideUses.push_back(blockArg);
2015+
}
2016+
}
2017+
19952018
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
19962019
// cleaned up.
19972020
LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
@@ -2000,17 +2023,79 @@ LogicalResult ControlFlowStructurizer::structurize() {
20002023
for (auto *block : constructBlocks)
20012024
block->dropAllReferences();
20022025

2026+
// All internal uses should be removed from original blocks by now, so
2027+
// whatever is left is an outside use and will need to be yielded from
2028+
// the newly created selection region.
2029+
if (!isLoop) {
2030+
for (Block *block : constructBlocks) {
2031+
for (Operation &op : *block) {
2032+
if (!op.use_empty())
2033+
for (Value result : op.getResults()) {
2034+
valuesToYield.push_back(mapper.lookupOrNull(result));
2035+
outsideUses.push_back(result);
2036+
}
2037+
}
2038+
for (BlockArgument &arg : block->getArguments()) {
2039+
if (!arg.use_empty()) {
2040+
valuesToYield.push_back(mapper.lookupOrNull(arg));
2041+
outsideUses.push_back(arg);
2042+
}
2043+
}
2044+
}
2045+
}
2046+
2047+
assert(valuesToYield.size() == outsideUses.size());
2048+
2049+
// If we need to yield any values from the selection region we will take
2050+
// care of it here.
2051+
if (!isLoop && !valuesToYield.empty()) {
2052+
LLVM_DEBUG(logger.startLine()
2053+
<< "[cf] yielding values from the selection region\n");
2054+
2055+
// Update `mlir.merge` with values to be yield.
2056+
auto mergeOps = body.back().getOps<spirv::MergeOp>();
2057+
Operation *merge = llvm::getSingleElement(mergeOps);
2058+
assert(merge);
2059+
merge->setOperands(valuesToYield);
2060+
2061+
// MLIR does not allow changing the number of results of an operation, so
2062+
// we create a new SelectionOp with required list of results and move
2063+
// the region from the initial SelectionOp. The initial operation is then
2064+
// removed. Since we move the region to the new op all links between blocks
2065+
// and remapping we have previously done should be preserved.
2066+
builder.setInsertionPoint(&mergeBlock->front());
2067+
auto selectionOp = builder.create<spirv::SelectionOp>(
2068+
location, TypeRange(outsideUses),
2069+
static_cast<spirv::SelectionControl>(control));
2070+
selectionOp->getRegion(0).takeBody(body);
2071+
2072+
// Remove initial op and swap the pointer to the newly created one.
2073+
op->erase();
2074+
op = selectionOp;
2075+
2076+
// Update all outside uses to use results of the SelectionOp and remove
2077+
// block arguments from the original merge block.
2078+
for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2079+
outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
2080+
for (unsigned i = 0, e = mergeBlock->getNumArguments(); i != e; ++i)
2081+
mergeBlock->eraseArgument(i);
2082+
}
2083+
20032084
// Check that whether some op in the to-be-erased blocks still has uses. Those
20042085
// uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
20052086
// region. We cannot handle such cases given that once a value is sinked into
2006-
// the SelectionOp/LoopOp's region, there is no escape for it:
2007-
// SelectionOp/LooOp does not support yield values right now.
2087+
// the SelectionOp/LoopOp's region, there is no escape for it.
20082088
for (auto *block : constructBlocks) {
20092089
for (Operation &op : *block)
20102090
if (!op.use_empty())
2011-
return op.emitOpError(
2012-
"failed control flow structurization: it has uses outside of the "
2013-
"enclosing selection/loop construct");
2091+
return op.emitOpError("failed control flow structurization: value has "
2092+
"uses outside of the "
2093+
"enclosing selection/loop construct");
2094+
for (BlockArgument &arg : block->getArguments())
2095+
if (!arg.use_empty())
2096+
return emitError(arg.getLoc(), "failed control flow structurization: "
2097+
"block argument has uses outside of the "
2098+
"enclosing selection/loop construct");
20142099
}
20152100

20162101
// Then erase all old blocks.
@@ -2236,7 +2321,7 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
22362321

22372322
auto *mergeBlock = mergeInfo.mergeBlock;
22382323
assert(mergeBlock && "merge block cannot be nullptr");
2239-
if (!mergeBlock->args_empty())
2324+
if (mergeInfo.continueBlock && !mergeBlock->args_empty())
22402325
return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
22412326
LLVM_DEBUG({
22422327
logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";

mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,14 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
447447
auto mergeID = getBlockID(mergeBlock);
448448
auto loc = selectionOp.getLoc();
449449

450+
// Before we do anything wire yielded values with the result of the selection
451+
// operation. The selection op is being flatten so we do not have to worry
452+
// about values being defined inside a region and used outside it anymore.
453+
auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
454+
assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
455+
for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
456+
selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
457+
450458
// This SelectionOp is in some MLIR block with preceding and following ops. In
451459
// the binary format, it should reside in separate SPIR-V blocks from its
452460
// preceding and following ops. So we need to emit unconditional branches to
@@ -483,6 +491,12 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
483491
// instruction to start a new SPIR-V block for ops following this SelectionOp.
484492
// The block should use the <id> for the merge block.
485493
encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
494+
495+
// We do not process the mergeBlock but we still need to generate phi
496+
// functions from its block arguments.
497+
if (failed(emitPhiForBlockArguments(mergeBlock)))
498+
return failure();
499+
486500
LLVM_DEBUG(llvm::dbgs() << "done merge ");
487501
LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
488502
LLVM_DEBUG(llvm::dbgs() << "\n");

mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,44 @@ func.func @missing_entry_block() -> () {
765765

766766
// -----
767767

768+
func.func @selection_yield(%cond: i1) -> () {
769+
%zero = spirv.Constant 0: i32
770+
%var1 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
771+
%var2 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
772+
773+
// CHECK: {{%.*}}:2 = spirv.mlir.selection -> i32, i32 {
774+
%yield:2 = spirv.mlir.selection -> i32, i32 {
775+
// CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
776+
spirv.BranchConditional %cond, ^then, ^else
777+
778+
// CHECK: ^bb1
779+
^then:
780+
%one = spirv.Constant 1: i32
781+
%three = spirv.Constant 3: i32
782+
// CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
783+
spirv.Branch ^merge(%one, %three : i32, i32)
784+
785+
// CHECK: ^bb2
786+
^else:
787+
%two = spirv.Constant 2: i32
788+
%four = spirv.Constant 4 : i32
789+
// CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
790+
spirv.Branch ^merge(%two, %four : i32, i32)
791+
792+
// CHECK: ^bb3({{%.*}}: i32, {{%.*}}: i32)
793+
^merge(%merged_1_2: i32, %merged_3_4: i32):
794+
// CHECK-NEXT: spirv.mlir.merge {{%.*}}, {{%.*}} : i32, i32
795+
spirv.mlir.merge %merged_1_2, %merged_3_4 : i32, i32
796+
}
797+
798+
spirv.Store "Function" %var1, %yield#0 : i32
799+
spirv.Store "Function" %var2, %yield#1 : i32
800+
801+
spirv.Return
802+
}
803+
804+
// -----
805+
768806
//===----------------------------------------------------------------------===//
769807
// spirv.Unreachable
770808
//===----------------------------------------------------------------------===//

mlir/test/Target/SPIRV/selection.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,55 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
150150
spirv.Return
151151
}
152152
}
153+
154+
// -----
155+
156+
// Selection yielding values
157+
158+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
159+
// CHECK-LABEL: @selection_yield
160+
spirv.func @selection_yield(%cond: i1) -> () "None" {
161+
// CHECK-NEXT: spirv.Constant 0
162+
// CHECK-NEXT: spirv.Variable
163+
%zero = spirv.Constant 0 : i32
164+
%var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
165+
166+
// CHECK: spirv.Branch ^[[BB:.+]]
167+
// CHECK-NEXT: ^[[BB]]:
168+
// CHECK-NEXT: spirv.mlir.selection -> i32
169+
%yield = spirv.mlir.selection -> i32 {
170+
// CHECK-NEXT: spirv.BranchConditional %{{.*}} [5, 10], ^[[THEN:.+]], ^[[ELSE:.+]]
171+
spirv.BranchConditional %cond [5, 10], ^then, ^else
172+
173+
// CHECK-NEXT: ^[[THEN]]:
174+
^then:
175+
// CHECK-NEXT: spirv.Constant 1
176+
%one = spirv.Constant 1: i32
177+
178+
// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]({{%.*}} : i32)
179+
spirv.Branch ^merge(%one : i32)
180+
181+
// CHECK-NEXT: ^[[ELSE]]:
182+
^else:
183+
// CHECK-NEXT: spirv.Constant 2
184+
%two = spirv.Constant 2: i32
185+
// CHECK-NEXT: spirv.Branch ^[[MERGE]]({{%.*}} : i32)
186+
spirv.Branch ^merge(%two : i32)
187+
188+
// CHECK-NEXT: ^[[MERGE]]({{%.*}}: i32):
189+
^merge(%merged: i32):
190+
// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
191+
spirv.mlir.merge %merged : i32
192+
}
193+
194+
spirv.Store "Function" %var, %yield : i32
195+
196+
spirv.Return
197+
}
198+
199+
spirv.func @main() -> () "None" {
200+
spirv.Return
201+
}
202+
spirv.EntryPoint "GLCompute" @main
203+
spirv.ExecutionMode @main "LocalSize", 1, 1, 1
204+
}

0 commit comments

Comments
 (0)