Skip to content

Commit f8e1337

Browse files
committed
[SLP] Support internal users of splat loads
Until now we would only accept a broadcast load pattern if it is only used by a single vector of instructions. This patch relaxes this, and allows for the broadcast to have more than one user vector, as long as all of its uses are internal to the SLP graph and vectorized. Differential Revision: https://reviews.llvm.org/D121940
1 parent 222adf3 commit f8e1337

File tree

2 files changed

+73
-34
lines changed

2 files changed

+73
-34
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,16 +1167,30 @@ class BoUpSLP {
11671167
/// \returns the score of placing \p V1 and \p V2 in consecutive lanes.
11681168
/// Also, checks if \p V1 and \p V2 are compatible with instructions in \p
11691169
/// MainAltOps.
1170-
static int getShallowScore(Value *V1, Value *V2, const DataLayout &DL,
1171-
ScalarEvolution &SE, int NumLanes,
1172-
ArrayRef<Value *> MainAltOps,
1173-
const TargetTransformInfo *TTI) {
1170+
int getShallowScore(Value *V1, Value *V2, Instruction *U1, Instruction *U2,
1171+
const DataLayout &DL, ScalarEvolution &SE, int NumLanes,
1172+
ArrayRef<Value *> MainAltOps) {
11741173
if (V1 == V2) {
11751174
if (isa<LoadInst>(V1)) {
1175+
// Retruns true if the users of V1 and V2 won't need to be extracted.
1176+
auto AllUsersAreInternal = [NumLanes, U1, U2, this](Value *V1,
1177+
Value *V2) {
1178+
// Bail out if we have too many uses to save compilation time.
1179+
static constexpr unsigned Limit = 8;
1180+
if (V1->hasNUsesOrMore(Limit) || V2->hasNUsesOrMore(Limit))
1181+
return false;
1182+
1183+
auto AllUsersVectorized = [U1, U2, this](Value *V) {
1184+
return llvm::all_of(V->users(), [U1, U2, this](Value *U) {
1185+
return U == U1 || U == U2 || R.getTreeEntry(U) != nullptr;
1186+
});
1187+
};
1188+
return AllUsersVectorized(V1) && AllUsersVectorized(V2);
1189+
};
11761190
// A broadcast of a load can be cheaper on some targets.
1177-
// TODO: For now accept a broadcast load with no other internal uses.
1178-
if (TTI->isLegalBroadcastLoad(V1->getType(), NumLanes) &&
1179-
(int)V1->getNumUses() == NumLanes)
1191+
if (R.TTI->isLegalBroadcastLoad(V1->getType(), NumLanes) &&
1192+
((int)V1->getNumUses() == NumLanes ||
1193+
AllUsersAreInternal(V1, V2)))
11801194
return VLOperands::ScoreSplatLoads;
11811195
}
11821196
return VLOperands::ScoreSplat;
@@ -1354,12 +1368,13 @@ class BoUpSLP {
13541368
/// Look-ahead SLP: Auto-vectorization in the presence of commutative
13551369
/// operations, CGO 2018 by Vasileios Porpodas, Rodrigo C. O. Rocha,
13561370
/// Luís F. W. Góes
1357-
int getScoreAtLevelRec(Value *LHS, Value *RHS, int CurrLevel, int MaxLevel,
1371+
int getScoreAtLevelRec(Value *LHS, Value *RHS, Instruction *U1,
1372+
Instruction *U2, int CurrLevel, int MaxLevel,
13581373
ArrayRef<Value *> MainAltOps) {
13591374

13601375
// Get the shallow score of V1 and V2.
13611376
int ShallowScoreAtThisLevel =
1362-
getShallowScore(LHS, RHS, DL, SE, getNumLanes(), MainAltOps, R.TTI);
1377+
getShallowScore(LHS, RHS, U1, U2, DL, SE, getNumLanes(), MainAltOps);
13631378

13641379
// If reached MaxLevel,
13651380
// or if V1 and V2 are not instructions,
@@ -1402,7 +1417,7 @@ class BoUpSLP {
14021417
// Recursively calculate the cost at each level
14031418
int TmpScore =
14041419
getScoreAtLevelRec(I1->getOperand(OpIdx1), I2->getOperand(OpIdx2),
1405-
CurrLevel + 1, MaxLevel, None);
1420+
I1, I2, CurrLevel + 1, MaxLevel, None);
14061421
// Look for the best score.
14071422
if (TmpScore > VLOperands::ScoreFail && TmpScore > MaxTmpScore) {
14081423
MaxTmpScore = TmpScore;
@@ -1432,8 +1447,10 @@ class BoUpSLP {
14321447
int getLookAheadScore(Value *LHS, Value *RHS, ArrayRef<Value *> MainAltOps,
14331448
int Lane, unsigned OpIdx, unsigned Idx,
14341449
bool &IsUsed) {
1435-
int Score =
1436-
getScoreAtLevelRec(LHS, RHS, 1, LookAheadMaxDepth, MainAltOps);
1450+
// Keep track of the instruction stack as we recurse into the operands
1451+
// during the look-ahead score exploration.
1452+
int Score = getScoreAtLevelRec(LHS, RHS, /*U1=*/nullptr, /*U2=*/nullptr,
1453+
1, LookAheadMaxDepth, MainAltOps);
14371454
if (Score) {
14381455
int SplatScore = getSplatScore(Lane, OpIdx, Idx);
14391456
if (Score <= -SplatScore) {

llvm/test/Transforms/SLPVectorizer/X86/lookahead.ll

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -781,28 +781,50 @@ entry:
781781

782782
; Same as splat_loads() but the splat load has internal uses in the slp graph.
783783
define double @splat_loads_with_internal_uses(double *%array1, double *%array2, double *%ptrA, double *%ptrB) {
784-
; CHECK-LABEL: @splat_loads_with_internal_uses(
785-
; CHECK-NEXT: entry:
786-
; CHECK-NEXT: [[GEP_1_0:%.*]] = getelementptr inbounds double, double* [[ARRAY1:%.*]], i64 0
787-
; CHECK-NEXT: [[GEP_2_0:%.*]] = getelementptr inbounds double, double* [[ARRAY2:%.*]], i64 0
788-
; CHECK-NEXT: [[TMP0:%.*]] = bitcast double* [[GEP_1_0]] to <2 x double>*
789-
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x double>, <2 x double>* [[TMP0]], align 8
790-
; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[GEP_2_0]] to <2 x double>*
791-
; CHECK-NEXT: [[TMP3:%.*]] = load <2 x double>, <2 x double>* [[TMP2]], align 8
792-
; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <2 x double> [[TMP3]], <2 x double> poison, <2 x i32> <i32 1, i32 0>
793-
; CHECK-NEXT: [[TMP4:%.*]] = fmul <2 x double> [[TMP1]], [[SHUFFLE]]
794-
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x double> [[SHUFFLE]], i32 1
795-
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x double> poison, double [[TMP5]], i32 0
796-
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x double> [[SHUFFLE]], i32 0
797-
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP7]], i32 1
798-
; CHECK-NEXT: [[TMP9:%.*]] = fmul <2 x double> [[TMP1]], [[TMP8]]
799-
; CHECK-NEXT: [[TMP10:%.*]] = fadd <2 x double> [[TMP4]], [[TMP9]]
800-
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP5]], i32 1
801-
; CHECK-NEXT: [[TMP12:%.*]] = fsub <2 x double> [[TMP10]], [[TMP11]]
802-
; CHECK-NEXT: [[TMP13:%.*]] = extractelement <2 x double> [[TMP12]], i32 0
803-
; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[TMP12]], i32 1
804-
; CHECK-NEXT: [[RES:%.*]] = fadd double [[TMP13]], [[TMP14]]
805-
; CHECK-NEXT: ret double [[RES]]
784+
; SSE-LABEL: @splat_loads_with_internal_uses(
785+
; SSE-NEXT: entry:
786+
; SSE-NEXT: [[GEP_1_0:%.*]] = getelementptr inbounds double, double* [[ARRAY1:%.*]], i64 0
787+
; SSE-NEXT: [[GEP_2_0:%.*]] = getelementptr inbounds double, double* [[ARRAY2:%.*]], i64 0
788+
; SSE-NEXT: [[TMP0:%.*]] = bitcast double* [[GEP_1_0]] to <2 x double>*
789+
; SSE-NEXT: [[TMP1:%.*]] = load <2 x double>, <2 x double>* [[TMP0]], align 8
790+
; SSE-NEXT: [[TMP2:%.*]] = bitcast double* [[GEP_2_0]] to <2 x double>*
791+
; SSE-NEXT: [[TMP3:%.*]] = load <2 x double>, <2 x double>* [[TMP2]], align 8
792+
; SSE-NEXT: [[SHUFFLE:%.*]] = shufflevector <2 x double> [[TMP3]], <2 x double> poison, <2 x i32> <i32 1, i32 0>
793+
; SSE-NEXT: [[TMP4:%.*]] = fmul <2 x double> [[TMP1]], [[SHUFFLE]]
794+
; SSE-NEXT: [[TMP5:%.*]] = extractelement <2 x double> [[SHUFFLE]], i32 1
795+
; SSE-NEXT: [[TMP6:%.*]] = insertelement <2 x double> poison, double [[TMP5]], i32 0
796+
; SSE-NEXT: [[TMP7:%.*]] = extractelement <2 x double> [[SHUFFLE]], i32 0
797+
; SSE-NEXT: [[TMP8:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP7]], i32 1
798+
; SSE-NEXT: [[TMP9:%.*]] = fmul <2 x double> [[TMP1]], [[TMP8]]
799+
; SSE-NEXT: [[TMP10:%.*]] = fadd <2 x double> [[TMP4]], [[TMP9]]
800+
; SSE-NEXT: [[TMP11:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP5]], i32 1
801+
; SSE-NEXT: [[TMP12:%.*]] = fsub <2 x double> [[TMP10]], [[TMP11]]
802+
; SSE-NEXT: [[TMP13:%.*]] = extractelement <2 x double> [[TMP12]], i32 0
803+
; SSE-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[TMP12]], i32 1
804+
; SSE-NEXT: [[RES:%.*]] = fadd double [[TMP13]], [[TMP14]]
805+
; SSE-NEXT: ret double [[RES]]
806+
;
807+
; AVX-LABEL: @splat_loads_with_internal_uses(
808+
; AVX-NEXT: entry:
809+
; AVX-NEXT: [[GEP_1_0:%.*]] = getelementptr inbounds double, double* [[ARRAY1:%.*]], i64 0
810+
; AVX-NEXT: [[GEP_2_0:%.*]] = getelementptr inbounds double, double* [[ARRAY2:%.*]], i64 0
811+
; AVX-NEXT: [[GEP_2_1:%.*]] = getelementptr inbounds double, double* [[ARRAY2]], i64 1
812+
; AVX-NEXT: [[LD_2_0:%.*]] = load double, double* [[GEP_2_0]], align 8
813+
; AVX-NEXT: [[LD_2_1:%.*]] = load double, double* [[GEP_2_1]], align 8
814+
; AVX-NEXT: [[TMP0:%.*]] = bitcast double* [[GEP_1_0]] to <2 x double>*
815+
; AVX-NEXT: [[TMP1:%.*]] = load <2 x double>, <2 x double>* [[TMP0]], align 8
816+
; AVX-NEXT: [[TMP2:%.*]] = insertelement <2 x double> poison, double [[LD_2_0]], i32 0
817+
; AVX-NEXT: [[TMP3:%.*]] = insertelement <2 x double> [[TMP2]], double [[LD_2_0]], i32 1
818+
; AVX-NEXT: [[TMP4:%.*]] = fmul <2 x double> [[TMP1]], [[TMP3]]
819+
; AVX-NEXT: [[TMP5:%.*]] = insertelement <2 x double> poison, double [[LD_2_1]], i32 0
820+
; AVX-NEXT: [[TMP6:%.*]] = insertelement <2 x double> [[TMP5]], double [[LD_2_1]], i32 1
821+
; AVX-NEXT: [[TMP7:%.*]] = fmul <2 x double> [[TMP1]], [[TMP6]]
822+
; AVX-NEXT: [[TMP8:%.*]] = fadd <2 x double> [[TMP4]], [[TMP7]]
823+
; AVX-NEXT: [[TMP9:%.*]] = fsub <2 x double> [[TMP8]], [[TMP3]]
824+
; AVX-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[TMP9]], i32 0
825+
; AVX-NEXT: [[TMP11:%.*]] = extractelement <2 x double> [[TMP9]], i32 1
826+
; AVX-NEXT: [[RES:%.*]] = fadd double [[TMP10]], [[TMP11]]
827+
; AVX-NEXT: ret double [[RES]]
806828
;
807829
entry:
808830
%gep_1_0 = getelementptr inbounds double, double* %array1, i64 0

0 commit comments

Comments
 (0)