Skip to content

Commit a87a803

Browse files
committed
Init implement TU select
1 parent 770e55a commit a87a803

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,9 @@ class VPInstruction : public VPRecipeWithIRFlags {
12621262
// operand). Only generates scalar values (either for the first lane only or
12631263
// for all lanes, depending on its uses).
12641264
PtrAdd,
1265+
// Selects elements from two vectors (second and third operand) based on a
1266+
// condition vector (first operand) and a pivot index (fourth operand).
1267+
MergeUntilPivot,
12651268
};
12661269

12671270
private:

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
144144
case VPInstruction::FirstOrderRecurrenceSplice:
145145
case VPInstruction::LogicalAnd:
146146
case VPInstruction::PtrAdd:
147+
case VPInstruction::MergeUntilPivot:
147148
return false;
148149
default:
149150
return true;
@@ -673,7 +674,18 @@ Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) {
673674
}
674675
return NewPhi;
675676
}
677+
case VPInstruction::MergeUntilPivot: {
678+
assert(Part == 0 && "No unrolling expected for predicated vectorization.");
679+
Value *Cond = State.get(getOperand(0), Part);
680+
Value *OnTrue = State.get(getOperand(1), Part);
681+
Value *OnFalse = State.get(getOperand(2), Part);
682+
Value *Pivot = State.get(getOperand(3), VPIteration(0, 0));
683+
assert(Pivot->getType()->isIntegerTy() && "Pivot should be an integer.");
676684

685+
return Builder.CreateIntrinsic(Intrinsic::vp_merge, {OnTrue->getType()},
686+
{Cond, OnTrue, OnFalse, Pivot}, nullptr,
687+
Name);
688+
}
677689
default:
678690
llvm_unreachable("Unsupported opcode for instruction");
679691
}
@@ -764,6 +776,9 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
764776
case VPInstruction::BranchOnCond:
765777
case VPInstruction::ResumePhi:
766778
return true;
779+
case VPInstruction::MergeUntilPivot:
780+
// Pivot must be an integer.
781+
return Op == getOperand(3);
767782
};
768783
llvm_unreachable("switch should return");
769784
}
@@ -782,6 +797,7 @@ bool VPInstruction::onlyFirstPartUsed(const VPValue *Op) const {
782797
case VPInstruction::BranchOnCount:
783798
case VPInstruction::BranchOnCond:
784799
case VPInstruction::CanonicalIVIncrementForPart:
800+
case VPInstruction::MergeUntilPivot:
785801
return true;
786802
};
787803
llvm_unreachable("switch should return");
@@ -848,6 +864,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
848864
case VPInstruction::PtrAdd:
849865
O << "ptradd";
850866
break;
867+
case VPInstruction::MergeUntilPivot:
868+
O << "merge-until-pivot";
869+
break;
851870
default:
852871
O << Instruction::getOpcodeName(getOpcode());
853872
}

0 commit comments

Comments
 (0)