Skip to content

[VPlan] Make VPReductionRecipe a VPRecipeWithIRFlags. NFC #130881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

ElvisWang123
Copy link
Contributor

This patch change the parent of the VPReductionRecipe from
VPSingleDefRecipe to VPRecipeWithIRFlags and also print/get/drop/control
flags by the VPRecipeWithIRFlags. This will remove the dependency of the
underlying instruction.

This patch also add a new function setFastMathFlags() to the
VPRecipeWithIRFlags because the entire reduction chain may contains
multiple instructions. And the underlying instruction may not contains
the corresponding flags for this reduction.

Based on #130880. Split from #113903.

@llvmbot
Copy link
Member

llvmbot commented Mar 12, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-vectorizers

Author: Elvis Wang (ElvisWang123)

Changes

This patch change the parent of the VPReductionRecipe from
VPSingleDefRecipe to VPRecipeWithIRFlags and also print/get/drop/control
flags by the VPRecipeWithIRFlags. This will remove the dependency of the
underlying instruction.

This patch also add a new function setFastMathFlags() to the
VPRecipeWithIRFlags because the entire reduction chain may contains
multiple instructions. And the underlying instruction may not contains
the corresponding flags for this reduction.

Based on #130880. Split from #113903.


Full diff: https://github.com/llvm/llvm-project/pull/130881.diff

6 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+3-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+20-8)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+16-10)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/widen-call-with-intrinsic-or-libfunc.ll (+2-2)
  • (modified) llvm/test/Transforms/LoopVectorize/vplan-printing.ll (+3-3)
  • (modified) llvm/unittests/Transforms/Vectorize/VPlanTest.cpp (+16-4)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index cb860a472d8f7..131f2e245c461 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -9799,9 +9799,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
       if (CM.blockNeedsPredicationForAnyReason(BB))
         CondOp = RecipeBuilder.getBlockInMask(BB);
 
-      auto *RedRecipe = new VPReductionRecipe(
-          RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp,
-          CM.useOrderedReductions(RdxDesc), CurrentLinkI->getDebugLoc());
+      auto *RedRecipe =
+          new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp,
+                                CondOp, CM.useOrderedReductions(RdxDesc));
       // Append the recipe to the end of the VPBasicBlock because we need to
       // ensure that it comes after all of it's inputs, including CondOp.
       // Delete CurrentLink as it will be invalid if its operand is replaced
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b1288c42b20f2..46597137a5668 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -614,6 +614,7 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     char AllowReciprocal : 1;
     char AllowContract : 1;
     char ApproxFunc : 1;
+    char Fast : 1;
 
     FastMathFlagsTy(const FastMathFlags &FMF);
   };
@@ -713,6 +714,8 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
            R->getVPDefID() == VPRecipeBase::VPWidenGEPSC ||
            R->getVPDefID() == VPRecipeBase::VPWidenCastSC ||
            R->getVPDefID() == VPRecipeBase::VPWidenIntrinsicSC ||
+           R->getVPDefID() == VPRecipeBase::VPReductionSC ||
+           R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
            R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
            R->getVPDefID() == VPRecipeBase::VPReverseVectorPointerSC ||
            R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
@@ -788,6 +791,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     }
   }
 
+  /// Set fast-math flags for this recipe.
+  void setFastMathFlags(FastMathFlags FMFs) {
+    OpType = OperationType::FPMathOp;
+    this->FMFs = FMFs;
+  }
+
   CmpInst::Predicate getPredicate() const {
     assert(OpType == OperationType::Cmp &&
            "recipe doesn't have a compare predicate");
@@ -2286,7 +2295,7 @@ class VPInterleaveRecipe : public VPRecipeBase {
 /// A recipe to represent inloop reduction operations, performing a reduction on
 /// a vector operand into a scalar value, and adding the result to a chain.
 /// The Operands are {ChainOp, VecOp, [Condition]}.
-class VPReductionRecipe : public VPSingleDefRecipe {
+class VPReductionRecipe : public VPRecipeWithIRFlags {
   /// The recurrence decriptor for the reduction in question.
   const RecurrenceDescriptor &RdxDesc;
   bool IsOrdered;
@@ -2296,29 +2305,32 @@ class VPReductionRecipe : public VPSingleDefRecipe {
 protected:
   VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
                     Instruction *I, ArrayRef<VPValue *> Operands,
-                    VPValue *CondOp, bool IsOrdered, DebugLoc DL)
-      : VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R),
+                    VPValue *CondOp, bool IsOrdered)
+      : VPRecipeWithIRFlags(SC, Operands, *I), RdxDesc(R),
         IsOrdered(IsOrdered) {
     if (CondOp) {
       IsConditional = true;
       addOperand(CondOp);
     }
+    // The inloop reduction may across multiple scalar instruction and the
+    // underlying instruction may not contains the corresponding flags. Set the
+    // flags explicit from the redurrence descriptor.
+    setFastMathFlags(R.getFastMathFlags());
   }
 
 public:
   VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
                     VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
-                    bool IsOrdered, DebugLoc DL = {})
+                    bool IsOrdered)
       : VPReductionRecipe(VPDef::VPReductionSC, R, I,
                           ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
-                          IsOrdered, DL) {}
+                          IsOrdered) {}
 
   ~VPReductionRecipe() override = default;
 
   VPReductionRecipe *clone() override {
     return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(),
-                                 getVecOp(), getCondOp(), IsOrdered,
-                                 getDebugLoc());
+                                 getVecOp(), getCondOp(), IsOrdered);
   }
 
   static inline bool classof(const VPRecipeBase *R) {
@@ -2373,7 +2385,7 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
             VPDef::VPReductionEVLSC, R.getRecurrenceDescriptor(),
             cast_or_null<Instruction>(R.getUnderlyingValue()),
             ArrayRef<VPValue *>({R.getChainOp(), R.getVecOp(), &EVL}), CondOp,
-            R.isOrdered(), R.getDebugLoc()) {}
+            R.isOrdered()) {}
 
   ~VPReductionEVLRecipe() override = default;
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index d154d54c37862..009bc7a23f793 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -360,6 +360,7 @@ FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const {
   assert(OpType == OperationType::FPMathOp &&
          "recipe doesn't have fast math flags");
   FastMathFlags Res;
+  Res.setFast(FMFs.Fast);
   Res.setAllowReassoc(FMFs.AllowReassoc);
   Res.setNoNaNs(FMFs.NoNaNs);
   Res.setNoInfs(FMFs.NoInfs);
@@ -1393,6 +1394,7 @@ VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
   AllowReciprocal = FMF.allowReciprocal();
   AllowContract = FMF.allowContract();
   ApproxFunc = FMF.approxFunc();
+  Fast = FMF.isFast();
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2224,7 +2226,8 @@ void VPReductionRecipe::execute(VPTransformState &State) {
   RecurKind Kind = RdxDesc.getRecurrenceKind();
   // Propagate the fast-math flags carried by the underlying instruction.
   IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
-  State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+  if (hasFastMathFlags())
+    State.Builder.setFastMathFlags(getFastMathFlags());
   State.setDebugLocFrom(getDebugLoc());
   Value *NewVecOp = State.get(getVecOp());
   if (VPValue *Cond = getCondOp()) {
@@ -2275,7 +2278,8 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
   // Propagate the fast-math flags carried by the underlying instruction.
   IRBuilderBase::FastMathFlagGuard FMFGuard(Builder);
   const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
-  Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
+  if (hasFastMathFlags())
+    Builder.setFastMathFlags(getFastMathFlags());
 
   RecurKind Kind = RdxDesc.getRecurrenceKind();
   Value *Prev = State.get(getChainOp(), /*IsScalar*/ true);
@@ -2312,6 +2316,8 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   Type *ElementTy = Ctx.Types.inferScalarType(this);
   auto *VectorTy = cast<VectorType>(toVectorTy(ElementTy, VF));
   unsigned Opcode = RdxDesc.getOpcode();
+  FastMathFlags FMFs =
+      hasFastMathFlags() ? getFastMathFlags() : FastMathFlags();
 
   // TODO: Support any-of and in-loop reductions.
   assert(
@@ -2331,12 +2337,12 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, Ctx.CostKind);
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    return Cost + Ctx.TTI.getMinMaxReductionCost(
-                      Id, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
+    return Cost +
+           Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
   }
 
-  return Cost + Ctx.TTI.getArithmeticReductionCost(
-                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), Ctx.CostKind);
+  return Cost + Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
+                                                   Ctx.CostKind);
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2347,8 +2353,8 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " = ";
   getChainOp()->printAsOperand(O, SlotTracker);
   O << " +";
-  if (isa<FPMathOperator>(getUnderlyingInstr()))
-    O << getUnderlyingInstr()->getFastMathFlags();
+  if (isa_and_nonnull<FPMathOperator>(getUnderlyingValue()))
+    printFlags(O);
   O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   getVecOp()->printAsOperand(O, SlotTracker);
   if (isConditional()) {
@@ -2369,8 +2375,8 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " = ";
   getChainOp()->printAsOperand(O, SlotTracker);
   O << " +";
-  if (isa<FPMathOperator>(getUnderlyingInstr()))
-    O << getUnderlyingInstr()->getFastMathFlags();
+  if (isa_and_nonnull<FPMathOperator>(getUnderlyingValue()))
+    printFlags(O);
   O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
   getVecOp()->printAsOperand(O, SlotTracker);
   O << ", ";
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/widen-call-with-intrinsic-or-libfunc.ll b/llvm/test/Transforms/LoopVectorize/AArch64/widen-call-with-intrinsic-or-libfunc.ll
index a119707bed120..89bb495045e41 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/widen-call-with-intrinsic-or-libfunc.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/widen-call-with-intrinsic-or-libfunc.ll
@@ -26,7 +26,7 @@ target triple = "arm64-apple-ios"
 ; CHECK-NEXT:     vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%gep.src>
 ; CHECK-NEXT:     WIDEN ir<%l> = load vp<[[VEC_PTR]]>
 ; CHECK-NEXT:     WIDEN-CAST ir<%conv> = fpext ir<%l> to double
-; CHECK-NEXT:     WIDEN-CALL ir<%s> = call reassoc nnan ninf nsz arcp contract afn @llvm.sin.f64(ir<%conv>) (using library function: __simd_sin_v2f64)
+; CHECK-NEXT:     WIDEN-CALL ir<%s> = call fast @llvm.sin.f64(ir<%conv>) (using library function: __simd_sin_v2f64)
 ; CHECK-NEXT:     REPLICATE ir<%gep.dst> = getelementptr inbounds ir<%dst>, vp<[[STEPS]]>
 ; CHECK-NEXT:     REPLICATE store ir<%s>, ir<%gep.dst>
 ; CHECK-NEXT:     EMIT vp<[[CAN_IV_NEXT:%.+]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
@@ -72,7 +72,7 @@ target triple = "arm64-apple-ios"
 ; CHECK-NEXT:     vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%gep.src>
 ; CHECK-NEXT:     WIDEN ir<%l> = load vp<[[VEC_PTR]]>
 ; CHECK-NEXT:     WIDEN-CAST ir<%conv> = fpext ir<%l> to double
-; CHECK-NEXT:     WIDEN-INTRINSIC ir<%s> = call reassoc nnan ninf nsz arcp contract afn llvm.sin(ir<%conv>)
+; CHECK-NEXT:     WIDEN-INTRINSIC ir<%s> = call fast llvm.sin(ir<%conv>)
 ; CHECK-NEXT:     REPLICATE ir<%gep.dst> = getelementptr inbounds ir<%dst>, vp<[[STEPS]]>
 ; CHECK-NEXT:     REPLICATE store ir<%s>, ir<%gep.dst>
 ; CHECK-NEXT:     EMIT vp<[[CAN_IV_NEXT:%.+]]> = add nuw vp<[[CAN_IV]]>, vp<[[VFxUF]]>
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
index 00d8de67a3b40..207cb8b4a0d30 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing.ll
@@ -800,7 +800,7 @@ define void @print_fast_math_flags(i64 %n, ptr noalias %y, ptr noalias %x, ptr %
 ; CHECK-NEXT:   vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%gep.y>
 ; CHECK-NEXT:   WIDEN ir<%lv> = load vp<[[VEC_PTR]]>
 ; CHECK-NEXT:   WIDEN ir<%add> = fadd nnan ir<%lv>, ir<1.000000e+00>
-; CHECK-NEXT:   WIDEN ir<%mul> = fmul reassoc nnan ninf nsz arcp contract afn ir<%add>, ir<2.000000e+00>
+; CHECK-NEXT:   WIDEN ir<%mul> = fmul fast ir<%add>, ir<2.000000e+00>
 ; CHECK-NEXT:   WIDEN ir<%div> = fdiv reassoc nsz contract ir<%mul>, ir<2.000000e+00>
 ; CHECK-NEXT:   CLONE ir<%gep.x> = getelementptr inbounds ir<%x>, vp<[[STEPS]]>
 ; CHECK-NEXT:   vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%gep.x>
@@ -1224,8 +1224,8 @@ define void @print_select_with_fastmath_flags(ptr noalias %a, ptr noalias %b, pt
 ; CHECK-NEXT:     vp<[[PTR2:%.+]]> = vector-pointer ir<[[GEP2]]>
 ; CHECK-NEXT:     WIDEN ir<[[LD2:%.+]]> = load vp<[[PTR2]]>
 ; CHECK-NEXT:     WIDEN ir<[[FCMP:%.+]]> = fcmp ogt ir<[[LD1]]>, ir<[[LD2]]>
-; CHECK-NEXT:     WIDEN ir<[[FADD:%.+]]> = fadd reassoc nnan ninf nsz arcp contract afn ir<[[LD1]]>, ir<1.000000e+01>
-; CHECK-NEXT:     WIDEN-SELECT ir<[[SELECT:%.+]]> = select reassoc nnan ninf nsz arcp contract afn ir<[[FCMP]]>, ir<[[FADD]]>, ir<[[LD2]]>
+; CHECK-NEXT:     WIDEN ir<[[FADD:%.+]]> = fadd fast ir<[[LD1]]>, ir<1.000000e+01>
+; CHECK-NEXT:     WIDEN-SELECT ir<[[SELECT:%.+]]> = select fast ir<[[FCMP]]>, ir<[[FADD]]>, ir<[[LD2]]>
 ; CHECK-NEXT:     CLONE ir<[[GEP3:%.+]]> = getelementptr inbounds nuw ir<%a>, vp<[[ST]]>
 ; CHECK-NEXT:     vp<[[PTR3:%.+]]> = vector-pointer ir<[[GEP3]]>
 ; CHECK-NEXT:     WIDEN store vp<[[PTR3]]>, ir<[[SELECT]]>
diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
index f9a85869e3142..da6c490d7af6a 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
@@ -1165,22 +1165,27 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
   }
 
   {
+    auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
+                                          PoisonValue::get(Int32));
     VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
     VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
     VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
-    VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
+    VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
                              VecOp, false);
     EXPECT_FALSE(Recipe.mayHaveSideEffects());
     EXPECT_FALSE(Recipe.mayReadFromMemory());
     EXPECT_FALSE(Recipe.mayWriteToMemory());
     EXPECT_FALSE(Recipe.mayReadOrWriteMemory());
+    delete Add;
   }
 
   {
+    auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
+                                          PoisonValue::get(Int32));
     VPValue *ChainOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1));
     VPValue *VecOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 2));
     VPValue *CondOp = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 3));
-    VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
+    VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
                              VecOp, false);
     VPValue *EVL = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 4));
     VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
@@ -1188,6 +1193,7 @@ TEST_F(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
     EXPECT_FALSE(EVLRecipe.mayReadFromMemory());
     EXPECT_FALSE(EVLRecipe.mayWriteToMemory());
     EXPECT_FALSE(EVLRecipe.mayReadOrWriteMemory());
+    delete Add;
   }
 
   {
@@ -1529,28 +1535,34 @@ TEST_F(VPRecipeTest, dumpRecipeUnnamedVPValuesNotInPlanOrBlock) {
 
 TEST_F(VPRecipeTest, CastVPReductionRecipeToVPUser) {
   IntegerType *Int32 = IntegerType::get(C, 32);
+  auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
+                                        PoisonValue::get(Int32));
   VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
   VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
   VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
-  VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
+  VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
                            VecOp, false);
   EXPECT_TRUE(isa<VPUser>(&Recipe));
   VPRecipeBase *BaseR = &Recipe;
   EXPECT_TRUE(isa<VPUser>(BaseR));
+  delete Add;
 }
 
 TEST_F(VPRecipeTest, CastVPReductionEVLRecipeToVPUser) {
   IntegerType *Int32 = IntegerType::get(C, 32);
+  auto *Add = BinaryOperator::CreateAdd(PoisonValue::get(Int32),
+                                        PoisonValue::get(Int32));
   VPValue *ChainOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 1));
   VPValue *VecOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 2));
   VPValue *CondOp = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 3));
-  VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, ChainOp, CondOp,
+  VPReductionRecipe Recipe(RecurrenceDescriptor(), Add, ChainOp, CondOp,
                            VecOp, false);
   VPValue *EVL = getPlan().getOrAddLiveIn(ConstantInt::get(Int32, 0));
   VPReductionEVLRecipe EVLRecipe(Recipe, *EVL, CondOp);
   EXPECT_TRUE(isa<VPUser>(&EVLRecipe));
   VPRecipeBase *BaseR = &EVLRecipe;
   EXPECT_TRUE(isa<VPUser>(BaseR));
+  delete Add;
 }
 } // namespace
 

Copy link

github-actions bot commented Mar 12, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@alexey-bataev
Copy link
Member

Looks good to me, if I'm not missing anything. @fhahn ?

@@ -2347,8 +2353,8 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " = ";
getChainOp()->printAsOperand(O, SlotTracker);
O << " +";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
if (isa_and_nonnull<FPMathOperator>(getUnderlyingValue()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be needed, printFlags will handle this automatically?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks.

Comment on lines 794 to 798
/// Set fast-math flags for this recipe.
void setFastMathFlags(FastMathFlags FMFs) {
OpType = OperationType::FPMathOp;
this->FMFs = FMFs;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assert that OpType Is already FPMathOp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thnaks.

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would probably be good to make this independent of #130880, even if it requires some printing test updates.

This patch change the parent of the VPReductionRecipe from
VPSingleDefRecipe to VPRecipeWithIRFlags and also print/get/control
flags by the VPRecipeWithIRFlags. This will remove the dependency of the
underlying instruction.

This patch also add a new function `setFastMathFlags()` to the
VPRecipeWithIRFlags because the entire reduction chain may contains
multiple instructions. And the underlying instruction may not contains
the corresponding flags for this reduction.
@ElvisWang123 ElvisWang123 force-pushed the change-VPReductionRecipe-inherit-from-VPRecipeWithIRFlags branch from f3235db to 07b5790 Compare March 14, 2025 13:44
lukel97
lukel97 previously approved these changes Mar 14, 2025
Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lukel97 lukel97 dismissed their stale review March 14, 2025 14:35

It might be better to always set the FMFs from RecurrenceDescriptor

@@ -2369,8 +2372,7 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
O << " = ";
getChainOp()->printAsOperand(O, SlotTracker);
O << " +";
if (isa<FPMathOperator>(getUnderlyingInstr()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not major but if we add back the isa check then we remove the diff for the integer reduction test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add it back the prevent integer reduction test changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we end up with integer reductions with FMFs set?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand an integer RecurrenceDescriptor has all fast math flags set:

// Start with all flags set because we will intersect this with the reduction
// flags from all the reduction operations.
FastMathFlags FMF = FastMathFlags::getFast();

So currently today an integer VPReductionRecipe will use a builder with all fast math flags set:

  IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
  State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but we shouldn't further propagate incorrect flags to VPReductionRecipe. We should probably check if it is a FP reduction on construction, and if it isn't set empty (fast-math) flags.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I think @ElvisWang123 had already done something similar in a previous version of this PR like

     if (isa<FPMathOperator>(I))
       setFastMathFlags(R.getFastMathFlags());

in the constructor. Should that be done as a part of this PR or in a follow up?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be done straight away, I added a suggested edit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed checks, thanks!

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -2369,8 +2372,7 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
O << " = ";
getChainOp()->printAsOperand(O, SlotTracker);
O << " +";
if (isa<FPMathOperator>(getUnderlyingInstr()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we end up with integer reductions with FMFs set?

@@ -167,7 +167,7 @@ define float @print_reduction(i64 %n, ptr noalias %y) {
; CHECK-NEXT: CLONE ir<%arrayidx> = getelementptr inbounds ir<%y>, vp<[[STEPS]]>
; CHECK-NEXT: vp<[[VEC_PTR:%.+]]> = vector-pointer ir<%arrayidx>
; CHECK-NEXT: WIDEN ir<%lv> = load vp<[[VEC_PTR]]>
; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + fast reduce.fadd (ir<%lv>)
; CHECK-NEXT: REDUCE ir<%red.next> = ir<%red> + reassoc nnan ninf nsz arcp contract afn reduce.fadd (ir<%lv>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rebase to latest main? This change shouldn't be needed any longer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebase, fixed.

@@ -2297,12 +2299,13 @@ class VPReductionRecipe : public VPSingleDefRecipe {
VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
Instruction *I, ArrayRef<VPValue *> Operands,
VPValue *CondOp, bool IsOrdered, DebugLoc DL)
: VPSingleDefRecipe(SC, Operands, I, DL), RdxDesc(R),
: VPRecipeWithIRFlags(SC, Operands, R.getFastMathFlags(), DL), RdxDesc(R),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
: VPRecipeWithIRFlags(SC, Operands, R.getFastMathFlags(), DL), RdxDesc(R),
: VPRecipeWithIRFlags(SC, Operands,isa<FPMathOperator>(I) ? R.getFastMathFlags() : FastMathFlags(), DL), RdxDesc(R),

to avoid setting FMFs for non FP operators?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks!

@@ -2369,8 +2372,7 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
O << " = ";
getChainOp()->printAsOperand(O, SlotTracker);
O << " +";
if (isa<FPMathOperator>(getUnderlyingInstr()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be done straight away, I added a suggested edit

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

For the title, something like [VPlan] Make VPReductionRecipe a VPRecipeWithIRFlags. (NFC) may be slightly more compact

@ElvisWang123 ElvisWang123 changed the title [VPlan] Change parent of VPReductionRecipe to VPRecipeWithIRFlags. NFC [VPlan] Make VPReductionRecipe a VPRecipeWithIRFlags. NFC Mar 17, 2025
@ElvisWang123 ElvisWang123 merged commit ed19620 into llvm:main Mar 18, 2025
11 checks passed
@ElvisWang123 ElvisWang123 deleted the change-VPReductionRecipe-inherit-from-VPRecipeWithIRFlags branch March 18, 2025 02:08
pawosm-arm pushed a commit to pawosm-arm/llvm-project that referenced this pull request Apr 10, 2025
This patch change the parent of the VPReductionRecipe from
VPSingleDefRecipe to VPRecipeWithIRFlags and also print/get/drop/control
flags by the VPRecipeWithIRFlags. This will remove the dependency of the
underlying instruction.

This patch also add a new function `setFastMathFlags()` to the
VPRecipeWithIRFlags because the entire reduction chain may contains
multiple instructions. And the underlying instruction may not contains
the corresponding flags for this reduction.

Split from llvm#113903.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants