-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] extend getCompressedMaskOp
support in VectorEmulateNarrowType
#116122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: lialan (lialan) ChangesPreviously when This patch resolves the issue by including Full diff: https://github.com/llvm/llvm-project/pull/116122.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e5f2a847994aee..0b5b8e0559cd2b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -104,10 +104,14 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
if (createMaskOp) {
OperandRange maskOperands = createMaskOp.getOperands();
size_t numMaskOperands = maskOperands.size();
+ // The `vector.create_mask` op creates a mask arrangement without any zeros
+ // at the front. Also, because `numFrontPadElems` is strictly smaller than
+ // `numSrcElemsPerDest`, the compressed mask generated by shifting the
+ // original mask by `numFrontPadElems` will not have any zeros at the front
+ // as well.
AffineExpr s0;
bindSymbols(rewriter.getContext(), s0);
- s0 = s0 + numSrcElemsPerDest - 1;
- s0 = s0.floorDiv(numSrcElemsPerDest);
+ s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
OpFoldResult origIndex =
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
OpFoldResult maskIndex =
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ed75ff7f1579c..327364ce820da7 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -74,6 +74,55 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
// -----
+// This tests the correctness of generating compressed mask with `vector.create_mask` and a dynamic input.
+// Specifically, the program masked loads a vector<5xi2> from `vector<3x5xi2>[1, 0]`, with an unknown mask generator `m`.
+// After emulation transformation, it masked loads 2 bytes from linearized index `vector<4xi8>[1]`, with a new compressed mask
+// given by `ceildiv(m + 1, 4)`.
+func.func @check_unaligned_create_mask_dynamic_i2(%m : index, %passthru: vector<5xi2>) -> vector<5xi2> {
+ %0 = memref.alloc() : memref<3x5xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %mask = vector.create_mask %m : vector<5xi1>
+ %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
+ memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+ return %1 : vector<5xi2>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 + 1) ceildiv 4)>
+// CHECK: func @check_unaligned_create_mask_dynamic_i2(
+// CHECK-SAME: %[[MASK:.+]]: index, %[[PASSTHRU:.+]]: vector<5xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
+// CHECK: %[[COMP_MASK:.+]] = affine.apply #map()[%[[MASK]]]
+// CHECK: vector.create_mask %[[COMP_MASK]] : vector<2xi1>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: vector.maskedload %[[ALLOC]][%[[C1]]]
+
+// -----
+
+// This tests the correctness of generated compressed mask with `vector.create_mask`, and a static input.
+// Quite the same as the previous test, but the mask generator is a static value.
+// In this case, the desired slice `vector<7xi2>` spans over 3 bytes.
+func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vector<7xi2> {
+ %0 = memref.alloc() : memref<3x7xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %mask = vector.create_mask %c3 : vector<7xi1>
+ %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
+ memref<3x7xi2>, vector<7xi1>, vector<7xi2> into vector<7xi2>
+ return %1 : vector<7xi2>
+}
+
+// CHECK: func @check_unaligned_create_mask_static_i2(
+// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]: vector<7xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[COMP_MASK:.+]] = vector.create_mask %[[C2]] : vector<3xi1>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %4 = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMP_MASK]]
+
+// -----
+
func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
%0 = memref.alloc() : memref<3x3xi2>
%cst = arith.constant dense<0> : vector<3x3xi2>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 034bd47f6163e6..c68909061d8f3c 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -141,7 +141,7 @@ func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passt
// CHECK-NEXT: return
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
-// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
+// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
// CHECK32: func @vector_maskedload_i8(
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
@@ -169,7 +169,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
return %2 : vector<3x8xi4>
}
// CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
-// CHECK-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+// CHECK-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
// CHECK: func @vector_maskedload_i4(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>)
@@ -185,7 +185,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4>
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
-// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
// CHECK32: func @vector_maskedload_i4(
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>)
@@ -473,7 +473,7 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
// CHECK-NEXT: return
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
-// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
+// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
// CHECK32: func @vector_maskedstore_i8(
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]
@@ -506,7 +506,7 @@ func.func @vector_maskedstore_i4(
return
}
// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
-// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
// CHECK-LABEL: func.func @vector_maskedstore_i4(
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
@@ -526,7 +526,7 @@ func.func @vector_maskedstore_i4(
// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
-// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
// CHECK32-LABEL: func.func @vector_maskedstore_i4(
// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, thanks! Please give some time for others to review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The updated logic makes sense, thanks! A few small suggestions.
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
b129fc0
to
f251be0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Alan, it looks good to me. I'll help land the PR once we get an approval from @banach-space
f251be0
to
d960ae8
Compare
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % a couple of nits
Previously when `numFrontPadElems`is not zero, `getCompressedMaskOp` produces wrong result if the mask generator op is `vector.create_mask`. This patch resolves such issue when `numFrontPadElems` is not zero. Signed-off-by: Alan Li <[email protected]>
d960ae8
to
9e6310a
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/161/builds/3285 Here is the relevant piece of the build log for the reference
|
Previously when
numFrontPadElems
is not zero,getCompressedMaskOp
produces wrong result if the mask generator op is avector.create_mask
.This patch resolves the issue by including
numFrontPadElems
into the mask generation.