Skip to content

Commit e43db56

Browse files
committed
[mlir][spirv] Allow yielding values from selection regions
There are cases in SPIR-V shaders where values need to be yielded from the selection region to make valid MLIR. For example (part of the SPIR-V shader decompiled to GLSL): ``` bool _115 if (_107) { // ... float _200 = fma(...); // ... _115 = _200 < _174; } else { _115 = _107; } bool _123; if (_115) { // ... float _213 = fma(...); // ... _123 = _213 < _174; } else { _123 = _115; } ```` This patch extends `mlir.selection` so it can return values. `mlir.merge` is used as a "yield" operation. This allows to maintain a compatibility with code that does not yield any values, as well as, to maintain an assumption that `mlir.merge` is the only operation in the merge block of the selection region.
1 parent 0ec9498 commit e43db56

File tree

7 files changed

+219
-19
lines changed

7 files changed

+219
-19
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,22 +352,23 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
352352
// -----
353353

354354
def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
355-
Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>]> {
355+
Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>, ReturnLike]> {
356356
let summary = "A special terminator for merging a structured selection/loop.";
357357

358358
let description = [{
359359
We use `spirv.mlir.selection`/`spirv.mlir.loop` for modelling structured selection/loop.
360360
This op is a terminator used inside their regions to mean jumping to the
361361
merge point, which is the next op following the `spirv.mlir.selection` or
362362
`spirv.mlir.loop` op. This op does not have a corresponding instruction in the
363-
SPIR-V binary format; it's solely for structural purpose.
363+
SPIR-V binary format; it's solely for structural purpose. The instruction is also
364+
used to yield values outside the selection/loop region.
364365
}];
365366

366-
let arguments = (ins);
367+
let arguments = (ins Variadic<AnyType>:$operands);
367368

368369
let results = (outs);
369370

370-
let assemblyFormat = "attr-dict";
371+
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
371372

372373
let hasOpcode = 0;
373374

@@ -471,7 +472,7 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
471472
SPIRV_SelectionControlAttr:$selection_control
472473
);
473474

474-
let results = (outs);
475+
let results = (outs Variadic<AnyType>:$results);
475476

476477
let regions = (region AnyRegion:$body);
477478

mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,11 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
227227

228228
// Create `spirv.selection` operation, selection header block and merge
229229
// block.
230-
auto selectionOp =
231-
rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
230+
auto selectionOp = rewriter.create<spirv::SelectionOp>(
231+
loc, TypeRange(), spirv::SelectionControl::None);
232232
auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
233233
selectionOp.getBody().end());
234-
rewriter.create<spirv::MergeOp>(loc);
234+
rewriter.create<spirv::MergeOp>(loc, ValueRange());
235235

236236
OpBuilder::InsertionGuard guard(rewriter);
237237
auto *selectionHeaderBlock =

mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
383383
builder.createBlock(&getBody());
384384

385385
// Add a spirv.mlir.merge op into the merge block.
386-
builder.create<spirv::MergeOp>(getLoc());
386+
builder.create<spirv::MergeOp>(getLoc(), ValueRange());
387387
}
388388

389389
//===----------------------------------------------------------------------===//
@@ -452,13 +452,22 @@ ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
452452
if (parseControlAttribute<spirv::SelectionControlAttr,
453453
spirv::SelectionControl>(parser, result))
454454
return failure();
455+
456+
if (succeeded(parser.parseOptionalArrow()))
457+
if (parser.parseTypeList(result.types))
458+
return failure();
459+
455460
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
456461
}
457462

458463
void SelectionOp::print(OpAsmPrinter &printer) {
459464
auto control = getSelectionControl();
460465
if (control != spirv::SelectionControl::None)
461466
printer << " control(" << spirv::stringifySelectionControl(control) << ")";
467+
if (getNumResults() > 0) {
468+
printer << " -> ";
469+
printer << getResultTypes();
470+
}
462471
printer << ' ';
463472
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
464473
/*printBlockTerminators=*/true);
@@ -526,15 +535,15 @@ void SelectionOp::addMergeBlock(OpBuilder &builder) {
526535
builder.createBlock(&getBody());
527536

528537
// Add a spirv.mlir.merge op into the merge block.
529-
builder.create<spirv::MergeOp>(getLoc());
538+
builder.create<spirv::MergeOp>(getLoc(), ValueRange());
530539
}
531540

532541
SelectionOp
533542
SelectionOp::createIfThen(Location loc, Value condition,
534543
function_ref<void(OpBuilder &builder)> thenBody,
535544
OpBuilder &builder) {
536-
auto selectionOp =
537-
builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
545+
auto selectionOp = builder.create<spirv::SelectionOp>(
546+
loc, TypeRange(), spirv::SelectionControl::None);
538547

539548
selectionOp.addMergeBlock(builder);
540549
Block *mergeBlock = selectionOp.getMergeBlock();

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,7 +1843,8 @@ ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
18431843
OpBuilder builder(&mergeBlock->front());
18441844

18451845
auto control = static_cast<spirv::SelectionControl>(selectionControl);
1846-
auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1846+
auto selectionOp =
1847+
builder.create<spirv::SelectionOp>(location, TypeRange(), control);
18471848
selectionOp.addMergeBlock(builder);
18481849

18491850
return selectionOp;
@@ -1992,6 +1993,29 @@ LogicalResult ControlFlowStructurizer::structurize() {
19921993
ArrayRef<Value>(blockArgs));
19931994
}
19941995

1996+
// Values defined inside the selection region that need to be yield outside
1997+
// the region.
1998+
SmallVector<Value> valuesToYield;
1999+
// Outside uses of values sank into the selection region. Those uses will be
2000+
// replaced with values returned by the SelectionOp.
2001+
SmallVector<Value> outsideUses;
2002+
2003+
// Move block arguments of the original block (`mergeBlock`) into the merge
2004+
// block inside the selection (`body.back()`). Values produced by block
2005+
// arguments will be yielded by the selection region. We do not update uses or
2006+
// erase original block arguments yet. It will be done later in the code.
2007+
if (!isLoop) {
2008+
for (BlockArgument blockArg : mergeBlock->getArguments()) {
2009+
// Create new block arguments in the last block ("merge block") of the
2010+
// selection region. We create one argument for each argument in
2011+
// `mergeBlock`. This new value will need to be yielded, and the original
2012+
// value replaced, so add them to appropriate vectors.
2013+
body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2014+
valuesToYield.push_back(body.back().getArguments().back());
2015+
outsideUses.push_back(blockArg);
2016+
}
2017+
}
2018+
19952019
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
19962020
// cleaned up.
19972021
LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
@@ -2000,17 +2024,79 @@ LogicalResult ControlFlowStructurizer::structurize() {
20002024
for (auto *block : constructBlocks)
20012025
block->dropAllReferences();
20022026

2027+
// All internal uses should be removed from original blocks by now, so
2028+
// whatever is left is an outside use and will need to be yielded from
2029+
// the newly created selection region.
2030+
if (!isLoop) {
2031+
for (Block *block : constructBlocks) {
2032+
for (Operation &op : *block) {
2033+
if (!op.use_empty())
2034+
for (Value result : op.getResults()) {
2035+
valuesToYield.push_back(mapper.lookupOrNull(result));
2036+
outsideUses.push_back(result);
2037+
}
2038+
}
2039+
for (BlockArgument &arg : block->getArguments()) {
2040+
if (!arg.use_empty()) {
2041+
valuesToYield.push_back(mapper.lookupOrNull(arg));
2042+
outsideUses.push_back(arg);
2043+
}
2044+
}
2045+
}
2046+
}
2047+
2048+
assert(valuesToYield.size() == outsideUses.size());
2049+
2050+
// If we need to yield any values from the selection region we will take
2051+
// care of it here.
2052+
if (!isLoop && !valuesToYield.empty()) {
2053+
LLVM_DEBUG(logger.startLine()
2054+
<< "[cf] yielding values from the selection region\n");
2055+
2056+
// Update `mlir.merge` with values to be yield.
2057+
auto mergeOps = body.back().getOps<spirv::MergeOp>();
2058+
assert(std::next(mergeOps.begin()) == mergeOps.end());
2059+
Operation *merge = *mergeOps.begin();
2060+
merge->setOperands(valuesToYield);
2061+
2062+
// MLIR does not allow changing the number of results of an operation, so
2063+
// we create a new SelectionOp with required list of results and move
2064+
// the region from the initial SelectionOp. The initial operation is then
2065+
// removed. Since we move the region to the new op all links between blocks
2066+
// and remapping we have previously done should be preserved.
2067+
builder.setInsertionPoint(&mergeBlock->front());
2068+
auto selectionOp = builder.create<spirv::SelectionOp>(
2069+
location, TypeRange(outsideUses),
2070+
static_cast<spirv::SelectionControl>(control));
2071+
selectionOp->getRegion(0).takeBody(body);
2072+
2073+
// Remove initial op and swap the pointer to the newly created one.
2074+
op->erase();
2075+
op = selectionOp;
2076+
2077+
// Update all outside uses to use results of the SelectionOp and remove
2078+
// block arguments from the original merge block.
2079+
for (size_t i = 0; i < outsideUses.size(); i++)
2080+
outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
2081+
for (size_t i = 0; i < mergeBlock->getNumArguments(); ++i)
2082+
mergeBlock->eraseArgument(i);
2083+
}
2084+
20032085
// Check that whether some op in the to-be-erased blocks still has uses. Those
20042086
// uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
20052087
// region. We cannot handle such cases given that once a value is sinked into
2006-
// the SelectionOp/LoopOp's region, there is no escape for it:
2007-
// SelectionOp/LooOp does not support yield values right now.
2088+
// the SelectionOp/LoopOp's region, there is no escape for it.
20082089
for (auto *block : constructBlocks) {
20092090
for (Operation &op : *block)
20102091
if (!op.use_empty())
2011-
return op.emitOpError(
2012-
"failed control flow structurization: it has uses outside of the "
2013-
"enclosing selection/loop construct");
2092+
return op.emitOpError("failed control flow structurization: value has "
2093+
"uses outside of the "
2094+
"enclosing selection/loop construct");
2095+
for (BlockArgument &arg : block->getArguments())
2096+
if (!arg.use_empty())
2097+
return emitError(arg.getLoc(), "failed control flow structurization: "
2098+
"block argument has uses outside of the "
2099+
"enclosing selection/loop construct");
20142100
}
20152101

20162102
// Then erase all old blocks.
@@ -2236,7 +2322,7 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
22362322

22372323
auto *mergeBlock = mergeInfo.mergeBlock;
22382324
assert(mergeBlock && "merge block cannot be nullptr");
2239-
if (!mergeBlock->args_empty())
2325+
if (mergeInfo.continueBlock != nullptr && !mergeBlock->args_empty())
22402326
return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
22412327
LLVM_DEBUG({
22422328
logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";

mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,14 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
447447
auto mergeID = getBlockID(mergeBlock);
448448
auto loc = selectionOp.getLoc();
449449

450+
// Before we do anything wire yielded values with the result of the selection
451+
// operation. The selection op is being flatten so we do not have to worry
452+
// about values being defined inside a region and used outside it anymore.
453+
auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
454+
assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
455+
for (unsigned i = 0; i < selectionOp.getNumResults(); ++i)
456+
selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
457+
450458
// This SelectionOp is in some MLIR block with preceding and following ops. In
451459
// the binary format, it should reside in separate SPIR-V blocks from its
452460
// preceding and following ops. So we need to emit unconditional branches to
@@ -483,6 +491,12 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
483491
// instruction to start a new SPIR-V block for ops following this SelectionOp.
484492
// The block should use the <id> for the merge block.
485493
encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
494+
495+
// We do not process the mergeBlock but we still need to generate phi
496+
// functions from its block arguments.
497+
if (failed(emitPhiForBlockArguments(mergeBlock)))
498+
return failure();
499+
486500
LLVM_DEBUG(llvm::dbgs() << "done merge ");
487501
LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
488502
LLVM_DEBUG(llvm::dbgs() << "\n");

mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,44 @@ func.func @missing_entry_block() -> () {
765765

766766
// -----
767767

768+
func.func @selection_yield(%cond: i1) -> () {
769+
%zero = spirv.Constant 0: i32
770+
%var1 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
771+
%var2 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
772+
773+
// CHECK: {{%.*}}:2 = spirv.mlir.selection -> i32, i32 {
774+
%yield:2 = spirv.mlir.selection -> i32, i32 {
775+
// CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
776+
spirv.BranchConditional %cond, ^then, ^else
777+
778+
// CHECK: ^bb1
779+
^then:
780+
%one = spirv.Constant 1: i32
781+
%three = spirv.Constant 3: i32
782+
// CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
783+
spirv.Branch ^merge(%one, %three : i32, i32)
784+
785+
// CHECK: ^bb2
786+
^else:
787+
%two = spirv.Constant 2: i32
788+
%four = spirv.Constant 4 : i32
789+
// CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
790+
spirv.Branch ^merge(%two, %four : i32, i32)
791+
792+
// CHECK: ^bb3({{%.*}}: i32, {{%.*}}: i32)
793+
^merge(%merged_1_2: i32, %merged_3_4: i32):
794+
// CHECK-NEXT: spirv.mlir.merge {{%.*}}, {{%.*}} : i32, i32
795+
spirv.mlir.merge %merged_1_2, %merged_3_4 : i32, i32
796+
}
797+
798+
spirv.Store "Function" %var1, %yield#0 : i32
799+
spirv.Store "Function" %var2, %yield#1 : i32
800+
801+
spirv.Return
802+
}
803+
804+
// -----
805+
768806
//===----------------------------------------------------------------------===//
769807
// spirv.Unreachable
770808
//===----------------------------------------------------------------------===//

mlir/test/Target/SPIRV/selection.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,55 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
150150
spirv.Return
151151
}
152152
}
153+
154+
// -----
155+
156+
// Selection yielding values
157+
158+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
159+
// CHECK-LABEL: @selection_yield
160+
spirv.func @selection_yield(%cond: i1) -> () "None" {
161+
// CHECK-NEXT: spirv.Constant 0
162+
// CHECK-NEXT: spirv.Variable
163+
%zero = spirv.Constant 0 : i32
164+
%var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
165+
166+
// CHECK: spirv.Branch ^[[BB:.+]]
167+
// CHECK-NEXT: ^[[BB]]:
168+
// CHECK-NEXT: spirv.mlir.selection -> i32
169+
%yield = spirv.mlir.selection -> i32 {
170+
// CHECK-NEXT: spirv.BranchConditional %{{.*}} [5, 10], ^[[THEN:.+]], ^[[ELSE:.+]]
171+
spirv.BranchConditional %cond [5, 10], ^then, ^else
172+
173+
// CHECK-NEXT: ^[[THEN]]:
174+
^then:
175+
// CHECK-NEXT: spirv.Constant 1
176+
%one = spirv.Constant 1: i32
177+
178+
// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]({{%.*}} : i32)
179+
spirv.Branch ^merge(%one : i32)
180+
181+
// CHECK-NEXT: ^[[ELSE]]:
182+
^else:
183+
// CHECK-NEXT: spirv.Constant 2
184+
%two = spirv.Constant 2: i32
185+
// CHECK-NEXT: spirv.Branch ^[[MERGE]]({{%.*}} : i32)
186+
spirv.Branch ^merge(%two : i32)
187+
188+
// CHECK-NEXT: ^[[MERGE]]({{%.*}}: i32):
189+
^merge(%merged: i32):
190+
// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
191+
spirv.mlir.merge %merged : i32
192+
}
193+
194+
spirv.Store "Function" %var, %yield : i32
195+
196+
spirv.Return
197+
}
198+
199+
spirv.func @main() -> () "None" {
200+
spirv.Return
201+
}
202+
spirv.EntryPoint "GLCompute" @main
203+
spirv.ExecutionMode @main "LocalSize", 1, 1, 1
204+
}

0 commit comments

Comments
 (0)