Skip to content

Commit c3c3262

Browse files
authored
[mlir][Vector] Fix vector.shuffle folder for poison indices (#124863)
This PR fixes the folder of a `vector.shuffle` with constant input vectors in the presence of a poison index. Partially poison vectors are currently not supported in UB so the folder select v1[0] for elements indexed by poison.
1 parent db7e2e5 commit c3c3262

File tree

2 files changed

+55
-20
lines changed

2 files changed

+55
-20
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2673,43 +2673,51 @@ static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
26732673
}
26742674

26752675
OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2676-
VectorType v1Type = getV1VectorType();
2676+
auto v1Type = getV1VectorType();
2677+
auto v2Type = getV2VectorType();
2678+
2679+
assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2680+
"Vector shuffle does not support scalable vectors");
2681+
26772682
// For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
26782683
// but must be a canonicalization into a vector.broadcast.
26792684
if (v1Type.getRank() == 0)
26802685
return {};
26812686

2682-
// fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2683-
if (!v1Type.isScalable() &&
2684-
isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
2687+
// Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
2688+
auto mask = getMask();
2689+
if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
26852690
return getV1();
2686-
// fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2687-
if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2688-
isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
2689-
getV2VectorType().getDimSize(0)))
2691+
// Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
2692+
if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
26902693
return getV2();
26912694

2692-
Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2693-
if (!lhs || !rhs)
2695+
Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
2696+
if (!v1Attr || !v2Attr)
26942697
return {};
26952698

2696-
auto lhsType =
2697-
llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
26982699
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
26992700
// manipulation.
2700-
if (lhsType.getRank() != 1)
2701+
if (v1Type.getRank() != 1)
27012702
return {};
2702-
int64_t lhsSize = lhsType.getDimSize(0);
2703+
2704+
int64_t v1Size = v1Type.getDimSize(0);
27032705

27042706
SmallVector<Attribute> results;
2705-
auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
2706-
auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
2707-
for (int64_t i : this->getMask()) {
2708-
if (i >= lhsSize) {
2709-
results.push_back(rhsElements[i - lhsSize]);
2707+
auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
2708+
auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
2709+
for (int64_t maskIdx : mask) {
2710+
Attribute indexedElm;
2711+
// Select v1[0] for poison indices.
2712+
// TODO: Return a partial poison vector when supported by the UB dialect.
2713+
if (maskIdx == ShuffleOp::kPoisonIndex) {
2714+
indexedElm = v1Elements[0];
27102715
} else {
2711-
results.push_back(lhsElements[i]);
2716+
indexedElm =
2717+
maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
27122718
}
2719+
2720+
results.push_back(indexedElm);
27132721
}
27142722

27152723
return DenseElementsAttr::get(getResultVectorType(), results);

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,41 +2006,68 @@ func.func @shuffle_1d() -> vector<4xi32> {
20062006
return %shuffle : vector<4xi32>
20072007
}
20082008

2009+
// -----
2010+
2011+
// Check that poison indices pick the first element of the first non-poison
2012+
// input vector. That is, %v[0] (i.e., 5) in this test.
2013+
2014+
// CHECK-LABEL: func @shuffle_1d_poison_idx
2015+
// CHECK: %[[V:.+]] = arith.constant dense<[2, 5, 0, 5]> : vector<4xi32>
2016+
// CHECK: return %[[V]]
2017+
func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
2018+
%v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
2019+
%v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
2020+
%shuffle = vector.shuffle %v0, %v1 [3, -1, 5, -1] : vector<3xi32>, vector<3xi32>
2021+
return %shuffle : vector<4xi32>
2022+
}
2023+
2024+
// -----
2025+
20092026
// CHECK-LABEL: func @shuffle_canonicalize_0d
20102027
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
20112028
// CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
20122029
%shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
20132030
return %shuffle : vector<1xi32>
20142031
}
20152032

2033+
// -----
2034+
20162035
// CHECK-LABEL: func @shuffle_fold1
20172036
// CHECK: %arg0 : vector<4xi32>
20182037
func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi32> {
20192038
%shuffle = vector.shuffle %v0, %v1 [0, 1, 2, 3] : vector<4xi32>, vector<2xi32>
20202039
return %shuffle : vector<4xi32>
20212040
}
20222041

2042+
// -----
2043+
20232044
// CHECK-LABEL: func @shuffle_fold2
20242045
// CHECK: %arg1 : vector<2xi32>
20252046
func.func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi32> {
20262047
%shuffle = vector.shuffle %v0, %v1 [4, 5] : vector<4xi32>, vector<2xi32>
20272048
return %shuffle : vector<2xi32>
20282049
}
20292050

2051+
// -----
2052+
20302053
// CHECK-LABEL: func @shuffle_fold3
20312054
// CHECK: return %arg0 : vector<4x5x6xi32>
20322055
func.func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<4x5x6xi32> {
20332056
%shuffle = vector.shuffle %v0, %v1 [0, 1, 2, 3] : vector<4x5x6xi32>, vector<2x5x6xi32>
20342057
return %shuffle : vector<4x5x6xi32>
20352058
}
20362059

2060+
// -----
2061+
20372062
// CHECK-LABEL: func @shuffle_fold4
20382063
// CHECK: return %arg1 : vector<2x5x6xi32>
20392064
func.func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<2x5x6xi32> {
20402065
%shuffle = vector.shuffle %v0, %v1 [4, 5] : vector<4x5x6xi32>, vector<2x5x6xi32>
20412066
return %shuffle : vector<2x5x6xi32>
20422067
}
20432068

2069+
// -----
2070+
20442071
// CHECK-LABEL: func @shuffle_nofold1
20452072
// CHECK: %[[V:.+]] = vector.shuffle %arg0, %arg1 [0, 1, 2, 3, 4] : vector<4xi32>, vector<2xi32>
20462073
// CHECK: return %[[V]]

0 commit comments

Comments
 (0)