Skip to content

Commit 3d9abfc

Browse files
committed
Consolidate all IR logic for getting the identity value of a reduction [nfc]
This change merges the three different places (at the IR layer) for finding the identity value of a reduction into a single copy. This depends on several prior commits which fix ommissions and bugs in the distinct copies, but this patch itself should be fully non-functional. As the new comments and naming try to make clear, the identity value is a property of the @llvm.vector.reduce.* intrinsic, not of e.g. the recurrence descriptor. (We still provide an interface for clients using recurrence descriptors, but the implementation simply translates to the intrinsic which each corresponds to.) As a note, the getIntrinsicIdentity API does not support fminnum/fmaxnum or fminimum/fmaximum which is why we still need manual logic (but at least only one copy of manual logic) for those cases.
1 parent c1a8283 commit 3d9abfc

File tree

7 files changed

+70
-113
lines changed

7 files changed

+70
-113
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,6 @@ class RecurrenceDescriptor {
155155
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
156156
static InstDesc isConditionalRdxPattern(RecurKind Kind, Instruction *I);
157157

158-
/// Returns identity corresponding to the RecurrenceKind.
159-
static Value *getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF);
160-
161158
/// Returns the opcode corresponding to the RecurrenceKind.
162159
static unsigned getOpcode(RecurKind Kind);
163160

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,14 @@ RecurKind getMinMaxReductionRecurKind(Intrinsic::ID RdxID);
378378
/// Returns the comparison predicate used when expanding a min/max reduction.
379379
CmpInst::Predicate getMinMaxReductionPredicate(RecurKind RK);
380380

381+
/// Given information about an @llvm.vector.reduce.* intrinsic, return
382+
/// the identity value for the reduction.
383+
Value *getReductionIdentity(Intrinsic::ID RdxID, Type *Ty, FastMathFlags FMF);
384+
385+
/// Given information about an recurrence kind, return the identity
386+
/// for the @llvm.vector.reduce.* used to generate it.
387+
Value *getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF);
388+
381389
/// Returns a Min/Max operation corresponding to MinMaxRecurrenceKind.
382390
/// The Builder's fast-math-flags must be set to propagate the expected values.
383391
Value *createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,52 +1031,6 @@ bool RecurrenceDescriptor::isFixedOrderRecurrence(PHINode *Phi, Loop *TheLoop,
10311031
return true;
10321032
}
10331033

1034-
/// This function returns the identity element (or neutral element) for
1035-
/// the operation K.
1036-
Value *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp,
1037-
FastMathFlags FMF) {
1038-
switch (K) {
1039-
case RecurKind::Xor:
1040-
case RecurKind::Add:
1041-
case RecurKind::Or:
1042-
case RecurKind::Mul:
1043-
case RecurKind::And:
1044-
case RecurKind::FMul:
1045-
case RecurKind::FAdd:
1046-
return ConstantExpr::getBinOpIdentity(getOpcode(K), Tp, false, FMF.noSignedZeros());
1047-
case RecurKind::FMulAdd:
1048-
return ConstantExpr::getBinOpIdentity(Instruction::FAdd, Tp, false, FMF.noSignedZeros());
1049-
case RecurKind::UMin:
1050-
return ConstantInt::get(Tp, -1, true);
1051-
case RecurKind::UMax:
1052-
return ConstantInt::get(Tp, 0);
1053-
case RecurKind::SMin:
1054-
return ConstantInt::get(Tp,
1055-
APInt::getSignedMaxValue(Tp->getIntegerBitWidth()));
1056-
case RecurKind::SMax:
1057-
return ConstantInt::get(Tp,
1058-
APInt::getSignedMinValue(Tp->getIntegerBitWidth()));
1059-
case RecurKind::FMin:
1060-
case RecurKind::FMax:
1061-
assert((FMF.noNaNs() && FMF.noSignedZeros()) &&
1062-
"nnan, nsz is expected to be set for FP min/max reduction.");
1063-
[[fallthrough]];
1064-
case RecurKind::FMinimum:
1065-
case RecurKind::FMaximum: {
1066-
bool Negative = K == RecurKind::FMax || K == RecurKind::FMaximum;
1067-
const fltSemantics &Semantics = Tp->getFltSemantics();
1068-
return !FMF.noInfs()
1069-
? ConstantFP::getInfinity(Tp, Negative)
1070-
: ConstantFP::get(Tp, APFloat::getLargest(Semantics, Negative));
1071-
}
1072-
case RecurKind::IAnyOf:
1073-
case RecurKind::FAnyOf:
1074-
llvm_unreachable("No meaningful identity for recurrence kind");
1075-
default:
1076-
llvm_unreachable("Unknown recurrence kind");
1077-
}
1078-
}
1079-
10801034
unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
10811035
switch (Kind) {
10821036
case RecurKind::Add:

llvm/lib/CodeGen/ExpandVectorPredication.cpp

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -368,52 +368,11 @@ Value *CachingVPExpander::expandPredicationToFPCall(
368368

369369
static Value *getNeutralReductionElement(const VPReductionIntrinsic &VPI,
370370
Type *EltTy) {
371-
bool Negative = false;
372-
unsigned EltBits = EltTy->getScalarSizeInBits();
373-
Intrinsic::ID VID = VPI.getIntrinsicID();
374-
switch (VID) {
375-
default:
376-
llvm_unreachable("Expecting a VP reduction intrinsic");
377-
case Intrinsic::vp_reduce_add:
378-
case Intrinsic::vp_reduce_or:
379-
case Intrinsic::vp_reduce_xor:
380-
case Intrinsic::vp_reduce_umax:
381-
return Constant::getNullValue(EltTy);
382-
case Intrinsic::vp_reduce_mul:
383-
return ConstantInt::get(EltTy, 1, /*IsSigned*/ false);
384-
case Intrinsic::vp_reduce_and:
385-
case Intrinsic::vp_reduce_umin:
386-
return ConstantInt::getAllOnesValue(EltTy);
387-
case Intrinsic::vp_reduce_smin:
388-
return ConstantInt::get(EltTy->getContext(),
389-
APInt::getSignedMaxValue(EltBits));
390-
case Intrinsic::vp_reduce_smax:
391-
return ConstantInt::get(EltTy->getContext(),
392-
APInt::getSignedMinValue(EltBits));
393-
case Intrinsic::vp_reduce_fmax:
394-
case Intrinsic::vp_reduce_fmaximum:
395-
Negative = true;
396-
[[fallthrough]];
397-
case Intrinsic::vp_reduce_fmin:
398-
case Intrinsic::vp_reduce_fminimum: {
399-
bool PropagatesNaN = VID == Intrinsic::vp_reduce_fminimum ||
400-
VID == Intrinsic::vp_reduce_fmaximum;
401-
FastMathFlags Flags = VPI.getFastMathFlags();
402-
const fltSemantics &Semantics = EltTy->getFltSemantics();
403-
return (!Flags.noNaNs() && !PropagatesNaN)
404-
? ConstantFP::getQNaN(EltTy, Negative)
405-
: !Flags.noInfs()
406-
? ConstantFP::getInfinity(EltTy, Negative)
407-
: ConstantFP::get(EltTy,
408-
APFloat::getLargest(Semantics, Negative));
409-
}
410-
case Intrinsic::vp_reduce_fadd:
411-
return ConstantExpr::getBinOpIdentity(
412-
Instruction::FAdd, EltTy, false,
413-
VPI.getFastMathFlags().noSignedZeros());
414-
case Intrinsic::vp_reduce_fmul:
415-
return ConstantFP::get(EltTy, 1.0);
416-
}
371+
Intrinsic::ID RdxID = *VPI.getFunctionalIntrinsicID();
372+
FastMathFlags FMF;
373+
if (isa<FPMathOperator>(VPI))
374+
FMF = VPI.getFastMathFlags();
375+
return getReductionIdentity(RdxID, EltTy, FMF);
417376
}
418377

419378
Value *

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,14 +1207,62 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
12071207
return Builder.CreateSelect(AnyOf, NewVal, InitVal, "rdx.select");
12081208
}
12091209

1210+
Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
1211+
FastMathFlags Flags) {
1212+
bool Negative = false;
1213+
switch (RdxID) {
1214+
default:
1215+
llvm_unreachable("Expecting a reduction intrinsic");
1216+
case Intrinsic::vector_reduce_add:
1217+
case Intrinsic::vector_reduce_mul:
1218+
case Intrinsic::vector_reduce_or:
1219+
case Intrinsic::vector_reduce_xor:
1220+
case Intrinsic::vector_reduce_and:
1221+
case Intrinsic::vector_reduce_fadd:
1222+
case Intrinsic::vector_reduce_fmul: {
1223+
unsigned Opc = getArithmeticReductionInstruction(RdxID);
1224+
return ConstantExpr::getBinOpIdentity(Opc, Ty, false,
1225+
Flags.noSignedZeros());
1226+
}
1227+
case Intrinsic::vector_reduce_umax:
1228+
case Intrinsic::vector_reduce_umin:
1229+
case Intrinsic::vector_reduce_smin:
1230+
case Intrinsic::vector_reduce_smax: {
1231+
Intrinsic::ID ScalarID = getMinMaxReductionIntrinsicOp(RdxID);
1232+
return ConstantExpr::getIntrinsicIdentity(ScalarID, Ty);
1233+
}
1234+
case Intrinsic::vector_reduce_fmax:
1235+
case Intrinsic::vector_reduce_fmaximum:
1236+
Negative = true;
1237+
[[fallthrough]];
1238+
case Intrinsic::vector_reduce_fmin:
1239+
case Intrinsic::vector_reduce_fminimum: {
1240+
bool PropagatesNaN = RdxID == Intrinsic::vector_reduce_fminimum ||
1241+
RdxID == Intrinsic::vector_reduce_fmaximum;
1242+
const fltSemantics &Semantics = Ty->getFltSemantics();
1243+
return (!Flags.noNaNs() && !PropagatesNaN)
1244+
? ConstantFP::getQNaN(Ty, Negative)
1245+
: !Flags.noInfs()
1246+
? ConstantFP::getInfinity(Ty, Negative)
1247+
: ConstantFP::get(Ty, APFloat::getLargest(Semantics, Negative));
1248+
}
1249+
}
1250+
}
1251+
1252+
Value *llvm::getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF) {
1253+
assert((!(K == RecurKind::FMin || K == RecurKind::FMax) ||
1254+
(FMF.noNaNs() && FMF.noSignedZeros())) &&
1255+
"nnan, nsz is expected to be set for FP min/max reduction.");
1256+
Intrinsic::ID RdxID = getReductionIntrinsicID(K);
1257+
return getReductionIdentity(RdxID, Tp, FMF);
1258+
}
1259+
12101260
Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
12111261
RecurKind RdxKind) {
12121262
auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType();
12131263
auto getIdentity = [&]() {
1214-
Intrinsic::ID ID = getReductionIntrinsicID(RdxKind);
1215-
unsigned Opc = getArithmeticReductionInstruction(ID);
1216-
bool NSZ = Builder.getFastMathFlags().noSignedZeros();
1217-
return ConstantExpr::getBinOpIdentity(Opc, SrcVecEltTy, false, NSZ);
1264+
return getRecurrenceIdentity(RdxKind, SrcVecEltTy,
1265+
Builder.getFastMathFlags());
12181266
};
12191267
switch (RdxKind) {
12201268
case RecurKind::Add:
@@ -1249,8 +1297,7 @@ Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
12491297
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
12501298
auto *SrcTy = cast<VectorType>(Src->getType());
12511299
Type *SrcEltTy = SrcTy->getElementType();
1252-
Value *Iden =
1253-
Desc.getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
1300+
Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
12541301
Value *Ops[] = {Iden, Src};
12551302
return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
12561303
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,8 +1840,8 @@ void VPReductionRecipe::execute(VPTransformState &State) {
18401840
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind))
18411841
Start = RdxDesc.getRecurrenceStartValue();
18421842
else
1843-
Start = RdxDesc.getRecurrenceIdentity(Kind, ElementTy,
1844-
RdxDesc.getFastMathFlags());
1843+
Start = llvm::getRecurrenceIdentity(Kind, ElementTy,
1844+
RdxDesc.getFastMathFlags());
18451845
if (State.VF.isVector())
18461846
Start = State.Builder.CreateVectorSplat(VecTy->getElementCount(),
18471847
Start);
@@ -3010,8 +3010,8 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
30103010
Builder.CreateVectorSplat(State.VF, StartV, "minmax.ident");
30113011
}
30123012
} else {
3013-
Iden = RdxDesc.getRecurrenceIdentity(RK, VecTy->getScalarType(),
3014-
RdxDesc.getFastMathFlags());
3013+
Iden = llvm::getRecurrenceIdentity(RK, VecTy->getScalarType(),
3014+
RdxDesc.getFastMathFlags());
30153015

30163016
if (!ScalarPHI) {
30173017
Iden = Builder.CreateVectorSplat(State.VF, Iden);

llvm/unittests/Analysis/IVDescriptorsTest.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,6 @@ for.end:
209209
EXPECT_TRUE(IsRdxPhi);
210210
RecurKind Kind = Rdx.getRecurrenceKind();
211211
EXPECT_EQ(Kind, RecurKind::FMin);
212-
Type *Ty = Phi->getType();
213-
Value *Id = Rdx.getRecurrenceIdentity(Kind, Ty, Rdx.getFastMathFlags());
214-
// Identity value for FP min reduction is +Inf.
215-
EXPECT_EQ(Id, ConstantFP::getInfinity(Ty, false /*Negative*/));
216212
});
217213
}
218214

@@ -261,9 +257,5 @@ for.end:
261257
EXPECT_TRUE(IsRdxPhi);
262258
RecurKind Kind = Rdx.getRecurrenceKind();
263259
EXPECT_EQ(Kind, RecurKind::FMax);
264-
Type *Ty = Phi->getType();
265-
Value *Id = Rdx.getRecurrenceIdentity(Kind, Ty, Rdx.getFastMathFlags());
266-
// Identity value for FP max reduction is -Inf.
267-
EXPECT_EQ(Id, ConstantFP::getInfinity(Ty, true /*Negative*/));
268260
});
269261
}

0 commit comments

Comments
 (0)