Skip to content

Commit 2a90631

Browse files
authored
[mlir][spirv] Allow yielding values from selection regions (#133702)
There are cases in SPIR-V shaders where values need to be yielded from the selection region to make valid MLIR. For example (part of the SPIR-V shader decompiled to GLSL): ``` bool _115 if (_107) { // ... float _200 = fma(...); // ... _115 = _200 < _174; } else { _115 = _107; } bool _123; if (_115) { // ... float _213 = fma(...); // ... _123 = _213 < _174; } else { _123 = _115; } ```` This patch extends `mlir.selection` so it can return values. `mlir.merge` is used as a "yield" operation. This allows to maintain a compatibility with code that does not yield any values, as well as, to maintain an assumption that `mlir.merge` is the only operation in the merge block of the selection region.
1 parent 581f8bc commit 2a90631

File tree

7 files changed

+270
-11
lines changed

7 files changed

+270
-11
lines changed

mlir/docs/Dialects/SPIR-V.md

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ MLIR system.
528528
We introduce a `spirv.mlir.selection` and `spirv.mlir.loop` op for structured selections and
529529
loops, respectively. The merge targets are the next ops following them. Inside
530530
their regions, a special terminator, `spirv.mlir.merge` is introduced for branching to
531-
the merge target.
531+
the merge target and yielding values.
532532

533533
### Selection
534534

@@ -603,7 +603,43 @@ func.func @selection(%cond: i1) -> () {
603603
604604
// ...
605605
}
606+
```
607+
608+
The selection can return values by yielding them with `spirv.mlir.merge`. This
609+
mechanism allows values defined within the selection region to be used outside of it.
610+
Without this, values that were sunk into the selection region, but used outside, would
611+
not be able to escape it.
612+
613+
For example
614+
615+
```mlir
616+
func.func @selection(%cond: i1) -> () {
617+
%zero = spirv.Constant 0: i32
618+
%var1 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
619+
%var2 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
620+
621+
%yield:2 = spirv.mlir.selection -> i32, i32 {
622+
spirv.BranchConditional %cond, ^then, ^else
606623
624+
^then:
625+
%one = spirv.Constant 1: i32
626+
%three = spirv.Constant 3: i32
627+
spirv.Branch ^merge(%one, %three : i32, i32)
628+
629+
^else:
630+
%two = spirv.Constant 2: i32
631+
%four = spirv.Constant 4 : i32
632+
spirv.Branch ^merge(%two, %four : i32, i32)
633+
634+
^merge(%merged_1_2: i32, %merged_3_4: i32):
635+
spirv.mlir.merge %merged_1_2, %merged_3_4 : i32, i32
636+
}
637+
638+
spirv.Store "Function" %var1, %yield#0 : i32
639+
spirv.Store "Function" %var2, %yield#1 : i32
640+
641+
spirv.Return
642+
}
607643
```
608644

609645
### Loop

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
352352
// -----
353353

354354
def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
355-
Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>]> {
355+
Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>, ReturnLike]> {
356356
let summary = "A special terminator for merging a structured selection/loop.";
357357

358358
let description = [{
@@ -361,13 +361,23 @@ def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
361361
merge point, which is the next op following the `spirv.mlir.selection` or
362362
`spirv.mlir.loop` op. This op does not have a corresponding instruction in the
363363
SPIR-V binary format; it's solely for structural purpose.
364+
365+
The instruction is also used to yield values from inside the selection/loop region
366+
to the outside, as values that were sunk into the region cannot otherwise escape it.
364367
}];
365368

366-
let arguments = (ins);
369+
let arguments = (ins Variadic<AnyType>:$operands);
367370

368371
let results = (outs);
369372

370-
let assemblyFormat = "attr-dict";
373+
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
374+
375+
let builders = [
376+
OpBuilder<(ins),
377+
[{
378+
build($_builder, $_state, ValueRange());
379+
}]>
380+
];
371381

372382
let hasOpcode = 0;
373383

@@ -465,13 +475,17 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
465475
header block, and one selection merge. The selection header block should be
466476
the first block. The selection merge block should be the last block.
467477
The merge block should only contain a `spirv.mlir.merge` op.
478+
479+
Values defined inside the selection regions cannot be directly used
480+
outside of them; however, the selection region can yield values. These values are
481+
yielded using a `spirv.mlir.merge` op and returned as a result of the selection op.
468482
}];
469483

470484
let arguments = (ins
471485
SPIRV_SelectionControlAttr:$selection_control
472486
);
473487

474-
let results = (outs);
488+
let results = (outs Variadic<AnyType>:$results);
475489

476490
let regions = (region AnyRegion:$body);
477491

@@ -494,6 +508,13 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
494508
OpBuilder &builder);
495509
}];
496510

511+
let builders = [
512+
OpBuilder<(ins "spirv::SelectionControl":$selectionControl),
513+
[{
514+
build($_builder, $_state, TypeRange(), selectionControl);
515+
}]>
516+
];
517+
497518
let hasOpcode = 0;
498519

499520
let autogenSerialization = 0;

mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,13 +452,22 @@ ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
452452
if (parseControlAttribute<spirv::SelectionControlAttr,
453453
spirv::SelectionControl>(parser, result))
454454
return failure();
455+
456+
if (succeeded(parser.parseOptionalArrow()))
457+
if (parser.parseTypeList(result.types))
458+
return failure();
459+
455460
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
456461
}
457462

458463
void SelectionOp::print(OpAsmPrinter &printer) {
459464
auto control = getSelectionControl();
460465
if (control != spirv::SelectionControl::None)
461466
printer << " control(" << spirv::stringifySelectionControl(control) << ")";
467+
if (getNumResults() > 0) {
468+
printer << " -> ";
469+
printer << getResultTypes();
470+
}
462471
printer << ' ';
463472
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
464473
/*printBlockTerminators=*/true);

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,6 +1992,29 @@ LogicalResult ControlFlowStructurizer::structurize() {
19921992
ArrayRef<Value>(blockArgs));
19931993
}
19941994

1995+
// Values defined inside the selection region that need to be yielded outside
1996+
// the region.
1997+
SmallVector<Value> valuesToYield;
1998+
// Outside uses of values that were sunk into the selection region. Those uses
1999+
// will be replaced with values returned by the SelectionOp.
2000+
SmallVector<Value> outsideUses;
2001+
2002+
// Move block arguments of the original block (`mergeBlock`) into the merge
2003+
// block inside the selection (`body.back()`). Values produced by block
2004+
// arguments will be yielded by the selection region. We do not update uses or
2005+
// erase original block arguments yet. It will be done later in the code.
2006+
if (!isLoop) {
2007+
for (BlockArgument blockArg : mergeBlock->getArguments()) {
2008+
// Create new block arguments in the last block ("merge block") of the
2009+
// selection region. We create one argument for each argument in
2010+
// `mergeBlock`. This new value will need to be yielded, and the original
2011+
// value replaced, so add them to appropriate vectors.
2012+
body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2013+
valuesToYield.push_back(body.back().getArguments().back());
2014+
outsideUses.push_back(blockArg);
2015+
}
2016+
}
2017+
19952018
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
19962019
// cleaned up.
19972020
LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
@@ -2000,17 +2023,79 @@ LogicalResult ControlFlowStructurizer::structurize() {
20002023
for (auto *block : constructBlocks)
20012024
block->dropAllReferences();
20022025

2026+
// All internal uses should be removed from original blocks by now, so
2027+
// whatever is left is an outside use and will need to be yielded from
2028+
// the newly created selection region.
2029+
if (!isLoop) {
2030+
for (Block *block : constructBlocks) {
2031+
for (Operation &op : *block) {
2032+
if (!op.use_empty())
2033+
for (Value result : op.getResults()) {
2034+
valuesToYield.push_back(mapper.lookupOrNull(result));
2035+
outsideUses.push_back(result);
2036+
}
2037+
}
2038+
for (BlockArgument &arg : block->getArguments()) {
2039+
if (!arg.use_empty()) {
2040+
valuesToYield.push_back(mapper.lookupOrNull(arg));
2041+
outsideUses.push_back(arg);
2042+
}
2043+
}
2044+
}
2045+
}
2046+
2047+
assert(valuesToYield.size() == outsideUses.size());
2048+
2049+
// If we need to yield any values from the selection region we will take
2050+
// care of it here.
2051+
if (!isLoop && !valuesToYield.empty()) {
2052+
LLVM_DEBUG(logger.startLine()
2053+
<< "[cf] yielding values from the selection region\n");
2054+
2055+
// Update `mlir.merge` with values to be yield.
2056+
auto mergeOps = body.back().getOps<spirv::MergeOp>();
2057+
Operation *merge = llvm::getSingleElement(mergeOps);
2058+
assert(merge);
2059+
merge->setOperands(valuesToYield);
2060+
2061+
// MLIR does not allow changing the number of results of an operation, so
2062+
// we create a new SelectionOp with required list of results and move
2063+
// the region from the initial SelectionOp. The initial operation is then
2064+
// removed. Since we move the region to the new op all links between blocks
2065+
// and remapping we have previously done should be preserved.
2066+
builder.setInsertionPoint(&mergeBlock->front());
2067+
auto selectionOp = builder.create<spirv::SelectionOp>(
2068+
location, TypeRange(outsideUses),
2069+
static_cast<spirv::SelectionControl>(control));
2070+
selectionOp->getRegion(0).takeBody(body);
2071+
2072+
// Remove initial op and swap the pointer to the newly created one.
2073+
op->erase();
2074+
op = selectionOp;
2075+
2076+
// Update all outside uses to use results of the SelectionOp and remove
2077+
// block arguments from the original merge block.
2078+
for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2079+
outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
2080+
for (unsigned i = 0, e = mergeBlock->getNumArguments(); i != e; ++i)
2081+
mergeBlock->eraseArgument(i);
2082+
}
2083+
20032084
// Check that whether some op in the to-be-erased blocks still has uses. Those
20042085
// uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
20052086
// region. We cannot handle such cases given that once a value is sinked into
2006-
// the SelectionOp/LoopOp's region, there is no escape for it:
2007-
// SelectionOp/LooOp does not support yield values right now.
2087+
// the SelectionOp/LoopOp's region, there is no escape for it.
20082088
for (auto *block : constructBlocks) {
20092089
for (Operation &op : *block)
20102090
if (!op.use_empty())
2011-
return op.emitOpError(
2012-
"failed control flow structurization: it has uses outside of the "
2013-
"enclosing selection/loop construct");
2091+
return op.emitOpError("failed control flow structurization: value has "
2092+
"uses outside of the "
2093+
"enclosing selection/loop construct");
2094+
for (BlockArgument &arg : block->getArguments())
2095+
if (!arg.use_empty())
2096+
return emitError(arg.getLoc(), "failed control flow structurization: "
2097+
"block argument has uses outside of the "
2098+
"enclosing selection/loop construct");
20142099
}
20152100

20162101
// Then erase all old blocks.
@@ -2236,7 +2321,7 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
22362321

22372322
auto *mergeBlock = mergeInfo.mergeBlock;
22382323
assert(mergeBlock && "merge block cannot be nullptr");
2239-
if (!mergeBlock->args_empty())
2324+
if (mergeInfo.continueBlock && !mergeBlock->args_empty())
22402325
return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
22412326
LLVM_DEBUG({
22422327
logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";

mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,15 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
447447
auto mergeID = getBlockID(mergeBlock);
448448
auto loc = selectionOp.getLoc();
449449

450+
// Before we do anything replace results of the selection operation with
451+
// values yielded (with `mlir.merge`) from inside the region. The selection op
452+
// is being flattened so we do not have to worry about values being defined
453+
// inside a region and used outside it anymore.
454+
auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
455+
assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
456+
for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
457+
selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
458+
450459
// This SelectionOp is in some MLIR block with preceding and following ops. In
451460
// the binary format, it should reside in separate SPIR-V blocks from its
452461
// preceding and following ops. So we need to emit unconditional branches to
@@ -483,6 +492,12 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
483492
// instruction to start a new SPIR-V block for ops following this SelectionOp.
484493
// The block should use the <id> for the merge block.
485494
encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
495+
496+
// We do not process the mergeBlock but we still need to generate phi
497+
// functions from its block arguments.
498+
if (failed(emitPhiForBlockArguments(mergeBlock)))
499+
return failure();
500+
486501
LLVM_DEBUG(llvm::dbgs() << "done merge ");
487502
LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
488503
LLVM_DEBUG(llvm::dbgs() << "\n");

mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,46 @@ func.func @missing_entry_block() -> () {
765765

766766
// -----
767767

768+
func.func @selection_yield(%cond: i1) -> () {
769+
%zero = spirv.Constant 0: i32
770+
%var1 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
771+
%var2 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
772+
773+
// CHECK: {{%.*}}:2 = spirv.mlir.selection -> i32, i32 {
774+
%yield:2 = spirv.mlir.selection -> i32, i32 {
775+
// CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
776+
spirv.BranchConditional %cond, ^then, ^else
777+
778+
// CHECK: ^bb1
779+
^then:
780+
%one = spirv.Constant 1: i32
781+
%three = spirv.Constant 3: i32
782+
// CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
783+
spirv.Branch ^merge(%one, %three : i32, i32)
784+
785+
// CHECK: ^bb2
786+
^else:
787+
%two = spirv.Constant 2: i32
788+
%four = spirv.Constant 4 : i32
789+
// CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
790+
spirv.Branch ^merge(%two, %four : i32, i32)
791+
792+
// CHECK: ^bb3({{%.*}}: i32, {{%.*}}: i32)
793+
^merge(%merged_1_2: i32, %merged_3_4: i32):
794+
// CHECK-NEXT: spirv.mlir.merge {{%.*}}, {{%.*}} : i32, i32
795+
spirv.mlir.merge %merged_1_2, %merged_3_4 : i32, i32
796+
}
797+
798+
// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}}#0 : i32
799+
spirv.Store "Function" %var1, %yield#0 : i32
800+
// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}}#1 : i32
801+
spirv.Store "Function" %var2, %yield#1 : i32
802+
803+
spirv.Return
804+
}
805+
806+
// -----
807+
768808
//===----------------------------------------------------------------------===//
769809
// spirv.Unreachable
770810
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)