Skip to content

[mlir][Vector] Improve vector.mask verifier #139823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2482,8 +2482,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
masked. Values used within the region are captured from above. Only one
*maskable* operation can be masked with a `vector.mask` operation at a time.
An operation is *maskable* if it implements the `MaskableOpInterface`. The
terminator yields all results of the maskable operation to the result of
this operation.
terminator yields all results from the maskable operation to the result of
this operation. No other values are allowed to be yielded.

An empty `vector.mask` operation is currently legal to enable optimizations
across the `vector.mask` region. However, this might change in the future
once vector transformations gain better support for `vector.mask`.
TODO: Consider making empty `vector.mask` illegal.

The vector mask argument holds a bit for each vector lane and determines
which vector lanes should execute the maskable operation and which ones
Expand Down
40 changes: 24 additions & 16 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6550,29 +6550,33 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
}

void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
// Keep the default yield terminator if the number of masked operations is not
// the expected. This case will trigger a verification failure.
// 1. For an empty `vector.mask`, create a default terminator.
if (region.empty() || region.front().empty()) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
return;
}

// 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
Block &block = region.front();
if (block.getOperations().size() != 2)
if (isa<vector::YieldOp>(block.back()))
return;

// Replace default yield terminator with a new one that returns the results
// from the masked operation.
OpBuilder opBuilder(builder.getContext());
Operation *maskedOp = &block.front();
Operation *oldYieldOp = &block.back();
assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
// 3. For a non-empty `vector.mask` without an explicit terminator:

// Empty vector.mask op.
if (maskedOp == oldYieldOp)
// Create default terminator if the number of masked operations is not
// one. This case will trigger a verification failure.
if (block.getOperations().size() != 1) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
return;
}

opBuilder.setInsertionPoint(oldYieldOp);
// Create a terminator that yields the results from the masked operation.
OpBuilder opBuilder(builder.getContext());
Operation *maskedOp = &block.front();
opBuilder.setInsertionPointToEnd(&block);
opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
oldYieldOp->dropAllReferences();
oldYieldOp->erase();
}

LogicalResult MaskOp::verify() {
Expand Down Expand Up @@ -6607,6 +6611,10 @@ LogicalResult MaskOp::verify() {
return emitOpError("expects number of results to match maskable operation "
"number of results");

if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
return emitOpError("expects all the results from the MaskableOpInterface "
"to match all the values returned by the terminator");
Comment on lines +6615 to +6616
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After seeing the comment from @newling , I've realised that we are special-casing an empty vector.mask. This example will not trigger the error:

%0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>

That's a bit of inconsistency. Perhaps leave a TODO to address this at some point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, my intent with this PR is to make clearer that an empty vector.mask is not a common valid case to mask operations and that may eventually go away. Let me clarify that a bit better in the doc. We would need the CSE equivalence issue to be fixed and improve some of the existing vector transformations. Definitely a target we should be moving towards.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

I would still appreciate a TODO or some comment - mostly for my future self as a reminder about this conversation :)


if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
return emitOpError(
"expects result type to match maskable operation result type");
Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,20 @@ func.func @vector_mask_empty_passthru_no_return_type(%mask : vector<8xi1>,

// -----

func.func @vector_mask_non_empty_external_return(%t: tensor<?xf32>, %idx: index,
%m: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> {
%ft0 = arith.constant 0.0 : f32
// expected-error@+1 {{'vector.mask' op expects all the results from the MaskableOpInterface to match all the values returned by the terminator}}
%0 = vector.mask %m {
%1 =vector.transfer_read %t[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
vector.yield %ext : vector<16xf32>
} : vector<16xi1> -> vector<16xf32>

return %0 : vector<16xf32>
}

// -----

func.func @vector_mask_empty_passthru_empty_return_type(%mask : vector<8xi1>,
%passthru : vector<8xi32>) {
// expected-error@+1 {{'vector.mask' expects a result if passthru operand is provided}}
Expand All @@ -1765,6 +1779,20 @@ func.func @vector_mask_empty_passthru_empty_return_type(%mask : vector<8xi1>,

// -----

func.func @vector_mask_non_empty_mixed_return(%t: tensor<?xf32>, %idx: index,
%m: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) {
%ft0 = arith.constant 0.0 : f32
// expected-error@+1 {{'vector.mask' op expects number of results to match maskable operation number of results}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails today with this error:

 error: 'vector.mask' op expects number of results to match mask region yielded values

Its not clear to me what generates this new error.

Copy link
Contributor Author

@dcaballe dcaballe May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code in ensureTerminator was replacing the explicitly-provided terminator with a different one.

%0:2 = vector.mask %m {
%1 =vector.transfer_read %t[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
vector.yield %1, %ext : vector<16xf32>, vector<16xf32>
} : vector<16xi1> -> (vector<16xf32>, vector<16xf32>)

return %0#0, %0#1 : vector<16xf32>, vector<16xf32>
}

// -----

func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
// expected-error@+1 {{op failed to verify that position is a multiple of the source length.}}
%0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>
Expand Down
Loading