Skip to content

Commit b56bb3a

Browse files
committed
[InstCombine] Scalarize (vec_ops (insert ?, X, Idx)) when only one element is demanded
This came as a result of PR #84389. SLP vectorizer can vectorize in a pattern like: ``` (blend (vec_ops0... (insert ?,X,0)), (vec_ops1... (insert ?,Y,1)) ) ``` In this case, `vec_ops0...` and `vec_ops1...` are essentially doing scalar transforms. We previously we handle things like: `(blend (insert ?,X,0), (insert ?,Y,0))` This patch extends that to look through `vec_ops...` that can be scalarized, and if its possible to scalarize all ops, it transforms the input to: ``` (blend (insert ?,(scalar_ops0... X), 0), (insert ?,(scalar_ops1... Y), 0) ) ```
1 parent a489bf4 commit b56bb3a

11 files changed

+219
-145
lines changed

llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp

Lines changed: 123 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2579,6 +2579,97 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) {
25792579
return new ShuffleVectorInst(X, Y, NewMask);
25802580
}
25812581

2582+
// Extract `(scalar_ops... x)` from `(vector_ops... (insert ?, x, C)`
2583+
static Value *
2584+
getScalarizationOfInsertElement(Value *V, int ReqIndexC,
2585+
InstCombiner::BuilderTy &Builder) {
2586+
Value *X, *Base;
2587+
ConstantInt *IndexC;
2588+
// Found a select.
2589+
if (match(V, m_InsertElt(m_Value(Base), m_Value(X), m_ConstantInt(IndexC)))) {
2590+
// See if matches the index we need.
2591+
if (match(IndexC, m_SpecificInt(ReqIndexC)))
2592+
return X;
2593+
// Otherwise continue searching. This is necessary for finding both elements
2594+
// in the common pattern:
2595+
// V0 = (insert poison x, 0)
2596+
// V1 = (insert V0, y, 1)
2597+
return getScalarizationOfInsertElement(Base, ReqIndexC, Builder);
2598+
}
2599+
2600+
// We can search through a splat of a single element for an insert.
2601+
int SplatIndex;
2602+
if (match(V, m_Shuffle(m_Value(Base), m_Value(X),
2603+
m_SplatOrUndefMask(SplatIndex))) &&
2604+
SplatIndex >= 0) {
2605+
if (auto *VType = dyn_cast<FixedVectorType>(V->getType())) {
2606+
// Chase whichever vector (Base/X) we are splatting from.
2607+
if (static_cast<unsigned>(SplatIndex) >= VType->getNumElements())
2608+
return getScalarizationOfInsertElement(
2609+
X, SplatIndex - VType->getNumElements(), Builder);
2610+
// New index we need to find is the index we are splatting from.
2611+
return getScalarizationOfInsertElement(Base, SplatIndex, Builder);
2612+
}
2613+
return nullptr;
2614+
}
2615+
2616+
// We don't want to duplicate `vector_ops...` if they have multiple uses.
2617+
if (!V->hasOneUse())
2618+
return nullptr;
2619+
2620+
Value *R = nullptr;
2621+
// Scalarize any unary op.
2622+
if (match(V, m_UnOp(m_Value(X)))) {
2623+
if (auto *Scalar = getScalarizationOfInsertElement(X, ReqIndexC, Builder))
2624+
R = Builder.CreateUnOp(cast<UnaryOperator>(V)->getOpcode(), Scalar);
2625+
}
2626+
2627+
// Scalarize any cast but bitcast.
2628+
// TODO: We skip bitcasts, but they would be okay if they are elementwise.
2629+
if (isa<CastInst>(V) && !match(V, m_BitCast(m_Value()))) {
2630+
X = cast<CastInst>(V)->getOperand(0);
2631+
if (auto *Scalar = getScalarizationOfInsertElement(X, ReqIndexC, Builder))
2632+
R = Builder.CreateCast(cast<CastInst>(V)->getOpcode(), Scalar,
2633+
V->getType()->getScalarType());
2634+
}
2635+
2636+
// Binop with a constant.
2637+
Constant *C;
2638+
if (match(V, m_c_BinOp(m_Value(X), m_ImmConstant(C)))) {
2639+
BinaryOperator *BO = cast<BinaryOperator>(V);
2640+
if (isSafeToSpeculativelyExecute(BO)) {
2641+
if (auto *Scalar =
2642+
getScalarizationOfInsertElement(X, ReqIndexC, Builder)) {
2643+
auto *ScalarC =
2644+
ConstantExpr::getExtractElement(C, Builder.getInt64(ReqIndexC));
2645+
2646+
BinaryOperator::BinaryOps Opc = BO->getOpcode();
2647+
if (match(V, m_c_BinOp(m_Value(X), m_ImmConstant(C))))
2648+
R = Builder.CreateBinOp(Opc, Scalar, ScalarC);
2649+
else
2650+
R = Builder.CreateBinOp(Opc, ScalarC, Scalar);
2651+
}
2652+
}
2653+
}
2654+
2655+
// Cmp with a constant.
2656+
CmpInst::Predicate Pred;
2657+
if (match(V, m_Cmp(Pred, m_Value(X), m_ImmConstant(C)))) {
2658+
if (auto *Scalar = getScalarizationOfInsertElement(X, ReqIndexC, Builder)) {
2659+
auto *ScalarC =
2660+
ConstantExpr::getExtractElement(C, Builder.getInt64(ReqIndexC));
2661+
R = Builder.CreateCmp(Pred, Scalar, ScalarC);
2662+
}
2663+
}
2664+
// TODO: Intrinsics
2665+
2666+
// If we created a new scalar instruction, copy flags from the vec version.
2667+
if (R != nullptr)
2668+
cast<Instruction>(R)->copyIRFlags(V);
2669+
2670+
return R;
2671+
}
2672+
25822673
/// Try to replace a shuffle with an insertelement or try to replace a shuffle
25832674
/// operand with the operand of an insertelement.
25842675
static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
@@ -2616,13 +2707,11 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
26162707
if (NumElts != InpNumElts)
26172708
return nullptr;
26182709

2619-
// shuffle (insert ?, Scalar, IndexC), V1, Mask --> insert V1, Scalar, IndexC'
2620-
auto isShufflingScalarIntoOp1 = [&](Value *&Scalar, ConstantInt *&IndexC) {
2621-
// We need an insertelement with a constant index.
2622-
if (!match(V0, m_InsertElt(m_Value(), m_Value(Scalar),
2623-
m_ConstantInt(IndexC))))
2624-
return false;
26252710

2711+
// (shuffle (vec_ops... (insert ?, Scalar, IndexC)), V1, Mask)
2712+
// --> insert V1, (scalar_ops... Scalar), IndexC'
2713+
auto GetScalarizationOfInsertEle =
2714+
[&Mask, &NumElts, &IC](Value *V) -> std::pair<Value *, int> {
26262715
// Test the shuffle mask to see if it splices the inserted scalar into the
26272716
// operand 1 vector of the shuffle.
26282717
int NewInsIndex = -1;
@@ -2631,40 +2720,45 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
26312720
if (Mask[i] == -1)
26322721
continue;
26332722

2634-
// The shuffle takes elements of operand 1 without lane changes.
2635-
if (Mask[i] == NumElts + i)
2723+
// The shuffle takes elements of operand 1.
2724+
if (Mask[i] >= NumElts)
26362725
continue;
26372726

26382727
// The shuffle must choose the inserted scalar exactly once.
2639-
if (NewInsIndex != -1 || Mask[i] != IndexC->getSExtValue())
2640-
return false;
2728+
if (NewInsIndex != -1)
2729+
return {nullptr, -1};
26412730

2642-
// The shuffle is placing the inserted scalar into element i.
2731+
// The shuffle is placing the inserted scalar into element i from operand
2732+
// 0.
26432733
NewInsIndex = i;
26442734
}
26452735

2646-
assert(NewInsIndex != -1 && "Did not fold shuffle with unused operand?");
2736+
// Operand is unused.
2737+
if (NewInsIndex < 0)
2738+
return {nullptr, -1};
26472739

2648-
// Index is updated to the potentially translated insertion lane.
2649-
IndexC = ConstantInt::get(IndexC->getIntegerType(), NewInsIndex);
2650-
return true;
2651-
};
2740+
Value *Scalar =
2741+
getScalarizationOfInsertElement(V, Mask[NewInsIndex], IC.Builder);
26522742

2653-
// If the shuffle is unnecessary, insert the scalar operand directly into
2654-
// operand 1 of the shuffle. Example:
2655-
// shuffle (insert ?, S, 1), V1, <1, 5, 6, 7> --> insert V1, S, 0
2656-
Value *Scalar;
2657-
ConstantInt *IndexC;
2658-
if (isShufflingScalarIntoOp1(Scalar, IndexC))
2659-
return InsertElementInst::Create(V1, Scalar, IndexC);
2743+
return {Scalar, NewInsIndex};
2744+
};
26602745

2661-
// Try again after commuting shuffle. Example:
2662-
// shuffle V0, (insert ?, S, 0), <0, 1, 2, 4> -->
2663-
// shuffle (insert ?, S, 0), V0, <4, 5, 6, 0> --> insert V0, S, 3
2664-
std::swap(V0, V1);
2746+
auto [V0Scalar, V0NewInsertIdx] = GetScalarizationOfInsertEle(V0);
26652747
ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
2666-
if (isShufflingScalarIntoOp1(Scalar, IndexC))
2667-
return InsertElementInst::Create(V1, Scalar, IndexC);
2748+
auto [V1Scalar, V1NewInsertIdx] = GetScalarizationOfInsertEle(V1);
2749+
2750+
if (V0Scalar != nullptr && V1Scalar != nullptr) {
2751+
Value *R = IC.Builder.CreateInsertElement(Shuf.getType(), V0Scalar,
2752+
V0NewInsertIdx);
2753+
return InsertElementInst::Create(R, V1Scalar,
2754+
IC.Builder.getInt64(V1NewInsertIdx));
2755+
} else if (V0Scalar != nullptr) {
2756+
return InsertElementInst::Create(V1, V0Scalar,
2757+
IC.Builder.getInt64(V0NewInsertIdx));
2758+
} else if (V1Scalar != nullptr) {
2759+
return InsertElementInst::Create(V0, V1Scalar,
2760+
IC.Builder.getInt64(V1NewInsertIdx));
2761+
}
26682762

26692763
return nullptr;
26702764
}

llvm/test/Transforms/InstCombine/insert-extract-shuffle-inseltpoison.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,8 +547,7 @@ define <4 x float> @insert_in_splat_variable_index(float %x, i32 %y) {
547547

548548
define <4 x float> @insert_in_nonsplat(float %x, <4 x float> %y) {
549549
; CHECK-LABEL: @insert_in_nonsplat(
550-
; CHECK-NEXT: [[XV:%.*]] = insertelement <4 x float> poison, float [[X:%.*]], i64 0
551-
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <4 x float> [[XV]], <4 x float> [[Y:%.*]], <4 x i32> <i32 poison, i32 0, i32 4, i32 poison>
550+
; CHECK-NEXT: [[SPLAT:%.*]] = insertelement <4 x float> [[Y:%.*]], float [[X:%.*]], i64 1
552551
; CHECK-NEXT: [[R:%.*]] = insertelement <4 x float> [[SPLAT]], float [[X]], i64 3
553552
; CHECK-NEXT: ret <4 x float> [[R]]
554553
;

llvm/test/Transforms/InstCombine/insert-extract-shuffle.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,8 +547,7 @@ define <4 x float> @insert_in_splat_variable_index(float %x, i32 %y) {
547547

548548
define <4 x float> @insert_in_nonsplat(float %x, <4 x float> %y) {
549549
; CHECK-LABEL: @insert_in_nonsplat(
550-
; CHECK-NEXT: [[XV:%.*]] = insertelement <4 x float> poison, float [[X:%.*]], i64 0
551-
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <4 x float> [[XV]], <4 x float> [[Y:%.*]], <4 x i32> <i32 poison, i32 0, i32 4, i32 poison>
550+
; CHECK-NEXT: [[SPLAT:%.*]] = insertelement <4 x float> [[Y:%.*]], float [[X:%.*]], i64 1
552551
; CHECK-NEXT: [[R:%.*]] = insertelement <4 x float> [[SPLAT]], float [[X]], i64 3
553552
; CHECK-NEXT: ret <4 x float> [[R]]
554553
;

llvm/test/Transforms/InstCombine/shufflevector-div-rem-inseltpoison.ll

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ define <2 x i16> @test_udiv(i16 %a, i1 %cmp) {
8888
; shufflevector is eliminated here.
8989
define <2 x float> @test_fdiv(float %a, float %b, i1 %cmp) {
9090
; CHECK-LABEL: @test_fdiv(
91-
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A:%.*]], i64 1
92-
; CHECK-NEXT: [[SPLAT_OP:%.*]] = fdiv <2 x float> [[TMP1]], <float undef, float 3.000000e+00>
93-
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[SPLAT_OP]]
91+
; CHECK-NEXT: [[A:%.*]] = fdiv float [[A1:%.*]], 3.000000e+00
92+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A]], i64 1
93+
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[TMP1]]
9494
; CHECK-NEXT: ret <2 x float> [[T2]]
9595
;
9696
%splatinsert = insertelement <2 x float> poison, float %a, i32 0
@@ -105,9 +105,9 @@ define <2 x float> @test_fdiv(float %a, float %b, i1 %cmp) {
105105
; shufflevector is eliminated here.
106106
define <2 x float> @test_frem(float %a, float %b, i1 %cmp) {
107107
; CHECK-LABEL: @test_frem(
108-
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A:%.*]], i64 1
109-
; CHECK-NEXT: [[SPLAT_OP:%.*]] = frem <2 x float> [[TMP1]], <float undef, float 3.000000e+00>
110-
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[SPLAT_OP]]
108+
; CHECK-NEXT: [[A:%.*]] = frem float [[A1:%.*]], 3.000000e+00
109+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A]], i64 1
110+
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[TMP1]]
111111
; CHECK-NEXT: ret <2 x float> [[T2]]
112112
;
113113
%splatinsert = insertelement <2 x float> poison, float %a, i32 0

llvm/test/Transforms/InstCombine/shufflevector-div-rem.ll

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ define <2 x i16> @test_udiv(i16 %a, i1 %cmp) {
8888
; shufflevector is eliminated here.
8989
define <2 x float> @test_fdiv(float %a, float %b, i1 %cmp) {
9090
; CHECK-LABEL: @test_fdiv(
91-
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A:%.*]], i64 1
92-
; CHECK-NEXT: [[SPLAT_OP:%.*]] = fdiv <2 x float> [[TMP1]], <float undef, float 3.000000e+00>
93-
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[SPLAT_OP]]
91+
; CHECK-NEXT: [[A:%.*]] = fdiv float [[A1:%.*]], 3.000000e+00
92+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A]], i64 1
93+
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[TMP1]]
9494
; CHECK-NEXT: ret <2 x float> [[T2]]
9595
;
9696
%splatinsert = insertelement <2 x float> undef, float %a, i32 0
@@ -105,9 +105,9 @@ define <2 x float> @test_fdiv(float %a, float %b, i1 %cmp) {
105105
; shufflevector is eliminated here.
106106
define <2 x float> @test_frem(float %a, float %b, i1 %cmp) {
107107
; CHECK-LABEL: @test_frem(
108-
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A:%.*]], i64 1
109-
; CHECK-NEXT: [[SPLAT_OP:%.*]] = frem <2 x float> [[TMP1]], <float undef, float 3.000000e+00>
110-
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[SPLAT_OP]]
108+
; CHECK-NEXT: [[A:%.*]] = frem float [[A1:%.*]], 3.000000e+00
109+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A]], i64 1
110+
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[TMP1]]
111111
; CHECK-NEXT: ret <2 x float> [[T2]]
112112
;
113113
%splatinsert = insertelement <2 x float> undef, float %a, i32 0

0 commit comments

Comments
 (0)