Skip to content

Commit ab3a301

Browse files
committed
Utils support
1 parent 5082fea commit ab3a301

File tree

4 files changed

+203
-0
lines changed

4 files changed

+203
-0
lines changed

llvm/include/llvm/IR/IRBuilder.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,49 +746,68 @@ class IRBuilderBase {
746746
private:
747747
CallInst *getReductionIntrinsic(Intrinsic::ID ID, Value *Src);
748748

749+
// Helper function for creating VP reduce intrinsic call.
750+
CallInst *getReductionIntrinsic(Intrinsic::ID ID, Value *Acc, Value *Src,
751+
Value *Mask, Value *EVL);
752+
749753
public:
750754
/// Create a sequential vector fadd reduction intrinsic of the source vector.
751755
/// The first parameter is a scalar accumulator value. An unordered reduction
752756
/// can be created by adding the reassoc fast-math flag to the resulting
753757
/// sequential reduction.
754758
CallInst *CreateFAddReduce(Value *Acc, Value *Src);
759+
CallInst *CreateFAddReduce(Value *Acc, Value *Src, Value *EVL,
760+
Value *Mask = nullptr);
755761

756762
/// Create a sequential vector fmul reduction intrinsic of the source vector.
757763
/// The first parameter is a scalar accumulator value. An unordered reduction
758764
/// can be created by adding the reassoc fast-math flag to the resulting
759765
/// sequential reduction.
760766
CallInst *CreateFMulReduce(Value *Acc, Value *Src);
767+
CallInst *CreateFMulReduce(Value *Acc, Value *Src, Value *EVL,
768+
Value *Mask = nullptr);
761769

762770
/// Create a vector int add reduction intrinsic of the source vector.
763771
CallInst *CreateAddReduce(Value *Src);
772+
CallInst *CreateAddReduce(Value *Src, Value *EVL, Value *Mask = nullptr);
764773

765774
/// Create a vector int mul reduction intrinsic of the source vector.
766775
CallInst *CreateMulReduce(Value *Src);
776+
CallInst *CreateMulReduce(Value *Src, Value *EVL, Value *Mask = nullptr);
767777

768778
/// Create a vector int AND reduction intrinsic of the source vector.
769779
CallInst *CreateAndReduce(Value *Src);
780+
CallInst *CreateAndReduce(Value *Src, Value *EVL, Value *Mask = nullptr);
770781

771782
/// Create a vector int OR reduction intrinsic of the source vector.
772783
CallInst *CreateOrReduce(Value *Src);
784+
CallInst *CreateOrReduce(Value *Src, Value *EVL, Value *Mask = nullptr);
773785

774786
/// Create a vector int XOR reduction intrinsic of the source vector.
775787
CallInst *CreateXorReduce(Value *Src);
788+
CallInst *CreateXorReduce(Value *Src, Value *EVL, Value *Mask = nullptr);
776789

777790
/// Create a vector integer max reduction intrinsic of the source
778791
/// vector.
779792
CallInst *CreateIntMaxReduce(Value *Src, bool IsSigned = false);
793+
CallInst *CreateIntMaxReduce(Value *Src, Value *EVL, bool IsSigned = false,
794+
Value *Mask = nullptr);
780795

781796
/// Create a vector integer min reduction intrinsic of the source
782797
/// vector.
783798
CallInst *CreateIntMinReduce(Value *Src, bool IsSigned = false);
799+
CallInst *CreateIntMinReduce(Value *Src, Value *EVL, bool IsSigned = false,
800+
Value *Mask = nullptr);
784801

785802
/// Create a vector float max reduction intrinsic of the source
786803
/// vector.
787804
CallInst *CreateFPMaxReduce(Value *Src);
805+
CallInst *CreateFPMaxReduce(Value *Src, Value *EVL, Value *Mask = nullptr);
788806

789807
/// Create a vector float min reduction intrinsic of the source
790808
/// vector.
791809
CallInst *CreateFPMinReduce(Value *Src);
810+
CallInst *CreateFPMinReduce(Value *Src, Value *EVL, Value *Mask = nullptr);
792811

793812
/// Create a vector float maximum reduction intrinsic of the source
794813
/// vector. This variant follows the NaN and signed zero semantic of

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,9 @@ Value *getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op,
394394
/// Fast-math-flags are propagated using the IRBuilder's setting.
395395
Value *createSimpleTargetReduction(IRBuilderBase &B, Value *Src,
396396
RecurKind RdxKind);
397+
Value *createSimpleTargetReduction(IRBuilderBase &B, Value *Src,
398+
RecurKind RdxKind, Value *EVL,
399+
Value *Mask = nullptr);
397400

398401
/// Create a target reduction of the given vector \p Src for a reduction of the
399402
/// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is
@@ -414,6 +417,9 @@ Value *createTargetReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc,
414417
Value *createOrderedReduction(IRBuilderBase &B,
415418
const RecurrenceDescriptor &Desc, Value *Src,
416419
Value *Start);
420+
Value *createOrderedReduction(IRBuilderBase &B,
421+
const RecurrenceDescriptor &Desc, Value *Src,
422+
Value *Start, Value *EVL, Value *Mask = nullptr);
417423

418424
/// Get the intersection (logical and) of all of the potential IR flags
419425
/// of each scalar operation (VL) that will be converted into a vector (I).

llvm/lib/IR/IRBuilder.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,20 @@ CallInst *IRBuilderBase::getReductionIntrinsic(Intrinsic::ID ID, Value *Src) {
414414
return CreateCall(Decl, Ops);
415415
}
416416

417+
CallInst *IRBuilderBase::getReductionIntrinsic(Intrinsic::ID ID, Value *Acc,
418+
Value *Src, Value *Mask,
419+
Value *EVL) {
420+
Module *M = GetInsertBlock()->getParent()->getParent();
421+
auto *SrcTy = cast<VectorType>(Src->getType());
422+
EVL = CreateIntCast(EVL, getInt32Ty(), /*isSigned=*/false);
423+
if (!Mask)
424+
Mask = CreateVectorSplat(SrcTy->getElementCount(), getTrue());
425+
Value *Ops[] = {Acc, Src, Mask, EVL};
426+
Type *Tys[] = {SrcTy};
427+
auto Decl = Intrinsic::getDeclaration(M, ID, Tys);
428+
return CreateCall(Decl, Ops);
429+
}
430+
417431
CallInst *IRBuilderBase::CreateFAddReduce(Value *Acc, Value *Src) {
418432
Module *M = GetInsertBlock()->getParent()->getParent();
419433
Value *Ops[] = {Acc, Src};
@@ -422,6 +436,11 @@ CallInst *IRBuilderBase::CreateFAddReduce(Value *Acc, Value *Src) {
422436
return CreateCall(Decl, Ops);
423437
}
424438

439+
CallInst *IRBuilderBase::CreateFAddReduce(Value *Acc, Value *Src, Value *EVL,
440+
Value *Mask) {
441+
return getReductionIntrinsic(Intrinsic::vp_reduce_fadd, Acc, Src, Mask, EVL);
442+
}
443+
425444
CallInst *IRBuilderBase::CreateFMulReduce(Value *Acc, Value *Src) {
426445
Module *M = GetInsertBlock()->getParent()->getParent();
427446
Value *Ops[] = {Acc, Src};
@@ -430,46 +449,149 @@ CallInst *IRBuilderBase::CreateFMulReduce(Value *Acc, Value *Src) {
430449
return CreateCall(Decl, Ops);
431450
}
432451

452+
CallInst *IRBuilderBase::CreateFMulReduce(Value *Acc, Value *Src, Value *EVL,
453+
Value *Mask) {
454+
return getReductionIntrinsic(Intrinsic::vp_reduce_fmul, Acc, Src, Mask, EVL);
455+
}
456+
433457
CallInst *IRBuilderBase::CreateAddReduce(Value *Src) {
434458
return getReductionIntrinsic(Intrinsic::vector_reduce_add, Src);
435459
}
436460

461+
CallInst *IRBuilderBase::CreateAddReduce(Value *Src, Value *EVL, Value *Mask) {
462+
auto *SrcTy = cast<VectorType>(Src->getType());
463+
auto *EltTy = SrcTy->getElementType();
464+
return getReductionIntrinsic(Intrinsic::vp_reduce_add,
465+
ConstantInt::get(EltTy, 0), Src, Mask, EVL);
466+
}
467+
437468
CallInst *IRBuilderBase::CreateMulReduce(Value *Src) {
438469
return getReductionIntrinsic(Intrinsic::vector_reduce_mul, Src);
439470
}
440471

472+
CallInst *IRBuilderBase::CreateMulReduce(Value *Src, Value *EVL, Value *Mask) {
473+
auto *SrcTy = cast<VectorType>(Src->getType());
474+
auto *EltTy = SrcTy->getElementType();
475+
return getReductionIntrinsic(Intrinsic::vp_reduce_mul,
476+
ConstantInt::get(EltTy, 1), Src, Mask, EVL);
477+
}
478+
441479
CallInst *IRBuilderBase::CreateAndReduce(Value *Src) {
442480
return getReductionIntrinsic(Intrinsic::vector_reduce_and, Src);
443481
}
444482

483+
CallInst *IRBuilderBase::CreateAndReduce(Value *Src, Value *EVL, Value *Mask) {
484+
auto *SrcTy = cast<VectorType>(Src->getType());
485+
auto *EltTy = SrcTy->getElementType();
486+
return getReductionIntrinsic(Intrinsic::vp_reduce_and,
487+
Constant::getAllOnesValue(EltTy), Src, Mask,
488+
EVL);
489+
}
490+
445491
CallInst *IRBuilderBase::CreateOrReduce(Value *Src) {
446492
return getReductionIntrinsic(Intrinsic::vector_reduce_or, Src);
447493
}
448494

495+
CallInst *IRBuilderBase::CreateOrReduce(Value *Src, Value *EVL, Value *Mask) {
496+
auto *SrcTy = cast<VectorType>(Src->getType());
497+
auto *EltTy = SrcTy->getElementType();
498+
return getReductionIntrinsic(Intrinsic::vp_reduce_or,
499+
ConstantInt::get(EltTy, 0), Src, Mask, EVL);
500+
}
501+
449502
CallInst *IRBuilderBase::CreateXorReduce(Value *Src) {
450503
return getReductionIntrinsic(Intrinsic::vector_reduce_xor, Src);
451504
}
452505

506+
CallInst *IRBuilderBase::CreateXorReduce(Value *Src, Value *EVL, Value *Mask) {
507+
auto *SrcTy = cast<VectorType>(Src->getType());
508+
auto *EltTy = SrcTy->getElementType();
509+
return getReductionIntrinsic(Intrinsic::vp_reduce_xor,
510+
ConstantInt::get(EltTy, 0), Src, Mask, EVL);
511+
}
512+
453513
CallInst *IRBuilderBase::CreateIntMaxReduce(Value *Src, bool IsSigned) {
454514
auto ID =
455515
IsSigned ? Intrinsic::vector_reduce_smax : Intrinsic::vector_reduce_umax;
456516
return getReductionIntrinsic(ID, Src);
457517
}
458518

519+
CallInst *IRBuilderBase::CreateIntMaxReduce(Value *Src, Value *EVL,
520+
bool IsSigned, Value *Mask) {
521+
auto *SrcTy = cast<VectorType>(Src->getType());
522+
auto *EltTy = SrcTy->getElementType();
523+
return getReductionIntrinsic(
524+
IsSigned ? Intrinsic::vp_reduce_smax : Intrinsic::vp_reduce_umax,
525+
IsSigned ? ConstantInt::get(EltTy, APInt::getSignedMinValue(
526+
EltTy->getIntegerBitWidth()))
527+
: ConstantInt::get(EltTy, 0),
528+
Src, Mask, EVL);
529+
}
530+
459531
CallInst *IRBuilderBase::CreateIntMinReduce(Value *Src, bool IsSigned) {
460532
auto ID =
461533
IsSigned ? Intrinsic::vector_reduce_smin : Intrinsic::vector_reduce_umin;
462534
return getReductionIntrinsic(ID, Src);
463535
}
464536

537+
CallInst *IRBuilderBase::CreateIntMinReduce(Value *Src, Value *EVL,
538+
bool IsSigned, Value *Mask) {
539+
auto *SrcTy = cast<VectorType>(Src->getType());
540+
auto *EltTy = SrcTy->getElementType();
541+
return getReductionIntrinsic(
542+
IsSigned ? Intrinsic::vp_reduce_smin : Intrinsic::vp_reduce_umin,
543+
IsSigned ? ConstantInt::get(EltTy, APInt::getSignedMaxValue(
544+
EltTy->getIntegerBitWidth()))
545+
: Constant::getAllOnesValue(EltTy),
546+
Src, Mask, EVL);
547+
}
548+
465549
CallInst *IRBuilderBase::CreateFPMaxReduce(Value *Src) {
466550
return getReductionIntrinsic(Intrinsic::vector_reduce_fmax, Src);
467551
}
468552

553+
CallInst *IRBuilderBase::CreateFPMaxReduce(Value *Src, Value *EVL,
554+
Value *Mask) {
555+
auto *SrcTy = cast<VectorType>(Src->getType());
556+
auto *EltTy = SrcTy->getElementType();
557+
FastMathFlags FMF = getFastMathFlags();
558+
Value *Neutral;
559+
if (FMF.noNaNs())
560+
Neutral = FMF.noInfs()
561+
? ConstantFP::get(
562+
EltTy, APFloat::getLargest(EltTy->getFltSemantics(),
563+
/*Negative=*/true))
564+
: ConstantFP::getInfinity(EltTy, true);
565+
else
566+
Neutral = ConstantFP::getQNaN(EltTy, /*Negative=*/true);
567+
568+
return getReductionIntrinsic(Intrinsic::vp_reduce_fmax, Neutral, Src, Mask,
569+
EVL);
570+
}
571+
469572
CallInst *IRBuilderBase::CreateFPMinReduce(Value *Src) {
470573
return getReductionIntrinsic(Intrinsic::vector_reduce_fmin, Src);
471574
}
472575

576+
CallInst *IRBuilderBase::CreateFPMinReduce(Value *Src, Value *EVL,
577+
Value *Mask) {
578+
auto *SrcTy = cast<VectorType>(Src->getType());
579+
auto *EltTy = SrcTy->getElementType();
580+
FastMathFlags FMF = getFastMathFlags();
581+
Value *Neutral;
582+
if (FMF.noNaNs())
583+
Neutral = FMF.noInfs()
584+
? ConstantFP::get(
585+
EltTy, APFloat::getLargest(EltTy->getFltSemantics(),
586+
/*Negative=*/false))
587+
: ConstantFP::getInfinity(EltTy, false);
588+
else
589+
Neutral = ConstantFP::getQNaN(EltTy, /*Negative=*/false);
590+
591+
return getReductionIntrinsic(Intrinsic::vp_reduce_fmin, Neutral, Src, Mask,
592+
EVL);
593+
}
594+
473595
CallInst *IRBuilderBase::CreateFPMaximumReduce(Value *Src) {
474596
return getReductionIntrinsic(Intrinsic::vector_reduce_fmaximum, Src);
475597
}

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,6 +1192,48 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src,
11921192
}
11931193
}
11941194

1195+
Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src,
1196+
RecurKind RdxKind, Value *EVL,
1197+
Value *Mask) {
1198+
auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType();
1199+
switch (RdxKind) {
1200+
case RecurKind::Add:
1201+
return Builder.CreateAddReduce(Src, EVL, Mask);
1202+
case RecurKind::Mul:
1203+
return Builder.CreateMulReduce(Src, EVL, Mask);
1204+
case RecurKind::And:
1205+
return Builder.CreateAndReduce(Src, EVL, Mask);
1206+
case RecurKind::Or:
1207+
return Builder.CreateOrReduce(Src, EVL, Mask);
1208+
case RecurKind::Xor:
1209+
return Builder.CreateXorReduce(Src, EVL, Mask);
1210+
case RecurKind::FMulAdd:
1211+
case RecurKind::FAdd:
1212+
return Builder.CreateFAddReduce(ConstantFP::getNegativeZero(SrcVecEltTy),
1213+
Src, EVL, Mask);
1214+
case RecurKind::FMul:
1215+
return Builder.CreateFMulReduce(ConstantFP::get(SrcVecEltTy, 1.0), Src, EVL,
1216+
Mask);
1217+
case RecurKind::SMax:
1218+
return Builder.CreateIntMaxReduce(Src, EVL, true, Mask);
1219+
case RecurKind::SMin:
1220+
return Builder.CreateIntMinReduce(Src, EVL, true, Mask);
1221+
case RecurKind::UMax:
1222+
return Builder.CreateIntMaxReduce(Src, EVL, false, Mask);
1223+
case RecurKind::UMin:
1224+
return Builder.CreateIntMinReduce(Src, EVL, false, Mask);
1225+
case RecurKind::FMax:
1226+
return Builder.CreateFPMaxReduce(Src, EVL, Mask);
1227+
case RecurKind::FMin:
1228+
return Builder.CreateFPMinReduce(Src, EVL, Mask);
1229+
case RecurKind::FMinimum:
1230+
case RecurKind::FMaximum:
1231+
assert(0 && "FMaximum/FMinimum reduction VP intrinsic is not supported.");
1232+
default:
1233+
llvm_unreachable("Unhandled opcode");
1234+
}
1235+
}
1236+
11951237
Value *llvm::createTargetReduction(IRBuilderBase &B,
11961238
const RecurrenceDescriptor &Desc, Value *Src,
11971239
PHINode *OrigPhi) {
@@ -1220,6 +1262,20 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
12201262
return B.CreateFAddReduce(Start, Src);
12211263
}
12221264

1265+
Value *llvm::createOrderedReduction(IRBuilderBase &B,
1266+
const RecurrenceDescriptor &Desc,
1267+
Value *Src, Value *Start, Value *EVL,
1268+
Value *Mask) {
1269+
assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
1270+
Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
1271+
"Unexpected reduction kind");
1272+
assert(Src->getType()->isVectorTy() && "Expected a vector type");
1273+
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
1274+
assert(EVL->getType()->isIntegerTy() && "Expected a integer type");
1275+
1276+
return B.CreateFAddReduce(Start, Src, EVL, Mask);
1277+
}
1278+
12231279
void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,
12241280
bool IncludeWrapFlags) {
12251281
auto *VecOp = dyn_cast<Instruction>(I);

0 commit comments

Comments
 (0)