-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Verify that masked ops implement MaskableOpInterface #108123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR fixes a bug in `MaskOp::verifier` that allowed `vector.mask` to mask operations that did not implement the MaskableOpInterface.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesThis PR fixes a bug in Full diff: https://github.com/llvm/llvm-project/pull/108123.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..62f9943e93b9cf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6124,7 +6124,9 @@ LogicalResult MaskOp::verify() {
Block &block = getMaskRegion().getBlocks().front();
if (block.getOperations().empty())
return emitOpError("expects a terminator within the mask region");
- if (block.getOperations().size() > 2)
+
+ unsigned numMaskRegionOps = block.getOperations().size();
+ if (numMaskRegionOps > 2)
return emitOpError("expects only one operation to mask");
// Terminator checks.
@@ -6136,11 +6138,14 @@ LogicalResult MaskOp::verify() {
return emitOpError(
"expects number of results to match mask region yielded values");
- auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
// Empty vector.mask. Nothing else to check.
- if (!maskableOp)
+ if (numMaskRegionOps == 1)
return success();
+ auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
+ if (!maskableOp)
+ return emitOpError("expects a MaskableOpInterface within the mask region");
+
// Result checks.
if (maskableOp->getNumResults() != getNumResults())
return emitOpError("expects number of results to match maskable operation "
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e71a6eb02ea46c..b7c78de4b5bd89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2471,13 +2471,15 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
// -----
// CHECK-LABEL: func @all_true_vector_mask
-// CHECK-SAME: %[[IN:.*]]: vector<3x4xf32>
-func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> {
+// CHECK-SAME: %[[IN:.*]]: tensor<3x4xf32>
+func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {
// CHECK-NOT: vector.mask
-// CHECK: %[[ADD:.*]] = arith.addf %[[IN]], %[[IN]] : vector<3x4xf32>
-// CHECK: return %[[ADD]] : vector<3x4xf32>
+// CHECK: %[[LD:.*]] = vector.transfer_read %[[IN]]
+// CHECK: return %[[LD]] : vector<3x4xf32>
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
%all_true = vector.constant_mask [3, 4] : vector<3x4xi1>
- %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
+ %0 = vector.mask %all_true { vector.transfer_read %ta[%c0, %c0], %cf0 : tensor<3x4xf32>, vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
return %0 : vector<3x4xf32>
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c95b8bd5ed6147..e2bc5ef6128e7d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1724,6 +1724,14 @@ func.func @vector_mask_passthru_no_return(%val: vector<16xf32>, %t0: tensor<?xf3
vector.mask %m0, %pt0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> vector<16xf32>
return
}
+// -----
+
+func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32> {
+ %m0 = vector.constant_mask [2, 2] : vector<3x4xi1>
+ // expected-error@+1 {{'vector.mask' op expects a MaskableOpInterface within the mask region}}
+ %0 = vector.mask %m0 { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
+ return %0 : vector<3x4xf32>
+}
// -----
|
if (block.getOperations().size() > 2) | ||
|
||
unsigned numMaskRegionOps = block.getOperations().size(); | ||
if (numMaskRegionOps > 2) | ||
return emitOpError("expects only one operation to mask"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like MaskOp::getMaskableOp()
is redundantly testing for this condition unnecessarily, and the MaskOpRewritePattern
is then using cast_or_null
where cast
could be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MaskOp can be empty so that's why getMaskableOp()
is checking for something similar to return nullptr
. That's also the reason to use cast_or_null
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh right!
The code is confusing, and not the most efficient: if (block->getOperations().size() < 2)
(size()
on a linked list is O(N)
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, that's why I introduced numMaskRegionOps
instead of calling size()
again. However, a valid vector mask would always have 2 ops at most so I thought it would be acceptable. I could implement similar logic with block.getOperations().empty()
+ compare block.getOperations().begin()
and block.getOperations().end()
but I think it would make it more confusing... WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
// Empty vector.mask. Nothing else to check. | ||
if (!maskableOp) | ||
if (numMaskRegionOps == 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be comparing against 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a terminator/
…lvm#108123) This PR fixes a bug in `MaskOp::verifier` that allowed `vector.mask` to mask operations that did not implement the MaskableOpInterface.
This PR fixes a bug in
MaskOp::verifier
that allowedvector.mask
to mask operations that did not implement the MaskableOpInterface.