@@ -774,21 +774,91 @@ static bool isMemModifiedBetween(BasicBlock::iterator Begin,
774
774
});
775
775
}
776
776
777
+ // / Helper class to indicate whether a vector index can be safely scalarized and
778
+ // / if a freeze needs to be inserted.
779
+ class ScalarizationResult {
780
+ enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
781
+
782
+ StatusTy Status;
783
+ Value *ToFreeze;
784
+
785
+ ScalarizationResult (StatusTy Status, Value *ToFreeze = nullptr )
786
+ : Status(Status), ToFreeze(ToFreeze) {}
787
+
788
+ public:
789
+ ScalarizationResult (const ScalarizationResult &Other) = default ;
790
+ ~ScalarizationResult () {
791
+ assert (!ToFreeze && " freeze() not called with ToFreeze being set" );
792
+ }
793
+
794
+ static ScalarizationResult unsafe () { return {StatusTy::Unsafe}; }
795
+ static ScalarizationResult safe () { return {StatusTy::Safe}; }
796
+ static ScalarizationResult safeWithFreeze (Value *ToFreeze) {
797
+ return {StatusTy::SafeWithFreeze, ToFreeze};
798
+ }
799
+
800
+ // / Returns true if the index can be scalarize without requiring a freeze.
801
+ bool isSafe () const { return Status == StatusTy::Safe; }
802
+ // / Returns true if the index cannot be scalarized.
803
+ bool isUnsafe () const { return Status == StatusTy::Unsafe; }
804
+ // / Returns true if the index can be scalarize, but requires inserting a
805
+ // / freeze.
806
+ bool isSafeWithFreeze () const { return Status == StatusTy::SafeWithFreeze; }
807
+
808
+ // / Freeze the ToFreeze and update the use in \p User to use it.
809
+ void freeze (IRBuilder<> &Builder, Instruction &UserI) {
810
+ assert (isSafeWithFreeze () &&
811
+ " should only be used when freezing is required" );
812
+ assert (is_contained (ToFreeze->users (), &UserI) &&
813
+ " UserI must be a user of ToFreeze" );
814
+ IRBuilder<>::InsertPointGuard Guard (Builder);
815
+ Builder.SetInsertPoint (cast<Instruction>(&UserI));
816
+ Value *Frozen =
817
+ Builder.CreateFreeze (ToFreeze, ToFreeze->getName () + " .frozen" );
818
+ for (Use &U : make_early_inc_range ((UserI.operands ())))
819
+ if (U.get () == ToFreeze)
820
+ U.set (Frozen);
821
+
822
+ ToFreeze = nullptr ;
823
+ }
824
+ };
825
+
777
826
// / Check if it is legal to scalarize a memory access to \p VecTy at index \p
778
827
// / Idx. \p Idx must access a valid vector element.
779
- static bool canScalarizeAccess (FixedVectorType *VecTy, Value *Idx,
780
- Instruction *CtxI, AssumptionCache &AC) {
781
- if (auto *C = dyn_cast<ConstantInt>(Idx))
782
- return C->getValue ().ult (VecTy->getNumElements ());
783
-
784
- if (!isGuaranteedNotToBePoison (Idx, &AC))
785
- return false ;
828
+ static ScalarizationResult canScalarizeAccess (FixedVectorType *VecTy,
829
+ Value *Idx, Instruction *CtxI,
830
+ AssumptionCache &AC) {
831
+ if (auto *C = dyn_cast<ConstantInt>(Idx)) {
832
+ if (C->getValue ().ult (VecTy->getNumElements ()))
833
+ return ScalarizationResult::safe ();
834
+ return ScalarizationResult::unsafe ();
835
+ }
786
836
787
- APInt Zero (Idx->getType ()->getScalarSizeInBits (), 0 );
788
- APInt MaxElts (Idx->getType ()->getScalarSizeInBits (), VecTy->getNumElements ());
837
+ unsigned IntWidth = Idx->getType ()->getScalarSizeInBits ();
838
+ APInt Zero (IntWidth, 0 );
839
+ APInt MaxElts (IntWidth, VecTy->getNumElements ());
789
840
ConstantRange ValidIndices (Zero, MaxElts);
790
- ConstantRange IdxRange = computeConstantRange (Idx, true , &AC, CtxI, 0 );
791
- return ValidIndices.contains (IdxRange);
841
+ ConstantRange IdxRange (IntWidth, true );
842
+
843
+ if (isGuaranteedNotToBePoison (Idx, &AC)) {
844
+ if (ValidIndices.contains (computeConstantRange (Idx, true , &AC, CtxI, 0 )))
845
+ return ScalarizationResult::safe ();
846
+ return ScalarizationResult::unsafe ();
847
+ }
848
+
849
+ // If the index may be poison, check if we can insert a freeze before the
850
+ // range of the index is restricted.
851
+ Value *IdxBase;
852
+ ConstantInt *CI;
853
+ if (match (Idx, m_And (m_Value (IdxBase), m_ConstantInt (CI)))) {
854
+ IdxRange = IdxRange.binaryAnd (CI->getValue ());
855
+ } else if (match (Idx, m_URem (m_Value (IdxBase), m_ConstantInt (CI)))) {
856
+ IdxRange = IdxRange.urem (CI->getValue ());
857
+ }
858
+
859
+ if (ValidIndices.contains (IdxRange))
860
+ return ScalarizationResult::safeWithFreeze (IdxBase);
861
+ return ScalarizationResult::unsafe ();
792
862
}
793
863
794
864
// / The memory operation on a vector of \p ScalarType had alignment of
@@ -836,12 +906,17 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
836
906
// modified between, vector type matches store size, and index is inbounds.
837
907
if (!Load->isSimple () || Load->getParent () != SI->getParent () ||
838
908
!DL.typeSizeEqualsStoreSize (Load->getType ()) ||
839
- !canScalarizeAccess (VecTy, Idx, Load, AC) ||
840
- SrcAddr != SI->getPointerOperand ()->stripPointerCasts () ||
909
+ SrcAddr != SI->getPointerOperand ()->stripPointerCasts ())
910
+ return false ;
911
+
912
+ auto ScalarizableIdx = canScalarizeAccess (VecTy, Idx, Load, AC);
913
+ if (ScalarizableIdx.isUnsafe () ||
841
914
isMemModifiedBetween (Load->getIterator (), SI->getIterator (),
842
915
MemoryLocation::get (SI), AA))
843
916
return false ;
844
917
918
+ if (ScalarizableIdx.isSafeWithFreeze ())
919
+ ScalarizableIdx.freeze (Builder, *cast<Instruction>(Idx));
845
920
Value *GEP = Builder.CreateInBoundsGEP (
846
921
SI->getValueOperand ()->getType (), SI->getPointerOperand (),
847
922
{ConstantInt::get (Idx->getType (), 0 ), Idx});
@@ -912,8 +987,11 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
912
987
else if (LastCheckedInst->comesBefore (UI))
913
988
LastCheckedInst = UI;
914
989
915
- if (!canScalarizeAccess (FixedVT, UI->getOperand (1 ), &I, AC))
990
+ auto ScalarIdx = canScalarizeAccess (FixedVT, UI->getOperand (1 ), &I, AC);
991
+ if (!ScalarIdx.isSafe ()) {
992
+ // TODO: Freeze index if it is safe to do so.
916
993
return false ;
994
+ }
917
995
918
996
auto *Index = dyn_cast<ConstantInt>(UI->getOperand (1 ));
919
997
OriginalCost +=
0 commit comments