Skip to content

Commit bcd65ba

Browse files
authored
[mlir][Vector] Verify that masked ops implement MaskableOpInterface (#108123)
This PR fixes a bug in `MaskOp::verifier` that allowed `vector.mask` to mask operations that did not implement the MaskableOpInterface.
1 parent 8a34f6d commit bcd65ba

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6131,7 +6131,9 @@ LogicalResult MaskOp::verify() {
61316131
Block &block = getMaskRegion().getBlocks().front();
61326132
if (block.getOperations().empty())
61336133
return emitOpError("expects a terminator within the mask region");
6134-
if (block.getOperations().size() > 2)
6134+
6135+
unsigned numMaskRegionOps = block.getOperations().size();
6136+
if (numMaskRegionOps > 2)
61356137
return emitOpError("expects only one operation to mask");
61366138

61376139
// Terminator checks.
@@ -6143,11 +6145,14 @@ LogicalResult MaskOp::verify() {
61436145
return emitOpError(
61446146
"expects number of results to match mask region yielded values");
61456147

6146-
auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
61476148
// Empty vector.mask. Nothing else to check.
6148-
if (!maskableOp)
6149+
if (numMaskRegionOps == 1)
61496150
return success();
61506151

6152+
auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
6153+
if (!maskableOp)
6154+
return emitOpError("expects a MaskableOpInterface within the mask region");
6155+
61516156
// Result checks.
61526157
if (maskableOp->getNumResults() != getNumResults())
61536158
return emitOpError("expects number of results to match maskable operation "

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2471,13 +2471,15 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
24712471
// -----
24722472

24732473
// CHECK-LABEL: func @all_true_vector_mask
2474-
// CHECK-SAME: %[[IN:.*]]: vector<3x4xf32>
2475-
func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> {
2474+
// CHECK-SAME: %[[IN:.*]]: tensor<3x4xf32>
2475+
func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {
24762476
// CHECK-NOT: vector.mask
2477-
// CHECK: %[[ADD:.*]] = arith.addf %[[IN]], %[[IN]] : vector<3x4xf32>
2478-
// CHECK: return %[[ADD]] : vector<3x4xf32>
2477+
// CHECK: %[[LD:.*]] = vector.transfer_read %[[IN]]
2478+
// CHECK: return %[[LD]] : vector<3x4xf32>
2479+
%c0 = arith.constant 0 : index
2480+
%cf0 = arith.constant 0.0 : f32
24792481
%all_true = vector.constant_mask [3, 4] : vector<3x4xi1>
2480-
%0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
2482+
%0 = vector.mask %all_true { vector.transfer_read %ta[%c0, %c0], %cf0 : tensor<3x4xf32>, vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
24812483
return %0 : vector<3x4xf32>
24822484
}
24832485

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,6 +1724,14 @@ func.func @vector_mask_passthru_no_return(%val: vector<16xf32>, %t0: tensor<?xf3
17241724
vector.mask %m0, %pt0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> vector<16xf32>
17251725
return
17261726
}
1727+
// -----
1728+
1729+
func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32> {
1730+
%m0 = vector.constant_mask [2, 2] : vector<3x4xi1>
1731+
// expected-error@+1 {{'vector.mask' op expects a MaskableOpInterface within the mask region}}
1732+
%0 = vector.mask %m0 { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
1733+
return %0 : vector<3x4xf32>
1734+
}
17271735

17281736
// -----
17291737

0 commit comments

Comments
 (0)