Skip to content

[InstCombine] Scalarize (vec_ops (insert ?, X, Idx)) when only one element is demanded #84645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 139 additions & 30 deletions llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2579,6 +2579,97 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) {
return new ShuffleVectorInst(X, Y, NewMask);
}

// Extract `(scalar_ops... x)` from `(vector_ops... (insert ?, x, C)`
static Value *
getScalarizationOfInsertElement(Value *V, int ReqIndexC,
InstCombiner::BuilderTy &Builder) {
Value *X, *Base;
ConstantInt *IndexC;
// Found a select.
if (match(V, m_InsertElt(m_Value(Base), m_Value(X), m_ConstantInt(IndexC)))) {
// See if matches the index we need.
if (match(IndexC, m_SpecificInt(ReqIndexC)))
return X;
// Otherwise continue searching. This is necessary for finding both elements
// in the common pattern:
// V0 = (insert poison x, 0)
// V1 = (insert V0, y, 1)
return getScalarizationOfInsertElement(Base, ReqIndexC, Builder);
}

// We can search through a splat of a single element for an insert.
int SplatIndex;
if (match(V, m_Shuffle(m_Value(Base), m_Value(X),
m_SplatOrUndefMask(SplatIndex))) &&
SplatIndex >= 0) {
if (auto *VType = dyn_cast<FixedVectorType>(V->getType())) {
// Chase whichever vector (Base/X) we are splatting from.
if (static_cast<unsigned>(SplatIndex) >= VType->getNumElements())
return getScalarizationOfInsertElement(
X, SplatIndex - VType->getNumElements(), Builder);
// New index we need to find is the index we are splatting from.
return getScalarizationOfInsertElement(Base, SplatIndex, Builder);
}
return nullptr;
}

// We don't want to duplicate `vector_ops...` if they have multiple uses.
if (!V->hasOneUse())
return nullptr;

Value *R = nullptr;
// Scalarize any unary op.
if (match(V, m_UnOp(m_Value(X)))) {
if (auto *Scalar = getScalarizationOfInsertElement(X, ReqIndexC, Builder))
R = Builder.CreateUnOp(cast<UnaryOperator>(V)->getOpcode(), Scalar);
}

// Scalarize any cast but bitcast.
// TODO: We skip bitcasts, but they would be okay if they are elementwise.
if (isa<CastInst>(V) && !match(V, m_BitCast(m_Value()))) {
X = cast<CastInst>(V)->getOperand(0);
if (auto *Scalar = getScalarizationOfInsertElement(X, ReqIndexC, Builder))
R = Builder.CreateCast(cast<CastInst>(V)->getOpcode(), Scalar,
V->getType()->getScalarType());
}

// Binop with a constant.
Constant *C;
if (match(V, m_c_BinOp(m_Value(X), m_ImmConstant(C)))) {
BinaryOperator *BO = cast<BinaryOperator>(V);
if (isSafeToSpeculativelyExecute(BO)) {
if (auto *Scalar =
getScalarizationOfInsertElement(X, ReqIndexC, Builder)) {
auto *ScalarC =
ConstantExpr::getExtractElement(C, Builder.getInt64(ReqIndexC));

BinaryOperator::BinaryOps Opc = BO->getOpcode();
if (match(V, m_c_BinOp(m_Value(X), m_ImmConstant(C))))
R = Builder.CreateBinOp(Opc, Scalar, ScalarC);
else
R = Builder.CreateBinOp(Opc, ScalarC, Scalar);
}
}
}

// Cmp with a constant.
CmpInst::Predicate Pred;
if (match(V, m_Cmp(Pred, m_Value(X), m_ImmConstant(C)))) {
if (auto *Scalar = getScalarizationOfInsertElement(X, ReqIndexC, Builder)) {
auto *ScalarC =
ConstantExpr::getExtractElement(C, Builder.getInt64(ReqIndexC));
R = Builder.CreateCmp(Pred, Scalar, ScalarC);
}
}
// TODO: Intrinsics

// If we created a new scalar instruction, copy flags from the vec version.
if (R != nullptr)
cast<Instruction>(R)->copyIRFlags(V);

return R;
}

/// Try to replace a shuffle with an insertelement or try to replace a shuffle
/// operand with the operand of an insertelement.
static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
Expand Down Expand Up @@ -2616,13 +2707,11 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
if (NumElts != InpNumElts)
return nullptr;

// shuffle (insert ?, Scalar, IndexC), V1, Mask --> insert V1, Scalar, IndexC'
auto isShufflingScalarIntoOp1 = [&](Value *&Scalar, ConstantInt *&IndexC) {
// We need an insertelement with a constant index.
if (!match(V0, m_InsertElt(m_Value(), m_Value(Scalar),
m_ConstantInt(IndexC))))
return false;

// (shuffle (vec_ops... (insert ?, Scalar, IndexC)), V1, Mask)
// --> insert V1, (scalar_ops... Scalar), IndexC'
auto GetScalarizationOfInsertEle =
[&Mask, &NumElts, &IC](Value *V,
int OtherIdx) -> std::pair<Value *, int> {
// Test the shuffle mask to see if it splices the inserted scalar into the
// operand 1 vector of the shuffle.
int NewInsIndex = -1;
Expand All @@ -2631,40 +2720,60 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
if (Mask[i] == -1)
continue;

// The shuffle takes elements of operand 1 without lane changes.
if (Mask[i] == NumElts + i)
// The shuffle takes element we are overwriting with the other insert.
if (i == OtherIdx)
continue;

// The shuffle takes elements of operand 1 either without modifying its
// position.
if (Mask[i] == (NumElts + i))
continue;

// The shuffle must choose the inserted scalar exactly once.
if (NewInsIndex != -1 || Mask[i] != IndexC->getSExtValue())
return false;
if (NewInsIndex != -1)
return {nullptr, -1};

// The shuffle is placing the inserted scalar into element i.
// The shuffle is placing the inserted scalar into element i from operand
// 0.
NewInsIndex = i;
}

assert(NewInsIndex != -1 && "Did not fold shuffle with unused operand?");
// Operand is unused.
if (NewInsIndex < 0)
return {nullptr, -1};

// Index is updated to the potentially translated insertion lane.
IndexC = ConstantInt::get(IndexC->getIntegerType(), NewInsIndex);
return true;
};
Value *Scalar =
getScalarizationOfInsertElement(V, Mask[NewInsIndex], IC.Builder);

// If the shuffle is unnecessary, insert the scalar operand directly into
// operand 1 of the shuffle. Example:
// shuffle (insert ?, S, 1), V1, <1, 5, 6, 7> --> insert V1, S, 0
Value *Scalar;
ConstantInt *IndexC;
if (isShufflingScalarIntoOp1(Scalar, IndexC))
return InsertElementInst::Create(V1, Scalar, IndexC);
return {Scalar, NewInsIndex};
};

// Try again after commuting shuffle. Example:
// shuffle V0, (insert ?, S, 0), <0, 1, 2, 4> -->
// shuffle (insert ?, S, 0), V0, <4, 5, 6, 0> --> insert V0, S, 3
std::swap(V0, V1);
auto [V0Scalar, V0NewInsertIdx] = GetScalarizationOfInsertEle(V0, -1);
ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
if (isShufflingScalarIntoOp1(Scalar, IndexC))
return InsertElementInst::Create(V1, Scalar, IndexC);
auto [V1Scalar, V1NewInsertIdx] =
GetScalarizationOfInsertEle(V1, V0NewInsertIdx);

// We failed to scalarize V0 because the shuffle was doing more than just
// blend on V1. But since we found an insert for V1, that might be covering
// the index we where having trouble with.
if (V0NewInsertIdx == -1 && V1Scalar != nullptr) {
ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
std::tie(V0Scalar, V0NewInsertIdx) =
GetScalarizationOfInsertEle(V0, V1NewInsertIdx);
}

if (V0Scalar != nullptr && V1Scalar != nullptr) {
Value *R = IC.Builder.CreateInsertElement(Shuf.getType(), V0Scalar,
V0NewInsertIdx);
return InsertElementInst::Create(R, V1Scalar,
IC.Builder.getInt64(V1NewInsertIdx));
} else if (V0Scalar != nullptr) {
return InsertElementInst::Create(V1, V0Scalar,
IC.Builder.getInt64(V0NewInsertIdx));
} else if (V1Scalar != nullptr) {
return InsertElementInst::Create(V0, V1Scalar,
IC.Builder.getInt64(V1NewInsertIdx));
}

return nullptr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ define <2 x i16> @test_udiv(i16 %a, i1 %cmp) {
; shufflevector is eliminated here.
define <2 x float> @test_fdiv(float %a, float %b, i1 %cmp) {
; CHECK-LABEL: @test_fdiv(
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A:%.*]], i64 1
; CHECK-NEXT: [[SPLAT_OP:%.*]] = fdiv <2 x float> [[TMP1]], <float undef, float 3.000000e+00>
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[SPLAT_OP]]
; CHECK-NEXT: [[A:%.*]] = fdiv float [[A1:%.*]], 3.000000e+00
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A]], i64 1
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[TMP1]]
; CHECK-NEXT: ret <2 x float> [[T2]]
;
%splatinsert = insertelement <2 x float> poison, float %a, i32 0
Expand All @@ -105,9 +105,9 @@ define <2 x float> @test_fdiv(float %a, float %b, i1 %cmp) {
; shufflevector is eliminated here.
define <2 x float> @test_frem(float %a, float %b, i1 %cmp) {
; CHECK-LABEL: @test_frem(
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A:%.*]], i64 1
; CHECK-NEXT: [[SPLAT_OP:%.*]] = frem <2 x float> [[TMP1]], <float undef, float 3.000000e+00>
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[SPLAT_OP]]
; CHECK-NEXT: [[A:%.*]] = frem float [[A1:%.*]], 3.000000e+00
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A]], i64 1
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[TMP1]]
; CHECK-NEXT: ret <2 x float> [[T2]]
;
%splatinsert = insertelement <2 x float> poison, float %a, i32 0
Expand Down
12 changes: 6 additions & 6 deletions llvm/test/Transforms/InstCombine/shufflevector-div-rem.ll
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ define <2 x i16> @test_udiv(i16 %a, i1 %cmp) {
; shufflevector is eliminated here.
define <2 x float> @test_fdiv(float %a, float %b, i1 %cmp) {
; CHECK-LABEL: @test_fdiv(
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A:%.*]], i64 1
; CHECK-NEXT: [[SPLAT_OP:%.*]] = fdiv <2 x float> [[TMP1]], <float undef, float 3.000000e+00>
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[SPLAT_OP]]
; CHECK-NEXT: [[A:%.*]] = fdiv float [[A1:%.*]], 3.000000e+00
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A]], i64 1
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[TMP1]]
; CHECK-NEXT: ret <2 x float> [[T2]]
;
%splatinsert = insertelement <2 x float> undef, float %a, i32 0
Expand All @@ -105,9 +105,9 @@ define <2 x float> @test_fdiv(float %a, float %b, i1 %cmp) {
; shufflevector is eliminated here.
define <2 x float> @test_frem(float %a, float %b, i1 %cmp) {
; CHECK-LABEL: @test_frem(
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A:%.*]], i64 1
; CHECK-NEXT: [[SPLAT_OP:%.*]] = frem <2 x float> [[TMP1]], <float undef, float 3.000000e+00>
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[SPLAT_OP]]
; CHECK-NEXT: [[A:%.*]] = frem float [[A1:%.*]], 3.000000e+00
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> poison, float [[A]], i64 1
; CHECK-NEXT: [[T2:%.*]] = select i1 [[CMP:%.*]], <2 x float> <float 7.700000e+01, float 9.900000e+01>, <2 x float> [[TMP1]]
; CHECK-NEXT: ret <2 x float> [[T2]]
;
%splatinsert = insertelement <2 x float> undef, float %a, i32 0
Expand Down
Loading