Skip to content

Commit 440a8ad

Browse files
committed
[VPlan] Use VPIRFlags to manage FMFs for ComputeReductionResult (NFC).
Manage fast-math flags using VPIRFlags from VPInstruciton, in inline with other VPInstructions. With this change, we now print the correctly flags for ComputeReductionResult, other than that NFC.
1 parent 8fb09c8 commit 440a8ad

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9576,8 +9576,13 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
95769576
Builder.createNaryOp(VPInstruction::ComputeFindLastIVResult,
95779577
{PhiR, Start, NewExitingVPV}, ExitDL);
95789578
} else {
9579-
FinalReductionResult = Builder.createNaryOp(
9580-
VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
9579+
VPIRFlags Flags = RecurrenceDescriptor::isFloatingPointRecurrenceKind(
9580+
RdxDesc.getRecurrenceKind())
9581+
? VPIRFlags(RdxDesc.getFastMathFlags())
9582+
: VPIRFlags();
9583+
FinalReductionResult =
9584+
Builder.createNaryOp(VPInstruction::ComputeReductionResult,
9585+
{PhiR, NewExitingVPV}, Flags, ExitDL);
95819586
}
95829587
// Update all users outside the vector region.
95839588
OrigExitingVPV->replaceUsesWithIf(

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,10 @@ Value *VPInstruction::generate(VPTransformState &State) {
648648
for (unsigned Part = 0; Part < UF; ++Part)
649649
RdxParts[Part] = State.get(getOperand(1 + Part), PhiR->isInLoop());
650650

651+
IRBuilderBase::FastMathFlagGuard FMFG(Builder);
652+
if (hasFastMathFlags())
653+
Builder.setFastMathFlags(getFastMathFlags());
654+
651655
// If the vector reduction can be performed in a smaller type, we truncate
652656
// then extend the loop exit value to enable InstCombine to evaluate the
653657
// entire expression in the smaller type.
@@ -663,8 +667,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
663667
ReducedPartRdx = RdxParts[UF - 1];
664668
} else {
665669
// Floating-point operations should have some FMF to enable the reduction.
666-
IRBuilderBase::FastMathFlagGuard FMFG(Builder);
667-
Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
668670
for (unsigned Part = 1; Part < UF; ++Part) {
669671
Value *RdxPart = RdxParts[Part];
670672
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK))
@@ -684,9 +686,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
684686
// TODO: Support in-order reductions based on the recurrence descriptor.
685687
// All ops in the reduction inherit fast-math-flags from the recurrence
686688
// descriptor.
687-
IRBuilderBase::FastMathFlagGuard FMFG(Builder);
688-
Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
689-
690689
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
691690
ReducedPartRdx =
692691
createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi);
@@ -1599,7 +1598,8 @@ bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
15991598
Opcode == Instruction::FSub || Opcode == Instruction::FNeg ||
16001599
Opcode == Instruction::FDiv || Opcode == Instruction::FRem ||
16011600
Opcode == Instruction::FCmp || Opcode == Instruction::Select ||
1602-
Opcode == VPInstruction::WideIVStep;
1601+
Opcode == VPInstruction::WideIVStep ||
1602+
Opcode == VPInstruction::ComputeReductionResult;
16031603
case OperationType::NonNegOp:
16041604
return Opcode == Instruction::ZExt;
16051605
break;

llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ define float @print_reduction(i64 %n, ptr noalias %y) {
3434
; CHECK-NEXT: Successor(s): middle.block
3535
; CHECK-EMPTY:
3636
; CHECK-NEXT: middle.block:
37-
; CHECK-NEXT: EMIT vp<[[RED_RES:%.+]]> = compute-reduction-result ir<%red>, ir<%red.next>
37+
; CHECK-NEXT: EMIT vp<[[RED_RES:%.+]]> = compute-reduction-result fast ir<%red>, ir<%red.next>
3838
; CHECK-NEXT: EMIT vp<[[RED_EX:%.+]]> = extract-last-element vp<[[RED_RES]]>
3939
; CHECK-NEXT: EMIT vp<[[CMP:%.+]]> = icmp eq ir<%n>, vp<[[VTC]]>
4040
; CHECK-NEXT: EMIT branch-on-cond vp<[[CMP]]>
@@ -102,7 +102,7 @@ define void @print_reduction_with_invariant_store(i64 %n, ptr noalias %y, ptr no
102102
; CHECK-NEXT: Successor(s): middle.block
103103
; CHECK-EMPTY:
104104
; CHECK-NEXT: middle.block:
105-
; CHECK-NEXT: EMIT vp<[[RED_RES:.+]]> = compute-reduction-result ir<%red>, ir<%red.next>
105+
; CHECK-NEXT: EMIT vp<[[RED_RES:.+]]> = compute-reduction-result fast ir<%red>, ir<%red.next>
106106
; CHECK-NEXT: CLONE store vp<[[RED_RES]]>, ir<%dst>
107107
; CHECK-NEXT: EMIT vp<[[CMP:%.+]]> = icmp eq ir<%n>, vp<[[VTC]]>
108108
; CHECK-NEXT: EMIT branch-on-cond vp<[[CMP]]>
@@ -175,7 +175,7 @@ define float @print_fmuladd_strict(ptr %a, ptr %b, i64 %n) {
175175
; CHECK-NEXT: Successor(s): middle.block
176176
; CHECK-EMPTY:
177177
; CHECK-NEXT: middle.block:
178-
; CHECK-NEXT: EMIT vp<[[RED_RES:%.+]]> = compute-reduction-result ir<%sum.07>, ir<[[MULADD]]>
178+
; CHECK-NEXT: EMIT vp<[[RED_RES:%.+]]> = compute-reduction-result nnan ninf nsz ir<%sum.07>, ir<[[MULADD]]>
179179
; CHECK-NEXT: EMIT vp<[[RED_EX:%.+]]> = extract-last-element vp<[[RED_RES]]>
180180
; CHECK-NEXT: EMIT vp<[[CMP:%.+]]> = icmp eq ir<%n>, vp<[[VTC]]>
181181
; CHECK-NEXT: EMIT branch-on-cond vp<[[CMP]]>

0 commit comments

Comments
 (0)