Skip to content

[MLIR] extend getCompressedMaskOp support in VectorEmulateNarrowType #116122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 20, 2024

Conversation

lialan
Copy link
Member

@lialan lialan commented Nov 13, 2024

Previously when numFrontPadElems is not zero, getCompressedMaskOp produces wrong result if the mask generator op is a vector.create_mask.

This patch resolves the issue by including numFrontPadElems into the mask generation.

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: lialan (lialan)

Changes

Previously when numFrontPadElems is not zero, getCompressedMaskOp produces wrong result if the mask generator op is a vector.create_mask.

This patch resolves the issue by including numFrontPadElems into the mask generation.


Full diff: https://github.com/llvm/llvm-project/pull/116122.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+6-2)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+49)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+6-6)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e5f2a847994aee..0b5b8e0559cd2b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -104,10 +104,14 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   if (createMaskOp) {
     OperandRange maskOperands = createMaskOp.getOperands();
     size_t numMaskOperands = maskOperands.size();
+    // The `vector.create_mask` op creates a mask arrangement without any zeros
+    // at the front. Also, because `numFrontPadElems` is strictly smaller than
+    // `numSrcElemsPerDest`, the compressed mask generated by shifting the
+    // original mask by `numFrontPadElems` will not have any zeros at the front
+    // as well.
     AffineExpr s0;
     bindSymbols(rewriter.getContext(), s0);
-    s0 = s0 + numSrcElemsPerDest - 1;
-    s0 = s0.floorDiv(numSrcElemsPerDest);
+    s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest);
     OpFoldResult origIndex =
         getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
     OpFoldResult maskIndex =
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ed75ff7f1579c..327364ce820da7 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -74,6 +74,55 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
 
 // -----
 
+// This tests the correctness of generating compressed mask with `vector.create_mask` and a dynamic input.
+// Specifically, the program masked loads a vector<5xi2> from `vector<3x5xi2>[1, 0]`, with an unknown mask generator `m`.
+// After emulation transformation, it masked loads 2 bytes from linearized index `vector<4xi8>[1]`, with a new compressed mask
+// given by `ceildiv(m + 1, 4)`.
+func.func @check_unaligned_create_mask_dynamic_i2(%m : index, %passthru: vector<5xi2>) -> vector<5xi2> {
+    %0 = memref.alloc() : memref<3x5xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %mask = vector.create_mask %m : vector<5xi1>
+    %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
+      memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+    return %1 : vector<5xi2>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 + 1) ceildiv 4)>
+// CHECK: func @check_unaligned_create_mask_dynamic_i2(
+// CHECK-SAME:     %[[MASK:.+]]: index, %[[PASSTHRU:.+]]: vector<5xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
+// CHECK: %[[COMP_MASK:.+]] = affine.apply #map()[%[[MASK]]]
+// CHECK: vector.create_mask %[[COMP_MASK]] : vector<2xi1>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: vector.maskedload %[[ALLOC]][%[[C1]]]
+
+// -----
+
+// This tests the correctness of generated compressed mask with `vector.create_mask`, and a static input.
+// Quite the same as the previous test, but the mask generator is a static value.
+// In this case, the desired slice `vector<7xi2>` spans over 3 bytes.
+func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vector<7xi2> {
+    %0 = memref.alloc() : memref<3x7xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c3 = arith.constant 3 : index
+    %mask = vector.create_mask %c3 : vector<7xi1>
+    %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
+      memref<3x7xi2>, vector<7xi1>, vector<7xi2> into vector<7xi2>
+    return %1 : vector<7xi2>
+}
+
+// CHECK: func @check_unaligned_create_mask_static_i2(
+// CHECK-SAME:     %[[PASSTHRU:[a-zA-Z0-9]+]]: vector<7xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[COMP_MASK:.+]] = vector.create_mask %[[C2]] : vector<3xi1>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %4 = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMP_MASK]]
+
+// -----
+
 func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
   %0 = memref.alloc() : memref<3x3xi2>
   %cst = arith.constant dense<0> : vector<3x3xi2>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 034bd47f6163e6..c68909061d8f3c 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -141,7 +141,7 @@ func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passt
 // CHECK-NEXT:   return
 
 //  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
-//  CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
+//  CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
 //      CHECK32: func @vector_maskedload_i8(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
@@ -169,7 +169,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
     return %2 : vector<3x8xi4>
 }
 //  CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
-//  CHECK-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+//  CHECK-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
 //      CHECK: func @vector_maskedload_i4(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>)
@@ -185,7 +185,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
 //      CHECK:   %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4>
 
 //  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
-//  CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+//  CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
 //      CHECK32: func @vector_maskedload_i4(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: vector<8xi4>)
@@ -473,7 +473,7 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
 // CHECK-NEXT:   return
 
 // CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
-// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
+// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
 // CHECK32:     func @vector_maskedstore_i8(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
@@ -506,7 +506,7 @@ func.func @vector_maskedstore_i4(
     return
 }
 // CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
-// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
 
 // CHECK-LABEL:   func.func @vector_maskedstore_i4(
 // CHECK-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
@@ -526,7 +526,7 @@ func.func @vector_maskedstore_i4(
 // CHECK:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
 
 // CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
-// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
 
 // CHECK32-LABEL:   func.func @vector_maskedstore_i4(
 // CHECK32-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,

@hanhanW hanhanW requested a review from banach-space November 14, 2024 00:27
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks! Please give some time for others to review.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updated logic makes sense, thanks! A few small suggestions.

@lialan lialan force-pushed the lialan/fix_create_mask branch from b129fc0 to f251be0 Compare November 14, 2024 23:03
@lialan lialan requested a review from banach-space November 14, 2024 23:07
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Alan, it looks good to me. I'll help land the PR once we get an approval from @banach-space

@lialan lialan force-pushed the lialan/fix_create_mask branch from f251be0 to d960ae8 Compare November 19, 2024 07:14
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % a couple of nits

Previously when `numFrontPadElems`is not zero, `getCompressedMaskOp` produces
wrong result if the mask generator op is `vector.create_mask`. This patch resolves
such issue when `numFrontPadElems` is not zero.

Signed-off-by: Alan Li <[email protected]>
@lialan lialan force-pushed the lialan/fix_create_mask branch from d960ae8 to 9e6310a Compare November 19, 2024 17:03
@hanhanW hanhanW merged commit f981ee7 into llvm:main Nov 20, 2024
8 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 20, 2024

LLVM Buildbot has detected a new failure on builder clang-arm64-windows-msvc running on linaro-armv8-windows-msvc-04 while building mlir at step 5 "ninja check 1".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/161/builds/3285

Here is the relevant piece of the build log for the reference
Step 5 (ninja check 1) failure: stage 1 checked (failure)
******************** TEST 'Clang-Unit :: DirectoryWatcher/./DirectoryWatcherTests.exe/5/8' FAILED ********************
Script(shard):
--
GTEST_OUTPUT=json:C:\Users\tcwg\llvm-worker\clang-arm64-windows-msvc\stage1\tools\clang\unittests\DirectoryWatcher\.\DirectoryWatcherTests.exe-Clang-Unit-28880-5-8.json GTEST_SHUFFLE=0 GTEST_TOTAL_SHARDS=8 GTEST_SHARD_INDEX=5 C:\Users\tcwg\llvm-worker\clang-arm64-windows-msvc\stage1\tools\clang\unittests\DirectoryWatcher\.\DirectoryWatcherTests.exe
--

Script:
--
C:\Users\tcwg\llvm-worker\clang-arm64-windows-msvc\stage1\tools\clang\unittests\DirectoryWatcher\.\DirectoryWatcherTests.exe --gtest_filter=DirectoryWatcherTest.DeleteWatchedDir
--
C:/Users/tcwg/llvm-worker/clang-arm64-windows-msvc/llvm/clang/unittests/DirectoryWatcher/DirectoryWatcherTest.cpp(259): error: Value of: WaitForExpectedStateResult.wait_for(EventualResultTimeout) == std::future_status::ready
  Actual: false
Expected: true
The expected result state wasn't reached before the time-out.

C:/Users/tcwg/llvm-worker/clang-arm64-windows-msvc/llvm/clang/unittests/DirectoryWatcher/DirectoryWatcherTest.cpp(262): error: Value of: TestConsumer.result().has_value()
  Actual: false
Expected: true


C:/Users/tcwg/llvm-worker/clang-arm64-windows-msvc/llvm/clang/unittests/DirectoryWatcher/DirectoryWatcherTest.cpp:259
Value of: WaitForExpectedStateResult.wait_for(EventualResultTimeout) == std::future_status::ready
  Actual: false
Expected: true
The expected result state wasn't reached before the time-out.

C:/Users/tcwg/llvm-worker/clang-arm64-windows-msvc/llvm/clang/unittests/DirectoryWatcher/DirectoryWatcherTest.cpp:262
Value of: TestConsumer.result().has_value()
  Actual: false
Expected: true



********************


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants