Skip to content

Commit 6da9abb

Browse files
committed
Review feedback
1 parent 30fc0e2 commit 6da9abb

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2485,10 +2485,9 @@ def Vector_MaskOp : Vector_Op<"mask", [
24852485
terminator yields all results from the maskable operation to the result of
24862486
this operation. No other values are allowed to be yielded.
24872487

2488-
An empty `vector.mask` operation is considered ill-formed but legal to
2489-
facilitate optimizations across the `vector.mask` operation. It is considered
2490-
a no-op regardless of its returned values and will be removed by the
2491-
canonicalizer.
2488+
An empty `vector.mask` operation is legal to facilitate optimizations across
2489+
the `vector.mask` operation. However, it is considered a no-op regardless of
2490+
its returned values and will be removed by the canonicalizer.
24922491

24932492
The vector mask argument holds a bit for each vector lane and determines
24942493
which vector lanes should execute the maskable operation and which ones

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6543,18 +6543,20 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
65436543
}
65446544

65456545
void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
6546-
// Create default terminator if there are no ops to mask.
6546+
// 1. For an empty `vector.mask`, create a default terminator.
65476547
if (region.empty() || region.front().empty()) {
65486548
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
65496549
MaskOp>::ensureTerminator(region, builder, loc);
65506550
return;
65516551
}
65526552

6553-
// If region has an explicit terminator, we don't modify it.
6553+
// 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
65546554
Block &block = region.front();
65556555
if (isa<vector::YieldOp>(block.back()))
65566556
return;
65576557

6558+
// 3. For a non-empty `vector.mask` without an explicit terminator:
6559+
65586560
// Create default terminator if the number of masked operations is not
65596561
// one. This case will trigger a verification failure.
65606562
if (block.getOperations().size() != 1) {
@@ -6603,9 +6605,8 @@ LogicalResult MaskOp::verify() {
66036605
"number of results");
66046606

66056607
if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
6606-
return emitOpError(
6607-
"expects all the results from the MaskableOpInterface to "
6608-
"be returned by the terminator");
6608+
return emitOpError("expects all the results from the MaskableOpInterface "
6609+
"to match all the values returned by the terminator");
66096610

66106611
if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
66116612
return emitOpError(

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,12 +1747,12 @@ func.func @vector_mask_0d_mask(%arg0: tensor<2x4xi32>,
17471747

17481748
// -----
17491749

1750-
func.func @vector_mask_non_empty_external_return(%t0: tensor<?xf32>, %idx: index,
1751-
%m0: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> {
1750+
func.func @vector_mask_non_empty_external_return(%t: tensor<?xf32>, %idx: index,
1751+
%m: vector<16xi1>, %ext: vector<16xf32>) -> vector<16xf32> {
17521752
%ft0 = arith.constant 0.0 : f32
1753-
// expected-error@+1 {{'vector.mask' op expects all the results from the MaskableOpInterface to be returned by the terminator}}
1754-
%0 = vector.mask %m0 {
1755-
%1 =vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
1753+
// expected-error@+1 {{'vector.mask' op expects all the results from the MaskableOpInterface to match all the values returned by the terminator}}
1754+
%0 = vector.mask %m {
1755+
%1 =vector.transfer_read %t[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
17561756
vector.yield %ext : vector<16xf32>
17571757
} : vector<16xi1> -> vector<16xf32>
17581758

@@ -1761,12 +1761,12 @@ func.func @vector_mask_non_empty_external_return(%t0: tensor<?xf32>, %idx: index
17611761

17621762
// -----
17631763

1764-
func.func @vector_mask_non_empty_mixed_return(%t0: tensor<?xf32>, %idx: index,
1765-
%m0: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) {
1764+
func.func @vector_mask_non_empty_mixed_return(%t: tensor<?xf32>, %idx: index,
1765+
%m: vector<16xi1>, %ext: vector<16xf32>) -> (vector<16xf32>, vector<16xf32>) {
17661766
%ft0 = arith.constant 0.0 : f32
17671767
// expected-error@+1 {{'vector.mask' op expects number of results to match maskable operation number of results}}
1768-
%0:2 = vector.mask %m0 {
1769-
%1 =vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
1768+
%0:2 = vector.mask %m {
1769+
%1 =vector.transfer_read %t[%idx], %ft0 : tensor<?xf32>, vector<16xf32>
17701770
vector.yield %1, %ext : vector<16xf32>, vector<16xf32>
17711771
} : vector<16xi1> -> (vector<16xf32>, vector<16xf32>)
17721772

0 commit comments

Comments
 (0)