Skip to content

Commit 81d9ed6

Browse files
committed
[SLP]Do extra analysis int minbitwidth if some checks return false.
The instruction itself can be considered good for minbitwidth casting, even if one of the operand checks returns false. Reviewers: RKSimon Reviewed By: RKSimon Pull Request: llvm#84363
1 parent fb5fd2d commit 81d9ed6

File tree

3 files changed

+133
-56
lines changed

3 files changed

+133
-56
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 61 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12225,7 +12225,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1222512225
return E->VectorizedValue;
1222612226
}
1222712227
if (True->getType() != VecTy || False->getType() != VecTy) {
12228-
assert((getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
12228+
assert((It != MinBWs.end() ||
12229+
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
1222912230
getOperandEntry(E, 2)->State == TreeEntry::NeedToGather ||
1223012231
MinBWs.contains(getOperandEntry(E, 1)) ||
1223112232
MinBWs.contains(getOperandEntry(E, 2))) &&
@@ -12297,7 +12298,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1229712298
return E->VectorizedValue;
1229812299
}
1229912300
if (LHS->getType() != VecTy || RHS->getType() != VecTy) {
12300-
assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
12301+
assert((It != MinBWs.end() ||
12302+
getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
1230112303
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
1230212304
MinBWs.contains(getOperandEntry(E, 0)) ||
1230312305
MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -12543,7 +12545,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1254312545
((Instruction::isBinaryOp(E->getOpcode()) &&
1254412546
(LHS->getType() != VecTy || RHS->getType() != VecTy)) ||
1254512547
(isa<CmpInst>(VL0) && LHS->getType() != RHS->getType()))) {
12546-
assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
12548+
assert((It != MinBWs.end() ||
12549+
getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
1254712550
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
1254812551
MinBWs.contains(getOperandEntry(E, 0)) ||
1254912552
MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -12559,9 +12562,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1255912562
else
1256012563
CastTy = LHS->getType();
1256112564
}
12562-
if (LHS->getType() != VecTy)
12565+
if (LHS->getType() != CastTy)
1256312566
LHS = Builder.CreateIntCast(LHS, CastTy, GetOperandSignedness(0));
12564-
if (RHS->getType() != VecTy)
12567+
if (RHS->getType() != CastTy)
1256512568
RHS = Builder.CreateIntCast(RHS, CastTy, GetOperandSignedness(1));
1256612569
}
1256712570

@@ -13988,15 +13991,6 @@ bool BoUpSLP::collectValuesToDemote(
1398813991
// If the value is not a vectorized instruction in the expression and not used
1398913992
// by the insertelement instruction and not used in multiple vector nodes, it
1399013993
// cannot be demoted.
13991-
// TODO: improve handling of gathered values and others.
13992-
auto *I = dyn_cast<Instruction>(V);
13993-
const TreeEntry *ITE = I ? getTreeEntry(I) : nullptr;
13994-
if (!ITE || !Visited.insert(I).second || MultiNodeScalars.contains(I) ||
13995-
all_of(I->users(), [&](User *U) {
13996-
return isa<InsertElementInst>(U) && !getTreeEntry(U);
13997-
}))
13998-
return false;
13999-
1400013994
auto IsPotentiallyTruncated = [&](Value *V, unsigned &BitWidth) -> bool {
1400113995
if (MultiNodeScalars.contains(V))
1400213996
return false;
@@ -14011,8 +14005,44 @@ bool BoUpSLP::collectValuesToDemote(
1401114005
BitWidth = std::max(BitWidth, BitWidth1);
1401214006
return BitWidth > 0 && OrigBitWidth >= (BitWidth * 2);
1401314007
};
14008+
auto FinalAnalysis = [&](const TreeEntry *ITE = nullptr) {
14009+
if (!IsProfitableToDemote)
14010+
return false;
14011+
return (ITE && ITE->UserTreeIndices.size() > 1) ||
14012+
IsPotentiallyTruncated(V, BitWidth);
14013+
};
14014+
// TODO: improve handling of gathered values and others.
14015+
auto *I = dyn_cast<Instruction>(V);
14016+
const TreeEntry *ITE = I ? getTreeEntry(I) : nullptr;
14017+
if (!ITE || !Visited.insert(I).second || MultiNodeScalars.contains(I) ||
14018+
all_of(I->users(), [&](User *U) {
14019+
return isa<InsertElementInst>(U) && !getTreeEntry(U);
14020+
}))
14021+
return FinalAnalysis();
14022+
1401414023
unsigned Start = 0;
1401514024
unsigned End = I->getNumOperands();
14025+
14026+
auto ProcessOperands = [&](ArrayRef<Value *> Operands, bool &NeedToExit) {
14027+
NeedToExit = false;
14028+
unsigned InitLevel = MaxDepthLevel;
14029+
for (Value *IncValue : Operands) {
14030+
unsigned Level = InitLevel;
14031+
if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth,
14032+
ToDemote, DemotedConsts, Visited, Level,
14033+
IsProfitableToDemote, IsTruncRoot)) {
14034+
if (!IsProfitableToDemote)
14035+
return false;
14036+
NeedToExit = true;
14037+
if (!FinalAnalysis(ITE))
14038+
return false;
14039+
continue;
14040+
}
14041+
MaxDepthLevel = std::max(MaxDepthLevel, Level);
14042+
}
14043+
return true;
14044+
};
14045+
bool NeedToExit = false;
1401614046
switch (I->getOpcode()) {
1401714047

1401814048
// We can always demote truncations and extensions. Since truncations can
@@ -14038,35 +14068,21 @@ bool BoUpSLP::collectValuesToDemote(
1403814068
case Instruction::And:
1403914069
case Instruction::Or:
1404014070
case Instruction::Xor: {
14041-
unsigned Level1 = MaxDepthLevel, Level2 = MaxDepthLevel;
14042-
if ((ITE->UserTreeIndices.size() > 1 &&
14043-
!IsPotentiallyTruncated(I, BitWidth)) ||
14044-
!collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
14045-
BitWidth, ToDemote, DemotedConsts, Visited,
14046-
Level1, IsProfitableToDemote, IsTruncRoot) ||
14047-
!collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
14048-
BitWidth, ToDemote, DemotedConsts, Visited,
14049-
Level2, IsProfitableToDemote, IsTruncRoot))
14071+
if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
14072+
return false;
14073+
if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
1405014074
return false;
14051-
MaxDepthLevel = std::max(Level1, Level2);
1405214075
break;
1405314076
}
1405414077

1405514078
// We can demote selects if we can demote their true and false values.
1405614079
case Instruction::Select: {
14080+
if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
14081+
return false;
1405714082
Start = 1;
14058-
unsigned Level1 = MaxDepthLevel, Level2 = MaxDepthLevel;
14059-
SelectInst *SI = cast<SelectInst>(I);
14060-
if ((ITE->UserTreeIndices.size() > 1 &&
14061-
!IsPotentiallyTruncated(I, BitWidth)) ||
14062-
!collectValuesToDemote(SI->getTrueValue(), IsProfitableToDemoteRoot,
14063-
BitWidth, ToDemote, DemotedConsts, Visited,
14064-
Level1, IsProfitableToDemote, IsTruncRoot) ||
14065-
!collectValuesToDemote(SI->getFalseValue(), IsProfitableToDemoteRoot,
14066-
BitWidth, ToDemote, DemotedConsts, Visited,
14067-
Level2, IsProfitableToDemote, IsTruncRoot))
14083+
auto *SI = cast<SelectInst>(I);
14084+
if (!ProcessOperands({SI->getTrueValue(), SI->getFalseValue()}, NeedToExit))
1406814085
return false;
14069-
MaxDepthLevel = std::max(Level1, Level2);
1407014086
break;
1407114087
}
1407214088

@@ -14076,23 +14092,20 @@ bool BoUpSLP::collectValuesToDemote(
1407614092
PHINode *PN = cast<PHINode>(I);
1407714093
if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
1407814094
return false;
14079-
unsigned InitLevel = MaxDepthLevel;
14080-
for (Value *IncValue : PN->incoming_values()) {
14081-
unsigned Level = InitLevel;
14082-
if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth,
14083-
ToDemote, DemotedConsts, Visited, Level,
14084-
IsProfitableToDemote, IsTruncRoot))
14085-
return false;
14086-
MaxDepthLevel = std::max(MaxDepthLevel, Level);
14087-
}
14095+
SmallVector<Value *> Ops(PN->incoming_values().begin(),
14096+
PN->incoming_values().end());
14097+
if (!ProcessOperands(Ops, NeedToExit))
14098+
return false;
1408814099
break;
1408914100
}
1409014101

1409114102
// Otherwise, conservatively give up.
1409214103
default:
1409314104
MaxDepthLevel = 1;
14094-
return IsProfitableToDemote && IsPotentiallyTruncated(I, BitWidth);
14105+
return FinalAnalysis();
1409514106
}
14107+
if (NeedToExit)
14108+
return true;
1409614109

1409714110
++MaxDepthLevel;
1409814111
// Gather demoted constant operands.
@@ -14131,15 +14144,17 @@ void BoUpSLP::computeMinimumValueSizes() {
1413114144

1413214145
// The first value node for store/insertelement is sext/zext/trunc? Skip it,
1413314146
// resize to the final type.
14147+
bool IsTruncRoot = false;
1413414148
bool IsProfitableToDemoteRoot = !IsStoreOrInsertElt;
1413514149
if (NodeIdx != 0 &&
1413614150
VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize &&
1413714151
(VectorizableTree[NodeIdx]->getOpcode() == Instruction::ZExt ||
1413814152
VectorizableTree[NodeIdx]->getOpcode() == Instruction::SExt ||
1413914153
VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc)) {
1414014154
assert(IsStoreOrInsertElt && "Expected store/insertelement seeded graph.");
14141-
++NodeIdx;
14155+
IsTruncRoot = VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc;
1414214156
IsProfitableToDemoteRoot = true;
14157+
++NodeIdx;
1414314158
}
1414414159

1414514160
// Analyzed in reduction already and not profitable - exit.
@@ -14271,7 +14286,6 @@ void BoUpSLP::computeMinimumValueSizes() {
1427114286
ReductionBitWidth = bit_ceil(ReductionBitWidth);
1427214287
}
1427314288
bool IsTopRoot = NodeIdx == 0;
14274-
bool IsTruncRoot = false;
1427514289
while (NodeIdx < VectorizableTree.size() &&
1427614290
VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize &&
1427714291
VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc) {

llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ for.end: ; preds = %for.end.loopexit, %
228228
; YAML-NEXT: Function: test_unrolled_select
229229
; YAML-NEXT: Args:
230230
; YAML-NEXT: - String: 'Vectorized horizontal reduction with cost '
231-
; YAML-NEXT: - Cost: '-36'
231+
; YAML-NEXT: - Cost: '-41'
232232
; YAML-NEXT: - String: ' and with tree size '
233233
; YAML-NEXT: - TreeSize: '10'
234234

@@ -246,15 +246,17 @@ define i32 @test_unrolled_select(ptr noalias nocapture readonly %blk1, ptr noali
246246
; CHECK-NEXT: [[P2_045:%.*]] = phi ptr [ [[BLK2:%.*]], [[FOR_BODY_LR_PH]] ], [ [[ADD_PTR88:%.*]], [[IF_END_86]] ]
247247
; CHECK-NEXT: [[P1_044:%.*]] = phi ptr [ [[BLK1:%.*]], [[FOR_BODY_LR_PH]] ], [ [[ADD_PTR:%.*]], [[IF_END_86]] ]
248248
; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i8>, ptr [[P1_044]], align 1
249-
; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i8> [[TMP0]] to <8 x i32>
249+
; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i8> [[TMP0]] to <8 x i16>
250250
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i8>, ptr [[P2_045]], align 1
251-
; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i8> [[TMP2]] to <8 x i32>
252-
; CHECK-NEXT: [[TMP4:%.*]] = sub nsw <8 x i32> [[TMP1]], [[TMP3]]
253-
; CHECK-NEXT: [[TMP5:%.*]] = icmp slt <8 x i32> [[TMP4]], zeroinitializer
254-
; CHECK-NEXT: [[TMP6:%.*]] = sub nsw <8 x i32> zeroinitializer, [[TMP4]]
255-
; CHECK-NEXT: [[TMP7:%.*]] = select <8 x i1> [[TMP5]], <8 x i32> [[TMP6]], <8 x i32> [[TMP4]]
256-
; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP7]])
257-
; CHECK-NEXT: [[OP_RDX]] = add i32 [[TMP8]], [[S_047]]
251+
; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i8> [[TMP2]] to <8 x i16>
252+
; CHECK-NEXT: [[TMP4:%.*]] = sub <8 x i16> [[TMP1]], [[TMP3]]
253+
; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i16> [[TMP4]] to <8 x i32>
254+
; CHECK-NEXT: [[TMP6:%.*]] = icmp slt <8 x i32> [[TMP5]], zeroinitializer
255+
; CHECK-NEXT: [[TMP7:%.*]] = sub <8 x i16> zeroinitializer, [[TMP4]]
256+
; CHECK-NEXT: [[TMP8:%.*]] = select <8 x i1> [[TMP6]], <8 x i16> [[TMP7]], <8 x i16> [[TMP4]]
257+
; CHECK-NEXT: [[TMP9:%.*]] = sext <8 x i16> [[TMP8]] to <8 x i32>
258+
; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP9]])
259+
; CHECK-NEXT: [[OP_RDX]] = add i32 [[TMP10]], [[S_047]]
258260
; CHECK-NEXT: [[CMP83:%.*]] = icmp slt i32 [[OP_RDX]], [[LIM:%.*]]
259261
; CHECK-NEXT: br i1 [[CMP83]], label [[IF_END_86]], label [[FOR_END_LOOPEXIT:%.*]]
260262
; CHECK: if.end.86:
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
2+
; RUN: opt --passes=slp-vectorizer -S -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s
3+
4+
define void @test() {
5+
; CHECK-LABEL: define void @test() {
6+
; CHECK-NEXT: entry:
7+
; CHECK-NEXT: store <8 x i16> zeroinitializer, ptr null, align 2
8+
; CHECK-NEXT: ret void
9+
;
10+
entry:
11+
%arrayidx8 = getelementptr i8, ptr null, i64 2
12+
%shr10 = ashr i32 0, 0
13+
%shr19 = lshr i32 0, 0
14+
%sub20 = or i32 %shr19, %shr10
15+
%xor21 = xor i32 %sub20, 0
16+
%conv22 = trunc i32 %xor21 to i16
17+
store i16 %conv22, ptr %arrayidx8, align 2
18+
%arrayidx28 = getelementptr i8, ptr null, i64 4
19+
%shr34 = lshr i32 0, 0
20+
%sub35 = or i32 %shr34, %shr10
21+
%xor36 = xor i32 %sub35, 0
22+
%conv37 = trunc i32 %xor36 to i16
23+
store i16 %conv37, ptr %arrayidx28, align 2
24+
%arrayidx43 = getelementptr i8, ptr null, i64 6
25+
%shr49 = lshr i32 0, 0
26+
%sub50 = or i32 %shr49, %shr10
27+
%xor51 = xor i32 %sub50, 0
28+
%conv52 = trunc i32 %xor51 to i16
29+
store i16 %conv52, ptr %arrayidx43, align 2
30+
%arrayidx.1 = getelementptr i8, ptr null, i64 8
31+
%shr.1 = lshr i32 0, 0
32+
%xor2.1 = xor i32 %shr.1, %shr10
33+
%sub3.1 = or i32 %xor2.1, 0
34+
%conv4.1 = trunc i32 %sub3.1 to i16
35+
store i16 %conv4.1, ptr %arrayidx.1, align 2
36+
%arrayidx8.1 = getelementptr i8, ptr null, i64 10
37+
%shr10.1 = ashr i32 0, 0
38+
%shr19.1 = lshr i32 0, 0
39+
%sub20.1 = or i32 %shr19.1, %shr10.1
40+
%xor21.1 = xor i32 %sub20.1, 0
41+
%conv22.1 = trunc i32 %xor21.1 to i16
42+
store i16 %conv22.1, ptr %arrayidx8.1, align 2
43+
%arrayidx28.1 = getelementptr i8, ptr null, i64 12
44+
%shr34.1 = lshr i32 0, 0
45+
%sub35.1 = or i32 %shr34.1, %shr10.1
46+
%xor36.1 = xor i32 %sub35.1, 0
47+
%conv37.1 = trunc i32 %xor36.1 to i16
48+
store i16 %conv37.1, ptr %arrayidx28.1, align 2
49+
%arrayidx43.1 = getelementptr i8, ptr null, i64 14
50+
%shr49.1 = lshr i32 0, 0
51+
%sub50.1 = or i32 %shr49.1, %shr10.1
52+
%xor51.1 = xor i32 %sub50.1, 0
53+
%conv52.1 = trunc i32 %xor51.1 to i16
54+
store i16 %conv52.1, ptr %arrayidx43.1, align 2
55+
%shr.2 = lshr i32 0, 0
56+
%xor2.2 = xor i32 %shr.2, %shr10.1
57+
%sub3.2 = or i32 %xor2.2, 0
58+
%conv4.2 = trunc i32 %sub3.2 to i16
59+
store i16 %conv4.2, ptr null, align 2
60+
ret void
61+
}

0 commit comments

Comments
 (0)