Skip to content

Commit 315c02a

Browse files
authored
[VPlan] Fix crash with inloop fmuladd reductions with blend (#131154)
When visiting in-loop reduction links, we previously crashed if we had an fmuladd with a blend after it in the chain. This fixes it by lifting the existing blend folding to also handle fmuladd. This also simplifies the code structure slightly for an upcoming patch I want to post to handle in-loop AnyOf reductions. I removed the PhiR->isInLoop() check since it's already guarded at the top of the parent Header->Phis() loop.
1 parent f23bbf6 commit 315c02a

File tree

2 files changed

+102
-20
lines changed

2 files changed

+102
-20
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9713,6 +9713,19 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
97139713
// condition directly.
97149714
VPSingleDefRecipe *PreviousLink = PhiR; // Aka Worklist[0].
97159715
for (VPSingleDefRecipe *CurrentLink : Worklist.getArrayRef().drop_front()) {
9716+
if (auto *Blend = dyn_cast<VPBlendRecipe>(CurrentLink)) {
9717+
assert(Blend->getNumIncomingValues() == 2 &&
9718+
"Blend must have 2 incoming values");
9719+
if (Blend->getIncomingValue(0) == PhiR) {
9720+
Blend->replaceAllUsesWith(Blend->getIncomingValue(1));
9721+
} else {
9722+
assert(Blend->getIncomingValue(1) == PhiR &&
9723+
"PhiR must be an operand of the blend");
9724+
Blend->replaceAllUsesWith(Blend->getIncomingValue(0));
9725+
}
9726+
continue;
9727+
}
9728+
97169729
Instruction *CurrentLinkI = CurrentLink->getUnderlyingInstr();
97179730

97189731
// Index of the first operand which holds a non-mask vector operand.
@@ -9741,20 +9754,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
97419754
LinkVPBB->insert(FMulRecipe, CurrentLink->getIterator());
97429755
VecOp = FMulRecipe;
97439756
} else {
9744-
auto *Blend = dyn_cast<VPBlendRecipe>(CurrentLink);
9745-
if (PhiR->isInLoop() && Blend) {
9746-
assert(Blend->getNumIncomingValues() == 2 &&
9747-
"Blend must have 2 incoming values");
9748-
if (Blend->getIncomingValue(0) == PhiR)
9749-
Blend->replaceAllUsesWith(Blend->getIncomingValue(1));
9750-
else {
9751-
assert(Blend->getIncomingValue(1) == PhiR &&
9752-
"PhiR must be an operand of the blend");
9753-
Blend->replaceAllUsesWith(Blend->getIncomingValue(0));
9754-
}
9755-
continue;
9756-
}
9757-
97589757
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) {
97599758
if (isa<VPWidenRecipe>(CurrentLink)) {
97609759
assert(isa<CmpInst>(CurrentLinkI) &&

llvm/test/Transforms/LoopVectorize/reduction-inloop.ll

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,89 @@ for.end:
11061106
ret float %muladd
11071107
}
11081108

1109+
define float @reduction_fmuladd_blend(ptr %a, ptr %b, i64 %n, i1 %c) {
1110+
; CHECK-LABEL: @reduction_fmuladd_blend(
1111+
; CHECK-NEXT: entry:
1112+
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[N:%.*]], 4
1113+
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
1114+
; CHECK: vector.ph:
1115+
; CHECK-NEXT: [[N_VEC:%.*]] = and i64 [[N]], -4
1116+
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i1> poison, i1 [[C:%.*]], i64 0
1117+
; CHECK-NEXT: [[TMP0:%.*]] = xor <4 x i1> [[BROADCAST_SPLATINSERT]], <i1 true, i1 poison, i1 poison, i1 poison>
1118+
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i1> [[TMP0]], <4 x i1> poison, <4 x i32> zeroinitializer
1119+
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
1120+
; CHECK: vector.body:
1121+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
1122+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi float [ 0.000000e+00, [[VECTOR_PH]] ], [ [[TMP7:%.*]], [[VECTOR_BODY]] ]
1123+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds float, ptr [[A:%.*]], i64 [[INDEX]]
1124+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x float>, ptr [[TMP2]], align 4
1125+
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds float, ptr [[B:%.*]], i64 [[INDEX]]
1126+
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <4 x float>, ptr [[TMP3]], align 4
1127+
; CHECK-NEXT: [[TMP4:%.*]] = fmul <4 x float> [[WIDE_LOAD]], [[WIDE_LOAD1]]
1128+
; CHECK-NEXT: [[TMP5:%.*]] = select <4 x i1> [[TMP1]], <4 x float> [[TMP4]], <4 x float> splat (float -0.000000e+00)
1129+
; CHECK-NEXT: [[TMP6:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP5]])
1130+
; CHECK-NEXT: [[TMP7]] = fadd float [[TMP6]], [[VEC_PHI]]
1131+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
1132+
; CHECK-NEXT: [[TMP8:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
1133+
; CHECK-NEXT: br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
1134+
; CHECK: middle.block:
1135+
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[N]], [[N_VEC]]
1136+
; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_END:%.*]], label [[SCALAR_PH]]
1137+
; CHECK: scalar.ph:
1138+
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
1139+
; CHECK-NEXT: [[BC_MERGE_RDX:%.*]] = phi float [ [[TMP7]], [[MIDDLE_BLOCK]] ], [ 0.000000e+00, [[ENTRY]] ]
1140+
; CHECK-NEXT: br label [[FOR_BODY:%.*]]
1141+
; CHECK: loop.header:
1142+
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], [[LATCH:%.*]] ]
1143+
; CHECK-NEXT: [[SUM:%.*]] = phi float [ [[BC_MERGE_RDX]], [[SCALAR_PH]] ], [ [[SUM_NEXT:%.*]], [[LATCH]] ]
1144+
; CHECK-NEXT: br i1 [[C]], label [[FOO:%.*]], label [[BAR:%.*]]
1145+
; CHECK: if:
1146+
; CHECK-NEXT: br label [[LATCH]]
1147+
; CHECK: else:
1148+
; CHECK-NEXT: [[ARRAYIDX2:%.*]] = getelementptr inbounds float, ptr [[B]], i64 [[IV]]
1149+
; CHECK-NEXT: [[TMP9:%.*]] = load float, ptr [[ARRAYIDX2]], align 4
1150+
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds float, ptr [[A]], i64 [[IV]]
1151+
; CHECK-NEXT: [[TMP10:%.*]] = load float, ptr [[ARRAYIDX]], align 4
1152+
; CHECK-NEXT: [[MULADD:%.*]] = tail call float @llvm.fmuladd.f32(float [[TMP10]], float [[TMP9]], float [[SUM]])
1153+
; CHECK-NEXT: br label [[LATCH]]
1154+
; CHECK: latch:
1155+
; CHECK-NEXT: [[SUM_NEXT]] = phi float [ [[SUM]], [[FOO]] ], [ [[MULADD]], [[BAR]] ]
1156+
; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1
1157+
; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i64 [[IV_NEXT]], [[N]]
1158+
; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_END]], label [[FOR_BODY]], !llvm.loop [[LOOP39:![0-9]+]]
1159+
; CHECK: exit:
1160+
; CHECK-NEXT: [[SUM_NEXT_LCSSA:%.*]] = phi float [ [[SUM_NEXT]], [[LATCH]] ], [ [[TMP7]], [[MIDDLE_BLOCK]] ]
1161+
; CHECK-NEXT: ret float [[SUM_NEXT_LCSSA]]
1162+
;
1163+
entry:
1164+
br label %loop.header
1165+
1166+
loop.header:
1167+
%iv = phi i64 [ 0, %entry ], [ %iv.next, %latch ]
1168+
%sum = phi float [ 0.000000e+00, %entry ], [ %sum.next, %latch ]
1169+
%arrayidx = getelementptr inbounds float, ptr %a, i64 %iv
1170+
%0 = load float, ptr %arrayidx, align 4
1171+
%arrayidx2 = getelementptr inbounds float, ptr %b, i64 %iv
1172+
%1 = load float, ptr %arrayidx2, align 4
1173+
br i1 %c, label %if, label %else
1174+
1175+
if:
1176+
br label %latch
1177+
1178+
else:
1179+
%muladd = tail call float @llvm.fmuladd.f32(float %0, float %1, float %sum)
1180+
br label %latch
1181+
1182+
latch:
1183+
%sum.next = phi float [ %sum, %if ], [ %muladd, %else ]
1184+
%iv.next = add nuw nsw i64 %iv, 1
1185+
%exitcond.not = icmp eq i64 %iv.next, %n
1186+
br i1 %exitcond.not, label %exit, label %loop.header
1187+
1188+
exit:
1189+
ret float %sum.next
1190+
}
1191+
11091192
; This case was previously failing verification due to the mask for the
11101193
; reduction being created after the reduction.
11111194
define i32 @predicated_not_dominates_reduction(ptr nocapture noundef readonly %h, i32 noundef %i) {
@@ -1130,7 +1213,7 @@ define i32 @predicated_not_dominates_reduction(ptr nocapture noundef readonly %h
11301213
; CHECK-NEXT: [[TMP7]] = add i32 [[TMP6]], [[VEC_PHI]]
11311214
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
11321215
; CHECK-NEXT: [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
1133-
; CHECK-NEXT: br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP38:![0-9]+]]
1216+
; CHECK-NEXT: br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP40:![0-9]+]]
11341217
; CHECK: middle.block:
11351218
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[I]], [[N_VEC]]
11361219
; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_END7:%.*]], label [[SCALAR_PH]]
@@ -1157,7 +1240,7 @@ define i32 @predicated_not_dominates_reduction(ptr nocapture noundef readonly %h
11571240
; CHECK-NEXT: [[G_1]] = phi i32 [ [[ADD]], [[IF_THEN]] ], [ [[G_016]], [[FOR_BODY2]] ]
11581241
; CHECK-NEXT: [[INC6]] = add nuw nsw i32 [[A_117]], 1
11591242
; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i32 [[INC6]], [[I]]
1160-
; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_END7]], label [[FOR_BODY2]], !llvm.loop [[LOOP39:![0-9]+]]
1243+
; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_END7]], label [[FOR_BODY2]], !llvm.loop [[LOOP41:![0-9]+]]
11611244
; CHECK: for.end7:
11621245
; CHECK-NEXT: [[G_1_LCSSA:%.*]] = phi i32 [ [[G_1]], [[FOR_INC5]] ], [ [[TMP7]], [[MIDDLE_BLOCK]] ]
11631246
; CHECK-NEXT: ret i32 [[G_1_LCSSA]]
@@ -1219,7 +1302,7 @@ define i32 @predicated_not_dominates_reduction_twoadd(ptr nocapture noundef read
12191302
; CHECK-NEXT: [[TMP11]] = add i32 [[TMP10]], [[TMP8]]
12201303
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
12211304
; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
1222-
; CHECK-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP40:![0-9]+]]
1305+
; CHECK-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP42:![0-9]+]]
12231306
; CHECK: middle.block:
12241307
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 [[I]], [[N_VEC]]
12251308
; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_END7:%.*]], label [[SCALAR_PH]]
@@ -1247,7 +1330,7 @@ define i32 @predicated_not_dominates_reduction_twoadd(ptr nocapture noundef read
12471330
; CHECK-NEXT: [[G_1]] = phi i32 [ [[ADD]], [[IF_THEN]] ], [ [[G_016]], [[FOR_BODY2]] ]
12481331
; CHECK-NEXT: [[INC6]] = add nuw nsw i32 [[A_117]], 1
12491332
; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i32 [[INC6]], [[I]]
1250-
; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_END7]], label [[FOR_BODY2]], !llvm.loop [[LOOP41:![0-9]+]]
1333+
; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_END7]], label [[FOR_BODY2]], !llvm.loop [[LOOP43:![0-9]+]]
12511334
; CHECK: for.end7:
12521335
; CHECK-NEXT: [[G_1_LCSSA:%.*]] = phi i32 [ [[G_1]], [[FOR_INC5]] ], [ [[TMP11]], [[MIDDLE_BLOCK]] ]
12531336
; CHECK-NEXT: ret i32 [[G_1_LCSSA]]
@@ -1362,7 +1445,7 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
13621445
; CHECK-NEXT: [[TMP48]] = add i32 [[VEC_PHI]], [[TMP47]]
13631446
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
13641447
; CHECK-NEXT: [[TMP49:%.*]] = icmp eq i32 [[INDEX_NEXT]], 1000
1365-
; CHECK-NEXT: br i1 [[TMP49]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP42:![0-9]+]]
1448+
; CHECK-NEXT: br i1 [[TMP49]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP44:![0-9]+]]
13661449
; CHECK: middle.block:
13671450
; CHECK-NEXT: br i1 true, label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
13681451
; CHECK: scalar.ph:
@@ -1377,7 +1460,7 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
13771460
; CHECK: if.then:
13781461
; CHECK-NEXT: br label [[FOR_INC]]
13791462
; CHECK: for.inc:
1380-
; CHECK-NEXT: br i1 poison, label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP43:![0-9]+]]
1463+
; CHECK-NEXT: br i1 poison, label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP45:![0-9]+]]
13811464
;
13821465
entry:
13831466
br label %for.body

0 commit comments

Comments
 (0)