Skip to content

Commit c6eef00

Browse files
authored
[mlir][Vector] Add vector.shuffle fold for poison inputs (#125608)
#124863 added folding support for poison indices to `vector.shuffle`. This PR adds support for folding `vector.shuffle` ops with one or two poison input vectors.
1 parent 6422882 commit c6eef00

File tree

2 files changed

+74
-14
lines changed

2 files changed

+74
-14
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include "mlir/IR/AffineMap.h"
2727
#include "mlir/IR/Builders.h"
2828
#include "mlir/IR/BuiltinAttributes.h"
29-
#include "mlir/IR/BuiltinOps.h"
3029
#include "mlir/IR/BuiltinTypes.h"
3130
#include "mlir/IR/DialectImplementation.h"
3231
#include "mlir/IR/IRMapping.h"
@@ -42,7 +41,6 @@
4241
#include "llvm/ADT/SmallVector.h"
4342
#include "llvm/ADT/StringSet.h"
4443
#include "llvm/ADT/TypeSwitch.h"
45-
#include "llvm/ADT/bit.h"
4644

4745
#include <cassert>
4846
#include <cstdint>
@@ -2684,25 +2682,45 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
26842682
if (!v1Attr || !v2Attr)
26852683
return {};
26862684

2685+
// Fold shuffle poison, poison -> poison.
2686+
bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2687+
bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2688+
if (isV1Poison && isV2Poison)
2689+
return ub::PoisonAttr::get(getContext());
2690+
26872691
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
26882692
// manipulation.
26892693
if (v1Type.getRank() != 1)
26902694
return {};
26912695

2692-
int64_t v1Size = v1Type.getDimSize(0);
2696+
// Poison input attributes need special handling as they are not
2697+
// DenseElementsAttr. If an index is poison, we select the first element of
2698+
// the first non-poison input.
2699+
SmallVector<Attribute> v1Elements, v2Elements;
2700+
Attribute poisonElement;
2701+
if (!isV2Poison) {
2702+
v2Elements =
2703+
to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
2704+
poisonElement = v2Elements[0];
2705+
}
2706+
if (!isV1Poison) {
2707+
v1Elements =
2708+
to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
2709+
poisonElement = v1Elements[0];
2710+
}
26932711

26942712
SmallVector<Attribute> results;
2695-
auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
2696-
auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
2713+
int64_t v1Size = v1Type.getDimSize(0);
26972714
for (int64_t maskIdx : mask) {
26982715
Attribute indexedElm;
2699-
// Select v1[0] for poison indices.
27002716
// TODO: Return a partial poison vector when supported by the UB dialect.
27012717
if (maskIdx == ShuffleOp::kPoisonIndex) {
2702-
indexedElm = v1Elements[0];
2718+
indexedElm = poisonElement;
27032719
} else {
2704-
indexedElm =
2705-
maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
2720+
if (maskIdx < v1Size)
2721+
indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2722+
else
2723+
indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
27062724
}
27072725

27082726
results.push_back(indexedElm);
@@ -3319,13 +3337,15 @@ class InsertStridedSliceConstantFolder final
33193337
!destVector.hasOneUse())
33203338
return failure();
33213339

3322-
auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3323-
33243340
TypedValue<VectorType> sourceValue = op.getSource();
33253341
Attribute sourceCst;
33263342
if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
33273343
return failure();
33283344

3345+
// TODO: Support poison.
3346+
if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3347+
return failure();
3348+
33293349
// TODO: Handle non-unit strides when they become available.
33303350
if (op.hasNonUnitStrides())
33313351
return failure();
@@ -3342,6 +3362,7 @@ class InsertStridedSliceConstantFolder final
33423362
// increasing linearized position indices.
33433363
// Because the destination may have higher dimensionality then the slice,
33443364
// we keep track of two overlapping sets of positions and offsets.
3365+
auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
33453366
auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
33463367
auto sliceValuesIt = denseSlice.value_begin<Attribute>();
33473368
auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2012,17 +2012,56 @@ func.func @shuffle_1d() -> vector<4xi32> {
20122012
// input vector. That is, %v[0] (i.e., 5) in this test.
20132013

20142014
// CHECK-LABEL: func @shuffle_1d_poison_idx
2015-
// CHECK: %[[V:.+]] = arith.constant dense<[2, 5, 0, 5]> : vector<4xi32>
2015+
// CHECK: %[[V:.+]] = arith.constant dense<[13, 10, 15, 10]> : vector<4xi32>
20162016
// CHECK: return %[[V]]
20172017
func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
2018-
%v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
2019-
%v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
2018+
%v0 = arith.constant dense<[10, 11, 12]> : vector<3xi32>
2019+
%v1 = arith.constant dense<[13, 14, 15]> : vector<3xi32>
20202020
%shuffle = vector.shuffle %v0, %v1 [3, -1, 5, -1] : vector<3xi32>, vector<3xi32>
20212021
return %shuffle : vector<4xi32>
20222022
}
20232023

20242024
// -----
20252025

2026+
// CHECK-LABEL: func @shuffle_1d_rhs_lhs_poison
2027+
// CHECK-NOT: vector.shuffle
2028+
// CHECK: %[[V:.+]] = ub.poison : vector<4xi32>
2029+
// CHECK: return %[[V]]
2030+
func.func @shuffle_1d_rhs_lhs_poison() -> vector<4xi32> {
2031+
%v0 = ub.poison : vector<3xi32>
2032+
%v1 = ub.poison : vector<3xi32>
2033+
%shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
2034+
return %shuffle : vector<4xi32>
2035+
}
2036+
2037+
// -----
2038+
2039+
// CHECK-LABEL: func @shuffle_1d_lhs_poison
2040+
// CHECK-NOT: vector.shuffle
2041+
// CHECK: %[[V:.+]] = arith.constant dense<[11, 12, 11, 11]> : vector<4xi32>
2042+
// CHECK: return %[[V]]
2043+
func.func @shuffle_1d_lhs_poison() -> vector<4xi32> {
2044+
%v0 = arith.constant dense<[11, 12, 13]> : vector<3xi32>
2045+
%v1 = ub.poison : vector<3xi32>
2046+
%shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
2047+
return %shuffle : vector<4xi32>
2048+
}
2049+
2050+
// -----
2051+
2052+
// CHECK-LABEL: func @shuffle_1d_rhs_poison
2053+
// CHECK-NOT: vector.shuffle
2054+
// CHECK: %[[V:.+]] = arith.constant dense<[11, 11, 13, 12]> : vector<4xi32>
2055+
// CHECK: return %[[V]]
2056+
func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
2057+
%v0 = ub.poison : vector<3xi32>
2058+
%v1 = arith.constant dense<[11, 12, 13]> : vector<3xi32>
2059+
%shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
2060+
return %shuffle : vector<4xi32>
2061+
}
2062+
2063+
// -----
2064+
20262065
// CHECK-LABEL: func @shuffle_canonicalize_0d
20272066
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
20282067
// CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>

0 commit comments

Comments
 (0)