Skip to content

[mlir][Vector] Verify that masked ops implement MaskableOpInterface #108123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 19, 2024

Conversation

dcaballe
Copy link
Contributor

This PR fixes a bug in MaskOp::verifier that allowed vector.mask to mask operations that did not implement the MaskableOpInterface.

This PR fixes a bug in `MaskOp::verifier` that allowed `vector.mask` to mask
operations that did not implement the MaskableOpInterface.
@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

Changes

This PR fixes a bug in MaskOp::verifier that allowed vector.mask to mask operations that did not implement the MaskableOpInterface.


Full diff: https://github.com/llvm/llvm-project/pull/108123.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+8-3)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+7-5)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+8)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..62f9943e93b9cf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6124,7 +6124,9 @@ LogicalResult MaskOp::verify() {
   Block &block = getMaskRegion().getBlocks().front();
   if (block.getOperations().empty())
     return emitOpError("expects a terminator within the mask region");
-  if (block.getOperations().size() > 2)
+
+  unsigned numMaskRegionOps = block.getOperations().size();
+  if (numMaskRegionOps > 2)
     return emitOpError("expects only one operation to mask");
 
   // Terminator checks.
@@ -6136,11 +6138,14 @@ LogicalResult MaskOp::verify() {
     return emitOpError(
         "expects number of results to match mask region yielded values");
 
-  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
   // Empty vector.mask. Nothing else to check.
-  if (!maskableOp)
+  if (numMaskRegionOps == 1)
     return success();
 
+  auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
+  if (!maskableOp)
+    return emitOpError("expects a MaskableOpInterface within the mask region");
+
   // Result checks.
   if (maskableOp->getNumResults() != getNumResults())
     return emitOpError("expects number of results to match maskable operation "
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e71a6eb02ea46c..b7c78de4b5bd89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2471,13 +2471,15 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
 // -----
 
 // CHECK-LABEL: func @all_true_vector_mask
-//  CHECK-SAME:     %[[IN:.*]]: vector<3x4xf32>
-func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> {
+//  CHECK-SAME:     %[[IN:.*]]: tensor<3x4xf32>
+func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {
 //   CHECK-NOT:   vector.mask
-//       CHECK:   %[[ADD:.*]] = arith.addf %[[IN]], %[[IN]] : vector<3x4xf32>
-//       CHECK:   return %[[ADD]] : vector<3x4xf32>
+//       CHECK:   %[[LD:.*]] = vector.transfer_read %[[IN]]
+//       CHECK:   return %[[LD]] : vector<3x4xf32>
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
   %all_true = vector.constant_mask [3, 4] : vector<3x4xi1>
-  %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
+  %0 = vector.mask %all_true { vector.transfer_read %ta[%c0, %c0], %cf0 : tensor<3x4xf32>, vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
   return %0 : vector<3x4xf32>
 }
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c95b8bd5ed6147..e2bc5ef6128e7d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1724,6 +1724,14 @@ func.func @vector_mask_passthru_no_return(%val: vector<16xf32>, %t0: tensor<?xf3
   vector.mask %m0, %pt0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> vector<16xf32>
   return
 }
+// -----
+
+func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32> {
+   %m0 = vector.constant_mask [2, 2] : vector<3x4xi1>
+  // expected-error@+1 {{'vector.mask' op expects a MaskableOpInterface within the mask region}}
+   %0 = vector.mask %m0 { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
+   return %0 : vector<3x4xf32>
+}
 
 // -----
 

if (block.getOperations().size() > 2)

unsigned numMaskRegionOps = block.getOperations().size();
if (numMaskRegionOps > 2)
return emitOpError("expects only one operation to mask");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like MaskOp::getMaskableOp() is redundantly testing for this condition unnecessarily, and the MaskOpRewritePattern is then using cast_or_null where cast could be used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaskOp can be empty so that's why getMaskableOp() is checking for something similar to return nullptr. That's also the reason to use cast_or_null

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right!

The code is confusing, and not the most efficient: if (block->getOperations().size() < 2) (size() on a linked list is O(N)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's why I introduced numMaskRegionOps instead of calling size() again. However, a valid vector mask would always have 2 ops at most so I thought it would be acceptable. I could implement similar logic with block.getOperations().empty() + compare block.getOperations().begin() and block.getOperations().end() but I think it would make it more confusing... WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

Copy link
Contributor

@CoTinker CoTinker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

// Empty vector.mask. Nothing else to check.
if (!maskableOp)
if (numMaskRegionOps == 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be comparing against 0?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a terminator/

@dcaballe dcaballe merged commit bcd65ba into llvm:main Sep 19, 2024
12 checks passed
tmsri pushed a commit to tmsri/llvm-project that referenced this pull request Sep 19, 2024
…lvm#108123)

This PR fixes a bug in `MaskOp::verifier` that allowed `vector.mask` to
mask operations that did not implement the MaskableOpInterface.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants