Skip to content

Commit 1878259

Browse files
authored
[mlir][spirv] Update verifier for spirv.mlir.merge (#133427)
- Moves the verification logic to the `verifyRegions` method of the parent operation. - Fixes a crash during verification when the last block lacks a terminator. Fixes #132850.
1 parent 5b3e152 commit 1878259

File tree

3 files changed

+34
-29
lines changed

3 files changed

+34
-29
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,8 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
351351

352352
// -----
353353

354-
def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [Pure, Terminator]> {
354+
def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
355+
Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>]> {
355356
let summary = "A special terminator for merging a structured selection/loop.";
356357

357358
let description = [{
@@ -371,6 +372,8 @@ def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [Pure, Terminator]> {
371372
let hasOpcode = 0;
372373

373374
let autogenSerialization = 0;
375+
376+
let hasVerifier = 0;
374377
}
375378

376379
// -----

mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,13 @@ static bool isMergeBlock(Block &block) {
259259
isa<spirv::MergeOp>(block.front());
260260
}
261261

262+
/// Returns true if a `spirv.mlir.merge` op outside the merge block.
263+
static bool hasOtherMerge(Region &region) {
264+
return !region.empty() && llvm::any_of(region.getOps(), [&](Operation &op) {
265+
return isa<spirv::MergeOp>(op) && op.getBlock() != &region.back();
266+
});
267+
}
268+
262269
LogicalResult LoopOp::verifyRegions() {
263270
auto *op = getOperation();
264271

@@ -298,6 +305,9 @@ LogicalResult LoopOp::verifyRegions() {
298305
if (!isMergeBlock(merge))
299306
return emitOpError("last block must be the merge block with only one "
300307
"'spirv.mlir.merge' op");
308+
if (hasOtherMerge(region))
309+
return emitOpError(
310+
"should not have 'spirv.mlir.merge' op outside the merge block");
301311

302312
if (std::next(region.begin()) == region.end())
303313
return emitOpError(
@@ -377,24 +387,6 @@ void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
377387
builder.create<spirv::MergeOp>(getLoc());
378388
}
379389

380-
//===----------------------------------------------------------------------===//
381-
// spirv.mlir.merge
382-
//===----------------------------------------------------------------------===//
383-
384-
LogicalResult MergeOp::verify() {
385-
auto *parentOp = (*this)->getParentOp();
386-
if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
387-
return emitOpError(
388-
"expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
389-
390-
// TODO: This check should be done in `verifyRegions` of parent op.
391-
Block &parentLastBlock = (*this)->getParentRegion()->back();
392-
if (getOperation() != parentLastBlock.getTerminator())
393-
return emitOpError("can only be used in the last block of "
394-
"'spirv.mlir.selection' or 'spirv.mlir.loop'");
395-
return success();
396-
}
397-
398390
//===----------------------------------------------------------------------===//
399391
// spirv.Return
400392
//===----------------------------------------------------------------------===//
@@ -507,6 +499,9 @@ LogicalResult SelectionOp::verifyRegions() {
507499
if (!isMergeBlock(region.back()))
508500
return emitOpError("last block must be the merge block with only one "
509501
"'spirv.mlir.merge' op");
502+
if (hasOtherMerge(region))
503+
return emitOpError(
504+
"should not have 'spirv.mlir.merge' op outside the merge block");
510505

511506
if (std::next(region.begin()) == region.end())
512507
return emitOpError("must have a selection header block");

mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -431,42 +431,49 @@ func.func @only_entry_and_continue_branch_to_header() -> () {
431431
//===----------------------------------------------------------------------===//
432432

433433
func.func @merge() -> () {
434-
// expected-error @+1 {{expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'}}
434+
// expected-error @+1 {{expects parent op to be one of 'spirv.mlir.selection, spirv.mlir.loop'}}
435435
spirv.mlir.merge
436436
}
437437

438438
// -----
439439

440440
func.func @only_allowed_in_last_block(%cond : i1) -> () {
441-
%zero = spirv.Constant 0: i32
442-
%one = spirv.Constant 1: i32
443-
%var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
444-
441+
// expected-error @+1 {{'spirv.mlir.selection' op should not have 'spirv.mlir.merge' op outside the merge block}}
445442
spirv.mlir.selection {
446443
spirv.BranchConditional %cond, ^then, ^merge
447-
448444
^then:
449-
spirv.Store "Function" %var, %one : i32
450-
// expected-error @+1 {{can only be used in the last block of 'spirv.mlir.selection' or 'spirv.mlir.loop'}}
451445
spirv.mlir.merge
452-
453446
^merge:
454447
spirv.mlir.merge
455448
}
449+
spirv.Return
450+
}
456451

452+
// -----
453+
454+
// Ensure this case not crash
455+
456+
func.func @last_block_no_terminator(%cond : i1) -> () {
457+
// expected-error @+1 {{empty block: expect at least a terminator}}
458+
spirv.mlir.selection {
459+
spirv.BranchConditional %cond, ^then, ^merge
460+
^then:
461+
spirv.mlir.merge
462+
^merge:
463+
}
457464
spirv.Return
458465
}
459466

460467
// -----
461468

462469
func.func @only_allowed_in_last_block() -> () {
463470
%true = spirv.Constant true
471+
// expected-error @+1 {{'spirv.mlir.loop' op should not have 'spirv.mlir.merge' op outside the merge block}}
464472
spirv.mlir.loop {
465473
spirv.Branch ^header
466474
^header:
467475
spirv.BranchConditional %true, ^body, ^merge
468476
^body:
469-
// expected-error @+1 {{can only be used in the last block of 'spirv.mlir.selection' or 'spirv.mlir.loop'}}
470477
spirv.mlir.merge
471478
^continue:
472479
spirv.Branch ^header

0 commit comments

Comments
 (0)