Skip to content

[mlir][Vector] Fix vector.shuffle folder for poison indices #124863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 31, 2025

Conversation

dcaballe
Copy link
Contributor

This PR fixes the folder of a vector.shuffle with constant input vectors in the presence of a poison index. Partially poison vectors are currently not supported in UB so the folder select v1[0] for elements indexed by poison.

@llvmbot
Copy link
Member

llvmbot commented Jan 29, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR fixes the folder of a vector.shuffle with constant input vectors in the presence of a poison index. Partially poison vectors are currently not supported in UB so the folder select v1[0] for elements indexed by poison.


Full diff: https://github.com/llvm/llvm-project/pull/124863.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+29-21)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+24)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b35422f4ca3a9f..08f75944d23086 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2673,43 +2673,51 @@ static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
 }
 
 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
-  VectorType v1Type = getV1VectorType();
+  auto v1Type = getV1VectorType();
+  auto v2Type = getV2VectorType();
+
+  assert(!v1Type.isScalable() && !v2Type.isScalable() &&
+         "Vector shuffle does not support scalable vectors");
+
   // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
   // but must be a canonicalization into a vector.broadcast.
   if (v1Type.getRank() == 0)
     return {};
 
-  // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
-  if (!v1Type.isScalable() &&
-      isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
+  // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
+  if (isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
     return getV1();
-  // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
-  if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
-      isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
-                       getV2VectorType().getDimSize(0)))
+  // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
+  if (isStepIndexArray(getMask(), v1Type.getDimSize(0), v2Type.getDimSize(0)))
     return getV2();
 
-  Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
-  if (!lhs || !rhs)
+  Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
+  if (!v1Attr || !v2Attr)
     return {};
 
-  auto lhsType =
-      llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
   // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
   // manipulation.
-  if (lhsType.getRank() != 1)
+  if (v1Type.getRank() != 1)
     return {};
-  int64_t lhsSize = lhsType.getDimSize(0);
+
+  int64_t v1Size = v1Type.getDimSize(0);
 
   SmallVector<Attribute> results;
-  auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
-  auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
-  for (int64_t i : this->getMask()) {
-    if (i >= lhsSize) {
-      results.push_back(rhsElements[i - lhsSize]);
-    } else {
-      results.push_back(lhsElements[i]);
+  auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
+  auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
+  for (int64_t maskIdx : this->getMask()) {
+    Attribute indexedElm;
+    // Select v1[0] for poison indices.
+    // TODO: Return a partial poison vector when supported by the UB dialect.
+    if (maskIdx == ShuffleOp::kPoisonIndex) {
+      indexedElm = v1Elements[0];
+    }
+    else {
+      indexedElm =
+          maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
     }
+
+    results.push_back(indexedElm);
   }
 
   return DenseElementsAttr::get(getResultVectorType(), results);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f9e3b772f9f0a2..070135828de901 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2006,6 +2006,20 @@ func.func @shuffle_1d() -> vector<4xi32> {
   return %shuffle : vector<4xi32>
 }
 
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_poison_idx
+//       CHECK:   %[[V:.+]] = arith.constant dense<[2, 5, 0, 5]> : vector<4xi32>
+//       CHECK:   return %[[V]]
+func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
+  %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
+  %v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
+  %shuffle = vector.shuffle %v0, %v1 [3, -1, 5, -1] : vector<3xi32>, vector<3xi32>
+  return %shuffle : vector<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
   // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
@@ -2013,6 +2027,8 @@ func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vect
   return %shuffle : vector<1xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @shuffle_fold1
 //       CHECK:   %arg0 : vector<4xi32>
 func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi32> {
@@ -2020,6 +2036,8 @@ func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi
   return %shuffle : vector<4xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @shuffle_fold2
 //       CHECK:   %arg1 : vector<2xi32>
 func.func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi32> {
@@ -2027,6 +2045,8 @@ func.func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi
   return %shuffle : vector<2xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @shuffle_fold3
 //       CHECK:   return %arg0 : vector<4x5x6xi32>
 func.func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<4x5x6xi32> {
@@ -2034,6 +2054,8 @@ func.func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> ve
   return %shuffle : vector<4x5x6xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @shuffle_fold4
 //       CHECK:   return %arg1 : vector<2x5x6xi32>
 func.func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<2x5x6xi32> {
@@ -2041,6 +2063,8 @@ func.func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> ve
   return %shuffle : vector<2x5x6xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @shuffle_nofold1
 //       CHECK:   %[[V:.+]] = vector.shuffle %arg0, %arg1 [0, 1, 2, 3, 4] : vector<4xi32>, vector<2xi32>
 //       CHECK:   return %[[V]]

Copy link

github-actions bot commented Jan 29, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

if (!v1Type.isScalable() &&
isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
// Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
if (isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: hoist mask and dim sizes into local variables, since these are queried multiple times in this function?

@dcaballe dcaballe force-pushed the vector-shuffle-poison-fold branch from b959c33 to b3cd0f7 Compare January 31, 2025 00:21
@@ -2006,41 +2006,65 @@ func.func @shuffle_1d() -> vector<4xi32> {
return %shuffle : vector<4xi32>
}

// -----

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Could you add a comment explaining what happens for poison idxs? It's quite non-obvious 😅🙏🏻

This PR fixes the folder of a `vector.shuffle` with constant input vectors
in the presence of a poison index. Partially poison vectors are currently
not supported in UB so the folder select v1[0] for elements indexed by poison.
@dcaballe dcaballe force-pushed the vector-shuffle-poison-fold branch from b3cd0f7 to 5df26fc Compare January 31, 2025 22:44
@dcaballe dcaballe merged commit c3c3262 into llvm:main Jan 31, 2025
6 of 7 checks passed
dcaballe added a commit that referenced this pull request Feb 5, 2025
#124863 added folding support
for poison indices to `vector.shuffle`. This PR adds support for folding
`vector.shuffle` ops with one or two poison input vectors.
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Feb 5, 2025
… (#125608)

llvm/llvm-project#124863 added folding support
for poison indices to `vector.shuffle`. This PR adds support for folding
`vector.shuffle` ops with one or two poison input vectors.
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
llvm#124863 added folding support
for poison indices to `vector.shuffle`. This PR adds support for folding
`vector.shuffle` ops with one or two poison input vectors.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants