Skip to content

Commit f981ee7

Browse files
authored
[MLIR] extend getCompressedMaskOp support in VectorEmulateNarrowType (#116122)
Previously when `numFrontPadElems` is not zero, `getCompressedMaskOp` produces wrong result if the mask generator op is a `vector.create_mask`. This patch resolves the issue by including `numFrontPadElems` into the mask generation. Signed-off-by: Alan Li <[email protected]>
1 parent 066dd91 commit f981ee7

File tree

3 files changed

+65
-13
lines changed

3 files changed

+65
-13
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,16 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
110110
.Case<vector::CreateMaskOp>(
111111
[&](auto createMaskOp) -> std::optional<Operation *> {
112112
OperandRange maskOperands = createMaskOp.getOperands();
113-
size_t numMaskOperands = maskOperands.size();
113+
// The `vector.create_mask` op creates a mask arrangement
114+
// without any zeros at the front. Also, because
115+
// `numFrontPadElems` is strictly smaller than
116+
// `numSrcElemsPerDest`, the compressed mask generated by
117+
// padding the original mask by `numFrontPadElems` will not
118+
// have any zeros at the front as well.
114119
AffineExpr s0;
115120
bindSymbols(rewriter.getContext(), s0);
116-
s0 = s0 + numSrcElemsPerDest - 1;
117-
s0 = s0.floorDiv(numSrcElemsPerDest);
118-
OpFoldResult origIndex =
119-
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
121+
s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
122+
OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back());
120123
OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply(
121124
rewriter, loc, s0, origIndex);
122125
SmallVector<Value> newMaskOperands(maskOperands.drop_back());

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
4242

4343
// -----
4444

45-
func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
45+
func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
4646
%0 = memref.alloc() : memref<3x5xi2>
4747
%cst = arith.constant dense<0> : vector<3x5xi2>
4848
%mask = vector.constant_mask [3] : vector<5xi1>
@@ -54,7 +54,7 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
5454
return %2 : vector<3x5xi2>
5555
}
5656

57-
// CHECK-LABEL: func @vector_cst_maskedload_i2(
57+
// CHECK-LABEL: func @vector_constant_mask_maskedload_i2(
5858
// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
5959
// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
6060
// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
@@ -74,6 +74,55 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
7474

7575
// -----
7676

77+
// This tests the correctness of generating compressed mask with `vector.create_mask` on a static input and dynamic indices.
78+
// Specifically, the program masked loads a vector<5xi2> from `vector<3x5xi2>[1, 0]`, with an unknown mask generator `m`.
79+
// After emulation transformation, it masked loads 2 bytes from linearized index `vector<4xi8>[1]`, with a new compressed mask
80+
// given by `ceildiv(m + 1, 4)`.
81+
func.func @unaligned_create_mask_dynamic_i2(%m : index, %passthru: vector<5xi2>) -> vector<5xi2> {
82+
%0 = memref.alloc() : memref<3x5xi2>
83+
%c0 = arith.constant 0 : index
84+
%c1 = arith.constant 1 : index
85+
%mask = vector.create_mask %m : vector<5xi1>
86+
%1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
87+
memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
88+
return %1 : vector<5xi2>
89+
}
90+
91+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 + 1) ceildiv 4)>
92+
// CHECK: func @unaligned_create_mask_dynamic_i2(
93+
// CHECK-SAME: %[[NUM_ELEMS_TO_LOAD:.+]]: index, %[[PASSTHRU:.+]]: vector<5xi2>)
94+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
95+
// CHECK: %[[COMPRESSED_MASK:.+]] = affine.apply #map()[%[[NUM_ELEMS_TO_LOAD]]]
96+
// CHECK: vector.create_mask %[[COMPRESSED_MASK]] : vector<2xi1>
97+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
98+
// CHECK: vector.maskedload %[[ALLOC]][%[[C1]]]
99+
100+
// -----
101+
102+
// This tests the correctness of generated compressed mask with `vector.create_mask`, and a static input.
103+
// Quite the same as the previous test, but the mask generator is a static value.
104+
// In this case, the desired slice `vector<7xi2>` spans over 3 bytes.
105+
func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vector<7xi2> {
106+
%0 = memref.alloc() : memref<3x7xi2>
107+
%c0 = arith.constant 0 : index
108+
%c1 = arith.constant 1 : index
109+
%c3 = arith.constant 3 : index
110+
%mask = vector.create_mask %c3 : vector<7xi1>
111+
%1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
112+
memref<3x7xi2>, vector<7xi1>, vector<7xi2> into vector<7xi2>
113+
return %1 : vector<7xi2>
114+
}
115+
116+
// CHECK: func @check_unaligned_create_mask_static_i2(
117+
// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]: vector<7xi2>)
118+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
119+
// CHECK: %[[C2:.+]] = arith.constant 2 : index
120+
// CHECK: %[[COMP_MASK:.+]] = vector.create_mask %[[C2]] : vector<3xi1>
121+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
122+
// CHECK: %4 = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMP_MASK]]
123+
124+
// -----
125+
77126
func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
78127
%0 = memref.alloc() : memref<3x3xi2>
79128
%cst = arith.constant dense<0> : vector<3x3xi2>

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passt
141141
// CHECK-NEXT: return
142142

143143
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
144-
// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
144+
// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
145145
// CHECK32: func @vector_maskedload_i8(
146146
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
147147
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
@@ -169,7 +169,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
169169
return %2 : vector<3x8xi4>
170170
}
171171
// CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
172-
// CHECK-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
172+
// CHECK-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
173173
// CHECK: func @vector_maskedload_i4(
174174
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
175175
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>)
@@ -185,7 +185,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
185185
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4>
186186

187187
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
188-
// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
188+
// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
189189
// CHECK32: func @vector_maskedload_i4(
190190
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
191191
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>)
@@ -497,7 +497,7 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
497497
// CHECK-NEXT: return
498498

499499
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
500-
// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
500+
// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
501501
// CHECK32: func @vector_maskedstore_i8(
502502
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]
503503
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]
@@ -530,7 +530,7 @@ func.func @vector_maskedstore_i4(
530530
return
531531
}
532532
// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
533-
// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
533+
// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
534534

535535
// CHECK-LABEL: func.func @vector_maskedstore_i4(
536536
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
@@ -550,7 +550,7 @@ func.func @vector_maskedstore_i4(
550550
// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
551551

552552
// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
553-
// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
553+
// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
554554

555555
// CHECK32-LABEL: func.func @vector_maskedstore_i4(
556556
// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,

0 commit comments

Comments
 (0)