Skip to content

Commit 068357d

Browse files
authored
[VectorCombine] Enable transform 'scalarizeLoadExtract' for scalable vector types (#65443)
The transform 'scalarizeLoadExtract' can be applied to scalable vector types if the index is less than the minimum number of elements. The check whether the index is less than the minimum number of elements locates at line 1175~1180. 'scalarizeLoadExtract' will call 'canScalarizeAccess' and check the returned result if this transform is safe. At the beginning of the function 'canScalarizeAccess', the index will be checked 1. If it is less than the number of elements of a fixed vector type. 2. If it is less than the minimum number of elements of a scalable vector type. Otherwise 'canScalarizeAccess' will return unsafe and this transform will be prevented.
1 parent b4d4146 commit 068357d

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,14 +1134,14 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
11341134
if (!match(&I, m_Load(m_Value(Ptr))))
11351135
return false;
11361136

1137-
auto *FixedVT = cast<FixedVectorType>(I.getType());
1137+
auto *VecTy = cast<VectorType>(I.getType());
11381138
auto *LI = cast<LoadInst>(&I);
11391139
const DataLayout &DL = I.getModule()->getDataLayout();
1140-
if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT))
1140+
if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy))
11411141
return false;
11421142

11431143
InstructionCost OriginalCost =
1144-
TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(),
1144+
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
11451145
LI->getPointerAddressSpace());
11461146
InstructionCost ScalarizedCost = 0;
11471147

@@ -1172,7 +1172,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
11721172
LastCheckedInst = UI;
11731173
}
11741174

1175-
auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT);
1175+
auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
11761176
if (!ScalarIdx.isSafe()) {
11771177
// TODO: Freeze index if it is safe to do so.
11781178
ScalarIdx.discard();
@@ -1182,12 +1182,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
11821182
auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
11831183
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
11841184
OriginalCost +=
1185-
TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind,
1185+
TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
11861186
Index ? Index->getZExtValue() : -1);
11871187
ScalarizedCost +=
1188-
TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(),
1188+
TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
11891189
Align(1), LI->getPointerAddressSpace());
1190-
ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType());
1190+
ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
11911191
}
11921192

11931193
if (ScalarizedCost >= OriginalCost)
@@ -1200,12 +1200,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
12001200

12011201
Value *Idx = EI->getOperand(1);
12021202
Value *GEP =
1203-
Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx});
1203+
Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
12041204
auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
1205-
FixedVT->getElementType(), GEP, EI->getName() + ".scalar"));
1205+
VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
12061206

12071207
Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1208-
LI->getAlign(), FixedVT->getElementType(), Idx, DL);
1208+
LI->getAlign(), VecTy->getElementType(), Idx, DL);
12091209
NewLoad->setAlignment(ScalarOpAlignment);
12101210

12111211
replaceValue(*EI, *NewLoad);
@@ -1727,9 +1727,6 @@ bool VectorCombine::run() {
17271727
case Instruction::ShuffleVector:
17281728
MadeChange |= widenSubvectorLoad(I);
17291729
break;
1730-
case Instruction::Load:
1731-
MadeChange |= scalarizeLoadExtract(I);
1732-
break;
17331730
default:
17341731
break;
17351732
}
@@ -1743,6 +1740,8 @@ bool VectorCombine::run() {
17431740
if (Opcode == Instruction::Store)
17441741
MadeChange |= foldSingleElementStore(I);
17451742

1743+
if (isa<VectorType>(I.getType()) && Opcode == Instruction::Load)
1744+
MadeChange |= scalarizeLoadExtract(I);
17461745

17471746
// If this is an early pipeline invocation of this pass, we are done.
17481747
if (TryEarlyFoldsOnly)

llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ define i32 @load_extract_idx_0(ptr %x) {
1515

1616
define i32 @vscale_load_extract_idx_0(ptr %x) {
1717
; CHECK-LABEL: @vscale_load_extract_idx_0(
18-
; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
19-
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 0
18+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 0
19+
; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP1]], align 16
2020
; CHECK-NEXT: ret i32 [[R]]
2121
;
2222
%lv = load <vscale x 4 x i32>, ptr %x
@@ -61,8 +61,8 @@ define i32 @load_extract_idx_2(ptr %x) {
6161

6262
define i32 @vscale_load_extract_idx_2(ptr %x) {
6363
; CHECK-LABEL: @vscale_load_extract_idx_2(
64-
; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
65-
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 2
64+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 2
65+
; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP1]], align 8
6666
; CHECK-NEXT: ret i32 [[R]]
6767
;
6868
%lv = load <vscale x 4 x i32>, ptr %x
@@ -142,9 +142,9 @@ define i32 @vscale_load_extract_idx_var_i64_known_valid_by_assume(ptr %x, i64 %i
142142
; CHECK-NEXT: entry:
143143
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4
144144
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
145-
; CHECK-NEXT: [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
146145
; CHECK-NEXT: call void @maythrow()
147-
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
146+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX]]
147+
; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP0]], align 4
148148
; CHECK-NEXT: ret i32 [[R]]
149149
;
150150
entry:

0 commit comments

Comments
 (0)