Skip to content

[mlir][Vector] Fix vector.shuffle folder for poison indices #124863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2673,43 +2673,51 @@ static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
}

OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
VectorType v1Type = getV1VectorType();
auto v1Type = getV1VectorType();
auto v2Type = getV2VectorType();

assert(!v1Type.isScalable() && !v2Type.isScalable() &&
"Vector shuffle does not support scalable vectors");

// For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
// but must be a canonicalization into a vector.broadcast.
if (v1Type.getRank() == 0)
return {};

// fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
if (!v1Type.isScalable() &&
isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
// Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
auto mask = getMask();
if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
return getV1();
// fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
getV2VectorType().getDimSize(0)))
// Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
return getV2();

Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
if (!lhs || !rhs)
Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
if (!v1Attr || !v2Attr)
return {};

auto lhsType =
llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
// manipulation.
if (lhsType.getRank() != 1)
if (v1Type.getRank() != 1)
return {};
int64_t lhsSize = lhsType.getDimSize(0);

int64_t v1Size = v1Type.getDimSize(0);

SmallVector<Attribute> results;
auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
for (int64_t i : this->getMask()) {
if (i >= lhsSize) {
results.push_back(rhsElements[i - lhsSize]);
auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
for (int64_t maskIdx : mask) {
Attribute indexedElm;
// Select v1[0] for poison indices.
// TODO: Return a partial poison vector when supported by the UB dialect.
if (maskIdx == ShuffleOp::kPoisonIndex) {
indexedElm = v1Elements[0];
} else {
results.push_back(lhsElements[i]);
indexedElm =
maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
}

results.push_back(indexedElm);
}

return DenseElementsAttr::get(getResultVectorType(), results);
Expand Down
27 changes: 27 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2006,41 +2006,68 @@ func.func @shuffle_1d() -> vector<4xi32> {
return %shuffle : vector<4xi32>
}

// -----

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Could you add a comment explaining what happens for poison idxs? It's quite non-obvious 😅🙏🏻

// Check that poison indices pick the first element of the first non-poison
// input vector. That is, %v[0] (i.e., 5) in this test.

// CHECK-LABEL: func @shuffle_1d_poison_idx
// CHECK: %[[V:.+]] = arith.constant dense<[2, 5, 0, 5]> : vector<4xi32>
// CHECK: return %[[V]]
func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
%v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
%v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
%shuffle = vector.shuffle %v0, %v1 [3, -1, 5, -1] : vector<3xi32>, vector<3xi32>
return %shuffle : vector<4xi32>
}

// -----

// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
// CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
%shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
return %shuffle : vector<1xi32>
}

// -----

// CHECK-LABEL: func @shuffle_fold1
// CHECK: %arg0 : vector<4xi32>
func.func @shuffle_fold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<4xi32> {
%shuffle = vector.shuffle %v0, %v1 [0, 1, 2, 3] : vector<4xi32>, vector<2xi32>
return %shuffle : vector<4xi32>
}

// -----

// CHECK-LABEL: func @shuffle_fold2
// CHECK: %arg1 : vector<2xi32>
func.func @shuffle_fold2(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<2xi32> {
%shuffle = vector.shuffle %v0, %v1 [4, 5] : vector<4xi32>, vector<2xi32>
return %shuffle : vector<2xi32>
}

// -----

// CHECK-LABEL: func @shuffle_fold3
// CHECK: return %arg0 : vector<4x5x6xi32>
func.func @shuffle_fold3(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<4x5x6xi32> {
%shuffle = vector.shuffle %v0, %v1 [0, 1, 2, 3] : vector<4x5x6xi32>, vector<2x5x6xi32>
return %shuffle : vector<4x5x6xi32>
}

// -----

// CHECK-LABEL: func @shuffle_fold4
// CHECK: return %arg1 : vector<2x5x6xi32>
func.func @shuffle_fold4(%v0 : vector<4x5x6xi32>, %v1 : vector<2x5x6xi32>) -> vector<2x5x6xi32> {
%shuffle = vector.shuffle %v0, %v1 [4, 5] : vector<4x5x6xi32>, vector<2x5x6xi32>
return %shuffle : vector<2x5x6xi32>
}

// -----

// CHECK-LABEL: func @shuffle_nofold1
// CHECK: %[[V:.+]] = vector.shuffle %arg0, %arg1 [0, 1, 2, 3, 4] : vector<4xi32>, vector<2xi32>
// CHECK: return %[[V]]
Expand Down