-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Fix vector.shuffle
folder for poison indices
#124863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR fixes the folder of a Full diff: https://github.com/llvm/llvm-project/pull/124863.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b35422f4ca3a9f..08f75944d23086 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2673,43 +2673,51 @@ static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
}
OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
- VectorType v1Type = getV1VectorType();
+ auto v1Type = getV1VectorType();
+ auto v2Type = getV2VectorType();
+
+ assert(!v1Type.isScalable() && !v2Type.isScalable() &&
+ "Vector shuffle does not support scalable vectors");
+
// For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
// but must be a canonicalization into a vector.broadcast.
if (v1Type.getRank() == 0)
return {};
- // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
- if (!v1Type.isScalable() &&
- isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
+ // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
+ if (isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
return getV1();
- // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
- if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
- isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
- getV2VectorType().getDimSize(0)))
+ // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
+ if (isStepIndexArray(getMask(), v1Type.getDimSize(0), v2Type.getDimSize(0)))
return getV2();
- Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
- if (!lhs || !rhs)
+ Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
+ if (!v1Attr || !v2Attr)
return {};
- auto lhsType =
- llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
// manipulation.
- if (lhsType.getRank() != 1)
+ if (v1Type.getRank() != 1)
return {};
- int64_t lhsSize = lhsType.getDimSize(0);
+
+ int64_t v1Size = v1Type.getDimSize(0);
SmallVector<Attribute> results;
- auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
- auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
- for (int64_t i : this->getMask()) {
- if (i >= lhsSize) {
- results.push_back(rhsElements[i - lhsSize]);
- } else {
- results.push_back(lhsElements[i]);
+ auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
+ auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
+ for (int64_t maskIdx : this->getMask()) {
+ Attribute indexedElm;
+ // Select v1[0] for poison indices.
+ // TODO: Return a partial poison vector when supported by the UB dialect.
+ if (maskIdx == ShuffleOp::kPoisonIndex) {
+ indexedElm = v1Elements[0];
+ }
+ else {
+ indexedElm =
+ maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
}
+
+ results.push_back(indexedElm);
}
return DenseElementsAttr::get(getResultVectorType(), results);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f9e3b772f9f0a2..070135828de901 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2006,6 +2006,20 @@ func.func @shuffle_1d() -> vector<4xi32> {
return %shuffle : vector<4xi32>
}
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_poison_idx
+// CHECK: %[[V:.+]] = arith.constant dense<[2, 5, 0, 5]> : vector<4xi32>
+// CHECK: return %[[V]]
+func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
+ %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
+ %v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
+ %shuffle = vector.shuffle %v0, %v1 [3, -1, 5, -1] : vector<3xi32>, vector<3xi32>
+ return %shuffle : vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
// CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
@@ -2013,6 +2027,8 @@ func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vect
return %shuffle : vector<1xi32>
}
+// -----
+
// CHECK-LABEL: func @shuffle_fold1
// CHECK: %arg0 : vector<4xi32>
func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi32> {
@@ -2020,6 +2036,8 @@ func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi
return %shuffle : vector<4xi32>
}
+// -----
+
// CHECK-LABEL: func @shuffle_fold2
// CHECK: %arg1 : vector<2xi32>
func.func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi32> {
@@ -2027,6 +2045,8 @@ func.func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi
return %shuffle : vector<2xi32>
}
+// -----
+
// CHECK-LABEL: func @shuffle_fold3
// CHECK: return %arg0 : vector<4x5x6xi32>
func.func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<4x5x6xi32> {
@@ -2034,6 +2054,8 @@ func.func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> ve
return %shuffle : vector<4x5x6xi32>
}
+// -----
+
// CHECK-LABEL: func @shuffle_fold4
// CHECK: return %arg1 : vector<2x5x6xi32>
func.func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<2x5x6xi32> {
@@ -2041,6 +2063,8 @@ func.func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> ve
return %shuffle : vector<2x5x6xi32>
}
+// -----
+
// CHECK-LABEL: func @shuffle_nofold1
// CHECK: %[[V:.+]] = vector.shuffle %arg0, %arg1 [0, 1, 2, 3, 4] : vector<4xi32>, vector<2xi32>
// CHECK: return %[[V]]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
if (!v1Type.isScalable() && | ||
isStepIndexArray(getMask(), 0, v1Type.getDimSize(0))) | ||
// Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1. | ||
if (isStepIndexArray(getMask(), 0, v1Type.getDimSize(0))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: hoist mask and dim sizes into local variables, since these are queried multiple times in this function?
b959c33
to
b3cd0f7
Compare
@@ -2006,41 +2006,65 @@ func.func @shuffle_1d() -> vector<4xi32> { | |||
return %shuffle : vector<4xi32> | |||
} | |||
|
|||
// ----- | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Could you add a comment explaining what happens for poison idxs? It's quite non-obvious 😅🙏🏻
This PR fixes the folder of a `vector.shuffle` with constant input vectors in the presence of a poison index. Partially poison vectors are currently not supported in UB so the folder select v1[0] for elements indexed by poison.
b3cd0f7
to
5df26fc
Compare
#124863 added folding support for poison indices to `vector.shuffle`. This PR adds support for folding `vector.shuffle` ops with one or two poison input vectors.
… (#125608) llvm/llvm-project#124863 added folding support for poison indices to `vector.shuffle`. This PR adds support for folding `vector.shuffle` ops with one or two poison input vectors.
llvm#124863 added folding support for poison indices to `vector.shuffle`. This PR adds support for folding `vector.shuffle` ops with one or two poison input vectors.
This PR fixes the folder of a
vector.shuffle
with constant input vectors in the presence of a poison index. Partially poison vectors are currently not supported in UB so the folder select v1[0] for elements indexed by poison.