Skip to content

Commit 3291372

Browse files
authored
[mlir][vector] Fix 0-d vector transfer mask inference (#116526)
When inferring the mask of a transfer operation that results in a single `i1` element, we could represent it using either `vector<i1>` or vector<1xi1>. To avoid type mismatches, this PR updates the mask inference logic to consistently generate `vector<1xi1>` for these cases. We can enable 0-D masks if they are needed in the future. See: #116197
1 parent 197fb27 commit 3291372

File tree

4 files changed

+37
-1
lines changed

4 files changed

+37
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2475,7 +2475,9 @@ def Vector_MaskOp : Vector_Op<"mask", [
24752475
should not. The `vector.mask` operation returns the value produced by the
24762476
masked execution of the nested operation, if any. The masked-off lanes in
24772477
the result vector are taken from the corresponding lanes of the pass-thru
2478-
argument, if provided, or left unmodified, otherwise.
2478+
argument, if provided, or left unmodified, otherwise. At this point, 0-D
2479+
vectors are not supported by `vector.mask`. They may be supported in the
2480+
future.
24792481

24802482
The `vector.mask` operation does not prescribe how a maskable operation
24812483
should be masked or how a masked operation should be lowered. Masking

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4122,6 +4122,11 @@ VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
41224122
assert(invPermMap && "Inversed permutation map couldn't be computed");
41234123
SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
41244124

4125+
// The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4126+
// 0-D mask into a single-element 1-D mask.
4127+
if (maskShape.empty())
4128+
maskShape.push_back(1);
4129+
41254130
SmallVector<bool> scalableDims =
41264131
applyPermutationMap(invPermMap, vecType.getScalableDims());
41274132

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1752,6 +1752,21 @@ func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32>
17521752

17531753
// -----
17541754

1755+
func.func @vector_mask_0d_mask(%arg0: tensor<2x4xi32>,
1756+
%idx0: index, %idx1: index,
1757+
%m0: vector<i1>) -> vector<1x1x4xi32> {
1758+
%cst = arith.constant 0 : i32
1759+
// expected-error@+1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'}}
1760+
%res = vector.mask %m0 {
1761+
%0 = vector.transfer_read %arg0[%idx0, %idx1], %cst {permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}
1762+
: tensor<2x4xi32>, vector<1x1x4xi32>
1763+
vector.yield %0 : vector<1x1x4xi32>
1764+
} : vector<i1> -> vector<1x1x4xi32>
1765+
return %res : vector<1x1x4xi32>
1766+
}
1767+
1768+
// -----
1769+
17551770
func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
17561771
// expected-error@+1 {{op failed to verify that position is a multiple of the source length.}}
17571772
%0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,20 @@ func.func @vector_mask_empty_return(%m0: vector<16xi1>, %arg0: vector<16xf32>) -
10281028
return %0 : vector<16xf32>
10291029
}
10301030

1031+
// CHECK-LABEL: func @vector_mask_scalar_broadcast_transfer
1032+
func.func @vector_mask_scalar_broadcast_transfer(%arg0: tensor<2x4xi32>,
1033+
%idx0: index, %idx1: index,
1034+
%m0: vector<1xi1>) -> vector<1x1x4xi32> {
1035+
%cst = arith.constant 0 : i32
1036+
// CHECK: vector.mask %{{.*}} { vector.transfer_read {{.*}} } : vector<1xi1> -> vector<1x1x4xi32>
1037+
%res = vector.mask %m0 {
1038+
%0 = vector.transfer_read %arg0[%idx0, %idx1], %cst {permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}
1039+
: tensor<2x4xi32>, vector<1x1x4xi32>
1040+
vector.yield %0 : vector<1x1x4xi32>
1041+
} : vector<1xi1> -> vector<1x1x4xi32>
1042+
return %res : vector<1x1x4xi32>
1043+
}
1044+
10311045
// CHECK-LABEL: func @vector_scalable_insert(
10321046
// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
10331047
// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>

0 commit comments

Comments
 (0)