Skip to content

Commit 16e244e

Browse files
committed
[VectorCombine] Support AND/UREM indices that require freezing.
38b098b limited scalarization to indices that are known non-poison. For certain patterns that restrict the range of an index, we can insert a freeze of the original value, to prevent propagation of poison. Reviewed By: lebedev.ri Differential Revision: https://reviews.llvm.org/D107580 (cherry-picked from c24fc37)
1 parent 43125d9 commit 16e244e

File tree

2 files changed

+100
-22
lines changed

2 files changed

+100
-22
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 92 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -774,21 +774,91 @@ static bool isMemModifiedBetween(BasicBlock::iterator Begin,
774774
});
775775
}
776776

777+
/// Helper class to indicate whether a vector index can be safely scalarized and
778+
/// if a freeze needs to be inserted.
779+
class ScalarizationResult {
780+
enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
781+
782+
StatusTy Status;
783+
Value *ToFreeze;
784+
785+
ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
786+
: Status(Status), ToFreeze(ToFreeze) {}
787+
788+
public:
789+
ScalarizationResult(const ScalarizationResult &Other) = default;
790+
~ScalarizationResult() {
791+
assert(!ToFreeze && "freeze() not called with ToFreeze being set");
792+
}
793+
794+
static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
795+
static ScalarizationResult safe() { return {StatusTy::Safe}; }
796+
static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
797+
return {StatusTy::SafeWithFreeze, ToFreeze};
798+
}
799+
800+
/// Returns true if the index can be scalarize without requiring a freeze.
801+
bool isSafe() const { return Status == StatusTy::Safe; }
802+
/// Returns true if the index cannot be scalarized.
803+
bool isUnsafe() const { return Status == StatusTy::Unsafe; }
804+
/// Returns true if the index can be scalarize, but requires inserting a
805+
/// freeze.
806+
bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
807+
808+
/// Freeze the ToFreeze and update the use in \p User to use it.
809+
void freeze(IRBuilder<> &Builder, Instruction &UserI) {
810+
assert(isSafeWithFreeze() &&
811+
"should only be used when freezing is required");
812+
assert(is_contained(ToFreeze->users(), &UserI) &&
813+
"UserI must be a user of ToFreeze");
814+
IRBuilder<>::InsertPointGuard Guard(Builder);
815+
Builder.SetInsertPoint(cast<Instruction>(&UserI));
816+
Value *Frozen =
817+
Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
818+
for (Use &U : make_early_inc_range((UserI.operands())))
819+
if (U.get() == ToFreeze)
820+
U.set(Frozen);
821+
822+
ToFreeze = nullptr;
823+
}
824+
};
825+
777826
/// Check if it is legal to scalarize a memory access to \p VecTy at index \p
778827
/// Idx. \p Idx must access a valid vector element.
779-
static bool canScalarizeAccess(FixedVectorType *VecTy, Value *Idx,
780-
Instruction *CtxI, AssumptionCache &AC) {
781-
if (auto *C = dyn_cast<ConstantInt>(Idx))
782-
return C->getValue().ult(VecTy->getNumElements());
783-
784-
if (!isGuaranteedNotToBePoison(Idx, &AC))
785-
return false;
828+
static ScalarizationResult canScalarizeAccess(FixedVectorType *VecTy,
829+
Value *Idx, Instruction *CtxI,
830+
AssumptionCache &AC) {
831+
if (auto *C = dyn_cast<ConstantInt>(Idx)) {
832+
if (C->getValue().ult(VecTy->getNumElements()))
833+
return ScalarizationResult::safe();
834+
return ScalarizationResult::unsafe();
835+
}
786836

787-
APInt Zero(Idx->getType()->getScalarSizeInBits(), 0);
788-
APInt MaxElts(Idx->getType()->getScalarSizeInBits(), VecTy->getNumElements());
837+
unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
838+
APInt Zero(IntWidth, 0);
839+
APInt MaxElts(IntWidth, VecTy->getNumElements());
789840
ConstantRange ValidIndices(Zero, MaxElts);
790-
ConstantRange IdxRange = computeConstantRange(Idx, true, &AC, CtxI, 0);
791-
return ValidIndices.contains(IdxRange);
841+
ConstantRange IdxRange(IntWidth, true);
842+
843+
if (isGuaranteedNotToBePoison(Idx, &AC)) {
844+
if (ValidIndices.contains(computeConstantRange(Idx, true, &AC, CtxI, 0)))
845+
return ScalarizationResult::safe();
846+
return ScalarizationResult::unsafe();
847+
}
848+
849+
// If the index may be poison, check if we can insert a freeze before the
850+
// range of the index is restricted.
851+
Value *IdxBase;
852+
ConstantInt *CI;
853+
if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
854+
IdxRange = IdxRange.binaryAnd(CI->getValue());
855+
} else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
856+
IdxRange = IdxRange.urem(CI->getValue());
857+
}
858+
859+
if (ValidIndices.contains(IdxRange))
860+
return ScalarizationResult::safeWithFreeze(IdxBase);
861+
return ScalarizationResult::unsafe();
792862
}
793863

794864
/// The memory operation on a vector of \p ScalarType had alignment of
@@ -836,12 +906,17 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
836906
// modified between, vector type matches store size, and index is inbounds.
837907
if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
838908
!DL.typeSizeEqualsStoreSize(Load->getType()) ||
839-
!canScalarizeAccess(VecTy, Idx, Load, AC) ||
840-
SrcAddr != SI->getPointerOperand()->stripPointerCasts() ||
909+
SrcAddr != SI->getPointerOperand()->stripPointerCasts())
910+
return false;
911+
912+
auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC);
913+
if (ScalarizableIdx.isUnsafe() ||
841914
isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
842915
MemoryLocation::get(SI), AA))
843916
return false;
844917

918+
if (ScalarizableIdx.isSafeWithFreeze())
919+
ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
845920
Value *GEP = Builder.CreateInBoundsGEP(
846921
SI->getValueOperand()->getType(), SI->getPointerOperand(),
847922
{ConstantInt::get(Idx->getType(), 0), Idx});
@@ -912,8 +987,11 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
912987
else if (LastCheckedInst->comesBefore(UI))
913988
LastCheckedInst = UI;
914989

915-
if (!canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC))
990+
auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC);
991+
if (!ScalarIdx.isSafe()) {
992+
// TODO: Freeze index if it is safe to do so.
916993
return false;
994+
}
917995

918996
auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
919997
OriginalCost +=

llvm/test/Transforms/VectorCombine/load-insert-store.ll

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,10 @@ entry:
310310
define void @insert_store_nonconst_index_known_valid_by_and_but_may_be_poison(<16 x i8>* %q, i8 zeroext %s, i32 %idx) {
311311
; CHECK-LABEL: @insert_store_nonconst_index_known_valid_by_and_but_may_be_poison(
312312
; CHECK-NEXT: entry:
313-
; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16
314-
; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i32 [[IDX:%.*]], 7
315-
; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX_CLAMPED]]
316-
; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16
313+
; CHECK-NEXT: [[TMP0:%.*]] = freeze i32 [[IDX:%.*]]
314+
; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = and i32 [[TMP0]], 7
315+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX_CLAMPED]]
316+
; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP1]], align 1
317317
; CHECK-NEXT: ret void
318318
;
319319
entry:
@@ -413,10 +413,10 @@ entry:
413413
define void @insert_store_nonconst_index_known_valid_by_urem_but_may_be_poison(<16 x i8>* %q, i8 zeroext %s, i32 %idx) {
414414
; CHECK-LABEL: @insert_store_nonconst_index_known_valid_by_urem_but_may_be_poison(
415415
; CHECK-NEXT: entry:
416-
; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16
417-
; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i32 [[IDX:%.*]], 16
418-
; CHECK-NEXT: [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX_CLAMPED]]
419-
; CHECK-NEXT: store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16
416+
; CHECK-NEXT: [[TMP0:%.*]] = freeze i32 [[IDX:%.*]]
417+
; CHECK-NEXT: [[IDX_CLAMPED:%.*]] = urem i32 [[TMP0]], 16
418+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX_CLAMPED]]
419+
; CHECK-NEXT: store i8 [[S:%.*]], i8* [[TMP1]], align 1
420420
; CHECK-NEXT: ret void
421421
;
422422
entry:

0 commit comments

Comments
 (0)