@@ -731,7 +731,7 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) {
731
731
}
732
732
733
733
// / VP Intrinsics whose vector operands are both splat values may be simplified
734
- // / into the scalar version of the operation and the result is splatted. This
734
+ // / into the scalar version of the operation and the result splatted. This
735
735
// / can lead to scalarization down the line.
736
736
bool VectorCombine::scalarizeVPIntrinsic (Instruction &I) {
737
737
if (!isa<VPIntrinsic>(I))
@@ -758,15 +758,8 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
758
758
return false ;
759
759
760
760
// Check to make sure we support scalarization of the intrinsic
761
- std::set<Intrinsic::ID> SupportedIntrinsics (
762
- {Intrinsic::vp_add, Intrinsic::vp_sub, Intrinsic::vp_mul,
763
- Intrinsic::vp_ashr, Intrinsic::vp_lshr, Intrinsic::vp_shl,
764
- Intrinsic::vp_or, Intrinsic::vp_and, Intrinsic::vp_xor,
765
- Intrinsic::vp_fadd, Intrinsic::vp_fsub, Intrinsic::vp_fmul,
766
- Intrinsic::vp_sdiv, Intrinsic::vp_udiv, Intrinsic::vp_srem,
767
- Intrinsic::vp_urem});
768
761
Intrinsic::ID IntrID = VPI.getIntrinsicID ();
769
- if (!SupportedIntrinsics. count (IntrID))
762
+ if (!VPBinOpIntrinsic::isVPBinOp (IntrID))
770
763
return false ;
771
764
772
765
// Calculate cost of splatting both operands into vectors and the vector
@@ -785,14 +778,39 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
785
778
InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost (Attrs, CostKind);
786
779
InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
787
780
781
+ // Determine scalar opcode
782
+ std::optional<unsigned > FunctionalOpcode =
783
+ VPI.getFunctionalOpcode ();
784
+ bool ScalarIsIntr = false ;
785
+ Intrinsic::ID ScalarIntrID;
786
+ if (!FunctionalOpcode) {
787
+ // Explicitly handle supported instructions (i.e. those that return
788
+ // isVPBinOp above, that do not have functional nor constrained opcodes due
789
+ // their intrinsic definitions.
790
+ DenseMap<Intrinsic::ID, Intrinsic::ID> VPIntrToIntr (
791
+ {{Intrinsic::vp_smax, Intrinsic::smax},
792
+ {Intrinsic::vp_smin, Intrinsic::smin},
793
+ {Intrinsic::vp_umax, Intrinsic::umax},
794
+ {Intrinsic::vp_umin, Intrinsic::umin},
795
+ {Intrinsic::vp_copysign, Intrinsic::copysign},
796
+ {Intrinsic::vp_minnum, Intrinsic::minnum},
797
+ {Intrinsic::vp_maxnum, Intrinsic::maxnum}});
798
+ ScalarIsIntr = true ;
799
+ assert (VPIntrToIntr.contains (IntrID) &&
800
+ " Unable to determine scalar opcode" );
801
+ ScalarIntrID = VPIntrToIntr[IntrID];
802
+ }
803
+
788
804
// Calculate cost of scalarizing
789
- std::optional<unsigned > ScalarOpcodeOpt =
790
- VPIntrinsic::getFunctionalOpcodeForVP (IntrID);
791
- assert (ScalarOpcodeOpt && " Unable to determine scalar opcode" );
792
- unsigned ScalarOpcode = *ScalarOpcodeOpt;
805
+ InstructionCost ScalarOpCost = 0 ;
806
+ if (ScalarIsIntr) {
807
+ IntrinsicCostAttributes Attrs (ScalarIntrID, VecTy->getScalarType (), Args);
808
+ ScalarOpCost = TTI.getIntrinsicInstrCost (Attrs, CostKind);
809
+ } else {
810
+ ScalarOpCost =
811
+ TTI.getArithmeticInstrCost (*FunctionalOpcode, VecTy->getScalarType ());
812
+ }
793
813
794
- InstructionCost ScalarOpCost =
795
- TTI.getArithmeticInstrCost (ScalarOpcode, VecTy->getScalarType ());
796
814
// The existing splats may be kept around if other instructions use them.
797
815
InstructionCost CostToKeepSplats =
798
816
(SplatCost * !Op0->hasOneUse ()) + (SplatCost * !Op1->hasOneUse ());
@@ -814,13 +832,19 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
814
832
bool IsKnownNonZeroVL = isKnownNonZero (EVL, DL, 0 , &AC, &VPI, &DT);
815
833
bool MustHaveNonZeroVL =
816
834
IntrID == Intrinsic::vp_sdiv || IntrID == Intrinsic::vp_udiv ||
817
- IntrID == Intrinsic::vp_srem || IntrID == Intrinsic::vp_urem;
835
+ IntrID == Intrinsic::vp_srem || IntrID == Intrinsic::vp_urem ||
836
+ IntrID == Intrinsic::vp_fdiv || IntrID == Intrinsic::vp_frem;
818
837
819
838
if ((MustHaveNonZeroVL && IsKnownNonZeroVL) || !MustHaveNonZeroVL) {
820
- replaceValue (VPI, *Builder.CreateVectorSplat (
821
- EC, Builder.CreateBinOp (
822
- (Instruction::BinaryOps)ScalarOpcode,
823
- getSplatValue (Op0), getSplatValue (Op1))));
839
+ Value *ScalarOp0 = getSplatValue (Op0);
840
+ Value *ScalarOp1 = getSplatValue (Op1);
841
+ Value *ScalarVal =
842
+ ScalarIsIntr
843
+ ? Builder.CreateIntrinsic (VecTy->getScalarType (), ScalarIntrID,
844
+ {ScalarOp0, ScalarOp1})
845
+ : Builder.CreateBinOp ((Instruction::BinaryOps)(*FunctionalOpcode),
846
+ ScalarOp0, ScalarOp1);
847
+ replaceValue (VPI, *Builder.CreateVectorSplat (EC, ScalarVal));
824
848
return true ;
825
849
}
826
850
return false ;
0 commit comments