Skip to content

Commit 67e726a

Browse files
[SLP]Transform stores + reverse to strided stores with stride -1, if profitable.
Adds transformation of consecutive vector store + reverse to strided stores with stride -1, if it is profitable Reviewers: RKSimon, preames Reviewed By: RKSimon Pull Request: #90464
1 parent 803e03f commit 67e726a

File tree

2 files changed

+71
-34
lines changed

2 files changed

+71
-34
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7934,6 +7934,33 @@ void BoUpSLP::transformNodes() {
79347934
}
79357935
break;
79367936
}
7937+
case Instruction::Store: {
7938+
Type *ScalarTy =
7939+
cast<StoreInst>(E.getMainOp())->getValueOperand()->getType();
7940+
auto *VecTy = FixedVectorType::get(ScalarTy, E.Scalars.size());
7941+
Align CommonAlignment = computeCommonAlignment<StoreInst>(E.Scalars);
7942+
// Check if profitable to represent consecutive load + reverse as strided
7943+
// load with stride -1.
7944+
if (isReverseOrder(E.ReorderIndices) &&
7945+
TTI->isLegalStridedLoadStore(VecTy, CommonAlignment)) {
7946+
SmallVector<int> Mask;
7947+
inversePermutation(E.ReorderIndices, Mask);
7948+
auto *BaseSI = cast<StoreInst>(E.Scalars.back());
7949+
InstructionCost OriginalVecCost =
7950+
TTI->getMemoryOpCost(Instruction::Store, VecTy, BaseSI->getAlign(),
7951+
BaseSI->getPointerAddressSpace(), CostKind,
7952+
TTI::OperandValueInfo()) +
7953+
::getShuffleCost(*TTI, TTI::SK_Reverse, VecTy, Mask, CostKind);
7954+
InstructionCost StridedCost = TTI->getStridedMemoryOpCost(
7955+
Instruction::Store, VecTy, BaseSI->getPointerOperand(),
7956+
/*VariableMask=*/false, CommonAlignment, CostKind, BaseSI);
7957+
if (StridedCost < OriginalVecCost)
7958+
// Strided load is more profitable than consecutive load + reverse -
7959+
// transform the node to strided load.
7960+
E.State = TreeEntry::StridedVectorize;
7961+
}
7962+
break;
7963+
}
79377964
default:
79387965
break;
79397966
}
@@ -9466,11 +9493,22 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
94669493
cast<StoreInst>(IsReorder ? VL[E->ReorderIndices.front()] : VL0);
94679494
auto GetVectorCost = [=](InstructionCost CommonCost) {
94689495
// We know that we can merge the stores. Calculate the cost.
9469-
TTI::OperandValueInfo OpInfo = getOperandInfo(E->getOperand(0));
9470-
return TTI->getMemoryOpCost(Instruction::Store, VecTy, BaseSI->getAlign(),
9471-
BaseSI->getPointerAddressSpace(), CostKind,
9472-
OpInfo) +
9473-
CommonCost;
9496+
InstructionCost VecStCost;
9497+
if (E->State == TreeEntry::StridedVectorize) {
9498+
Align CommonAlignment =
9499+
computeCommonAlignment<StoreInst>(UniqueValues.getArrayRef());
9500+
VecStCost = TTI->getStridedMemoryOpCost(
9501+
Instruction::Store, VecTy, BaseSI->getPointerOperand(),
9502+
/*VariableMask=*/false, CommonAlignment, CostKind);
9503+
} else {
9504+
assert(E->State == TreeEntry::Vectorize &&
9505+
"Expected either strided or consecutive stores.");
9506+
TTI::OperandValueInfo OpInfo = getOperandInfo(E->getOperand(0));
9507+
VecStCost = TTI->getMemoryOpCost(
9508+
Instruction::Store, VecTy, BaseSI->getAlign(),
9509+
BaseSI->getPointerAddressSpace(), CostKind, OpInfo);
9510+
}
9511+
return VecStCost + CommonCost;
94749512
};
94759513
SmallVector<Value *> PointerOps(VL.size());
94769514
for (auto [I, V] : enumerate(VL)) {
@@ -12398,7 +12436,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1239812436
bool IsReverseOrder = isReverseOrder(E->ReorderIndices);
1239912437
auto FinalShuffle = [&](Value *V, const TreeEntry *E, VectorType *VecTy) {
1240012438
ShuffleInstructionBuilder ShuffleBuilder(ScalarTy, Builder, *this);
12401-
if (E->getOpcode() == Instruction::Store) {
12439+
if (E->getOpcode() == Instruction::Store &&
12440+
E->State == TreeEntry::Vectorize) {
1240212441
ArrayRef<int> Mask =
1240312442
ArrayRef(reinterpret_cast<const int *>(E->ReorderIndices.begin()),
1240412443
E->ReorderIndices.size());
@@ -12986,8 +13025,27 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1298613025
VecValue = FinalShuffle(VecValue, E, VecTy);
1298713026

1298813027
Value *Ptr = SI->getPointerOperand();
12989-
StoreInst *ST =
12990-
Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign());
13028+
Instruction *ST;
13029+
if (E->State == TreeEntry::Vectorize) {
13030+
ST = Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign());
13031+
} else {
13032+
assert(E->State == TreeEntry::StridedVectorize &&
13033+
"Expected either strided or conseutive stores.");
13034+
Align CommonAlignment = computeCommonAlignment<StoreInst>(E->Scalars);
13035+
Type *StrideTy = DL->getIndexType(SI->getPointerOperandType());
13036+
auto *Inst = Builder.CreateIntrinsic(
13037+
Intrinsic::experimental_vp_strided_store,
13038+
{VecTy, Ptr->getType(), StrideTy},
13039+
{VecValue, Ptr,
13040+
ConstantInt::get(
13041+
StrideTy, -static_cast<int>(DL->getTypeAllocSize(ScalarTy))),
13042+
Builder.getAllOnesMask(VecTy->getElementCount()),
13043+
Builder.getInt32(E->Scalars.size())});
13044+
Inst->addParamAttr(
13045+
/*ArgNo=*/1,
13046+
Attribute::getWithAlignment(Inst->getContext(), CommonAlignment));
13047+
ST = Inst;
13048+
}
1299113049

1299213050
Value *V = propagateMetadata(ST, E->Scalars);
1299313051

llvm/test/Transforms/SLPVectorizer/RISCV/strided-stores-vectorized.ll

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,12 @@
44
define void @store_reverse(ptr %p3) {
55
; CHECK-LABEL: @store_reverse(
66
; CHECK-NEXT: entry:
7-
; CHECK-NEXT: [[TMP0:%.*]] = load i64, ptr [[P3:%.*]], align 8
8-
; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 8
9-
; CHECK-NEXT: [[TMP1:%.*]] = load i64, ptr [[ARRAYIDX1]], align 8
10-
; CHECK-NEXT: [[SHL:%.*]] = shl i64 [[TMP0]], [[TMP1]]
11-
; CHECK-NEXT: [[ARRAYIDX2:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 7
12-
; CHECK-NEXT: store i64 [[SHL]], ptr [[ARRAYIDX2]], align 8
13-
; CHECK-NEXT: [[ARRAYIDX3:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 1
14-
; CHECK-NEXT: [[TMP2:%.*]] = load i64, ptr [[ARRAYIDX3]], align 8
15-
; CHECK-NEXT: [[ARRAYIDX4:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 9
16-
; CHECK-NEXT: [[TMP3:%.*]] = load i64, ptr [[ARRAYIDX4]], align 8
17-
; CHECK-NEXT: [[SHL5:%.*]] = shl i64 [[TMP2]], [[TMP3]]
18-
; CHECK-NEXT: [[ARRAYIDX6:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 6
19-
; CHECK-NEXT: store i64 [[SHL5]], ptr [[ARRAYIDX6]], align 8
20-
; CHECK-NEXT: [[ARRAYIDX7:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 2
21-
; CHECK-NEXT: [[TMP4:%.*]] = load i64, ptr [[ARRAYIDX7]], align 8
22-
; CHECK-NEXT: [[ARRAYIDX8:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 10
23-
; CHECK-NEXT: [[TMP5:%.*]] = load i64, ptr [[ARRAYIDX8]], align 8
24-
; CHECK-NEXT: [[SHL9:%.*]] = shl i64 [[TMP4]], [[TMP5]]
25-
; CHECK-NEXT: [[ARRAYIDX10:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 5
26-
; CHECK-NEXT: store i64 [[SHL9]], ptr [[ARRAYIDX10]], align 8
27-
; CHECK-NEXT: [[ARRAYIDX11:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 3
28-
; CHECK-NEXT: [[TMP6:%.*]] = load i64, ptr [[ARRAYIDX11]], align 8
29-
; CHECK-NEXT: [[ARRAYIDX12:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 11
30-
; CHECK-NEXT: [[TMP7:%.*]] = load i64, ptr [[ARRAYIDX12]], align 8
31-
; CHECK-NEXT: [[SHL13:%.*]] = shl i64 [[TMP6]], [[TMP7]]
7+
; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i64, ptr [[P3:%.*]], i64 8
328
; CHECK-NEXT: [[ARRAYIDX14:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 4
33-
; CHECK-NEXT: store i64 [[SHL13]], ptr [[ARRAYIDX14]], align 8
9+
; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i64>, ptr [[P3]], align 8
10+
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i64>, ptr [[ARRAYIDX1]], align 8
11+
; CHECK-NEXT: [[TMP2:%.*]] = shl <4 x i64> [[TMP0]], [[TMP1]]
12+
; CHECK-NEXT: call void @llvm.experimental.vp.strided.store.v4i64.p0.i64(<4 x i64> [[TMP2]], ptr align 8 [[ARRAYIDX14]], i64 -8, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, i32 4)
3413
; CHECK-NEXT: ret void
3514
;
3615
entry:

0 commit comments

Comments
 (0)