@@ -6550,29 +6550,33 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
6550
6550
}
6551
6551
6552
6552
void MaskOp::ensureTerminator (Region ®ion, Builder &builder, Location loc) {
6553
- OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
6554
- MaskOp>::ensureTerminator (region, builder, loc);
6555
- // Keep the default yield terminator if the number of masked operations is not
6556
- // the expected. This case will trigger a verification failure.
6553
+ // 1. For an empty `vector.mask`, create a default terminator.
6554
+ if (region.empty () || region.front ().empty ()) {
6555
+ OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
6556
+ MaskOp>::ensureTerminator (region, builder, loc);
6557
+ return ;
6558
+ }
6559
+
6560
+ // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
6557
6561
Block &block = region.front ();
6558
- if (block.getOperations (). size () != 2 )
6562
+ if (isa<vector::YieldOp>( block.back ()) )
6559
6563
return ;
6560
6564
6561
- // Replace default yield terminator with a new one that returns the results
6562
- // from the masked operation.
6563
- OpBuilder opBuilder (builder.getContext ());
6564
- Operation *maskedOp = &block.front ();
6565
- Operation *oldYieldOp = &block.back ();
6566
- assert (isa<vector::YieldOp>(oldYieldOp) && " Expected vector::YieldOp" );
6565
+ // 3. For a non-empty `vector.mask` without an explicit terminator:
6567
6566
6568
- // Empty vector.mask op.
6569
- if (maskedOp == oldYieldOp)
6567
+ // Create default terminator if the number of masked operations is not
6568
+ // one. This case will trigger a verification failure.
6569
+ if (block.getOperations ().size () != 1 ) {
6570
+ OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
6571
+ MaskOp>::ensureTerminator (region, builder, loc);
6570
6572
return ;
6573
+ }
6571
6574
6572
- opBuilder.setInsertionPoint (oldYieldOp);
6575
+ // Create a terminator that yields the results from the masked operation.
6576
+ OpBuilder opBuilder (builder.getContext ());
6577
+ Operation *maskedOp = &block.front ();
6578
+ opBuilder.setInsertionPointToEnd (&block);
6573
6579
opBuilder.create <vector::YieldOp>(loc, maskedOp->getResults ());
6574
- oldYieldOp->dropAllReferences ();
6575
- oldYieldOp->erase ();
6576
6580
}
6577
6581
6578
6582
LogicalResult MaskOp::verify () {
@@ -6607,6 +6611,10 @@ LogicalResult MaskOp::verify() {
6607
6611
return emitOpError (" expects number of results to match maskable operation "
6608
6612
" number of results" );
6609
6613
6614
+ if (!llvm::equal (maskableOp->getResults (), terminator.getOperands ()))
6615
+ return emitOpError (" expects all the results from the MaskableOpInterface "
6616
+ " to match all the values returned by the terminator" );
6617
+
6610
6618
if (!llvm::equal (maskableOp->getResultTypes (), getResultTypes ()))
6611
6619
return emitOpError (
6612
6620
" expects result type to match maskable operation result type" );
0 commit comments