Skip to content

[VectorCombine] Enable transform 'scalarizeLoadExtract' for scalable vector types #65443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 18, 2023
Merged

[VectorCombine] Enable transform 'scalarizeLoadExtract' for scalable vector types #65443

merged 1 commit into from
Sep 18, 2023

Conversation

benshi001
Copy link
Member

@benshi001 benshi001 commented Sep 6, 2023

The transform 'scalarizeLoadExtract' can be applied to scalable
vector types if the index is less than the minimum number of elements.

The transform 'scalarizeLoadExtract' can be applied to scalable
vector types if the index is less than the minimum number of elements.

The check whether the index is less than the minimum number of elements
locates at line 1175~1180. 'scalarizeLoadExtract' will call 'canScalarizeAccess'
and check the returned result if this transform is safe.

At the beginning of the function 'canScalarizeAccess', the index will be checked

  1. if it is less than the number of elements of a fixed vector type.
  2. if it is less than the minimum number of elements of a scalable vector type.

Otherwise 'canScalarizeAccess' will return unsafe and this transform will be
prevented.

This PR is stacked #65442

@fhahn
Copy link
Contributor

fhahn commented Sep 6, 2023

The transform 'scalarizeLoadExtract' can be applied to scalable vector types if the index is less than the minimum number of elements.

Would be good to clarify this in the commit message; we need to prove that the index is < # of vector elements independent of whether the vector is scalable or not. IIUC the restriction on minimum number of elements is only because it makes things work well with the existing reasoning, but this restriction could be weakened in the future, i.e. there may be scenarios where we could prove an index < vscale * min number of elements.

@benshi001
Copy link
Member Author

benshi001 commented Sep 7, 2023

The transform 'scalarizeLoadExtract' can be applied to scalable vector types if the index is less than the minimum number of elements.

Would be good to clarify this in the commit message; we need to prove that the index is < # of vector elements independent of whether the vector is scalable or not. IIUC the restriction on minimum number of elements is only because it makes things work well with the existing reasoning, but this restriction could be weakened in the future, i.e. there may be scenarios where we could prove an index < vscale * min number of elements.

I have updated my commit message db60c01 for the clarification as you suggested.

@benshi001
Copy link
Member Author

ping ...

…vector types

The transform 'scalarizeLoadExtract' can be applied to scalable
vector types if the index is less than the minimum number of elements.

The check whether the index is less than the minimum number of elements
locates at line 1175~1180. 'scalarizeLoadExtract' will call 'canScalarizeAccess'
and check the returned result if this transform is safe.

At the beginning of the function 'canScalarizeAccess', the index will be checked
1. if it is less than the number of elements of a fixed vector type.
2. if it is less than the minimum number of elements of a scalable vector type.

Otherwise 'canScalarizeAccess' will return unsafe and this transform will be
prevented.
@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2023

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Changes The transform 'scalarizeLoadExtract' can be applied to scalable vector types if the index is less than the minimum number of elements.

The transform 'scalarizeLoadExtract' can be applied to scalable
vector types if the index is less than the minimum number of elements.

The check whether the index is less than the minimum number of elements
locates at line 1175~1180. 'scalarizeLoadExtract' will call 'canScalarizeAccess'
and check the returned result if this transform is safe.

At the beginning of the function 'canScalarizeAccess', the index will be checked

  1. if it is less than the number of elements of a fixed vector type.
  2. if it is less than the minimum number of elements of a scalable vector type.

Otherwise 'canScalarizeAccess' will return unsafe and this transform will be
prevented.

This PR is stacked #65442

Full diff: https://github.com/llvm/llvm-project/pull/65443.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+12-13)
  • (modified) llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll (+6-6)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 66e3bcaac0adb2e..4f95eaba8de7bd2 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1134,14 +1134,14 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   if (!match(&I, m_Load(m_Value(Ptr))))
     return false;
 
-  auto *FixedVT = cast<FixedVectorType>(I.getType());
+  auto *VecTy = cast<VectorType>(I.getType());
   auto *LI = cast<LoadInst>(&I);
   const DataLayout &DL = I.getModule()->getDataLayout();
-  if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT))
+  if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy))
     return false;
 
   InstructionCost OriginalCost =
-      TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(),
+      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
                           LI->getPointerAddressSpace());
   InstructionCost ScalarizedCost = 0;
 
@@ -1172,7 +1172,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
       LastCheckedInst = UI;
     }
 
-    auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT);
+    auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
     if (!ScalarIdx.isSafe()) {
       // TODO: Freeze index if it is safe to do so.
       ScalarIdx.discard();
@@ -1182,12 +1182,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     OriginalCost +=
-        TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind,
+        TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
                                Index ? Index->getZExtValue() : -1);
     ScalarizedCost +=
-        TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(),
+        TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
                             Align(1), LI->getPointerAddressSpace());
-    ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType());
+    ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
   }
 
   if (ScalarizedCost >= OriginalCost)
@@ -1200,12 +1200,12 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
 
     Value *Idx = EI->getOperand(1);
     Value *GEP =
-        Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx});
+        Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
     auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
-        FixedVT->getElementType(), GEP, EI->getName() + ".scalar"));
+        VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
 
     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
-        LI->getAlign(), FixedVT->getElementType(), Idx, DL);
+        LI->getAlign(), VecTy->getElementType(), Idx, DL);
     NewLoad->setAlignment(ScalarOpAlignment);
 
     replaceValue(*EI, *NewLoad);
@@ -1727,9 +1727,6 @@ bool VectorCombine::run() {
       case Instruction::ShuffleVector:
         MadeChange |= widenSubvectorLoad(I);
         break;
-      case Instruction::Load:
-        MadeChange |= scalarizeLoadExtract(I);
-        break;
       default:
         break;
       }
@@ -1743,6 +1740,8 @@ bool VectorCombine::run() {
     if (Opcode == Instruction::Store)
       MadeChange |= foldSingleElementStore(I);
 
+    if (isa<VectorType>(I.getType()) && Opcode == Instruction::Load)
+      MadeChange |= scalarizeLoadExtract(I);
 
     // If this is an early pipeline invocation of this pass, we are done.
     if (TryEarlyFoldsOnly)
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index 7df4f49e095c96c..c7e5979aa9e7bd9 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -15,8 +15,8 @@ define i32 @load_extract_idx_0(ptr %x) {
 
 define i32 @vscale_load_extract_idx_0(ptr %x) {
 ; CHECK-LABEL: @vscale_load_extract_idx_0(
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 0
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP1]], align 16
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %lv = load <vscale x 4 x i32>, ptr %x
@@ -61,8 +61,8 @@ define i32 @load_extract_idx_2(ptr %x) {
 
 define i32 @vscale_load_extract_idx_2(ptr %x) {
 ; CHECK-LABEL: @vscale_load_extract_idx_2(
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i32 2
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i32 2
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP1]], align 8
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %lv = load <vscale x 4 x i32>, ptr %x
@@ -142,9 +142,9 @@ define i32 @vscale_load_extract_idx_var_i64_known_valid_by_assume(ptr %x, i64 %i
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
 ; CHECK-NEXT:    call void @maythrow()
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@benshi001 benshi001 merged commit 068357d into llvm:main Sep 18, 2023
@benshi001 benshi001 deleted the veccom-load-extract-1 branch September 18, 2023 02:49
@kda
Copy link
Contributor

kda commented Sep 18, 2023

@aartbik
Copy link
Contributor

aartbik commented Sep 18, 2023 via email

@aartbik
Copy link
Contributor

aartbik commented Sep 18, 2023 via email

@benshi001
Copy link
Member Author

ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
…vector types (llvm#65443)

The transform 'scalarizeLoadExtract' can be applied to scalable
vector types if the index is less than the minimum number of elements.

The check whether the index is less than the minimum number of elements
locates at line 1175~1180. 'scalarizeLoadExtract' will call
'canScalarizeAccess' and check the returned result if this transform is safe.

At the beginning of the function 'canScalarizeAccess', the index will be
checked
1. If it is less than the number of elements of a fixed vector type.
2. If it is less than the minimum number of elements of a scalable vector type.

Otherwise 'canScalarizeAccess' will return unsafe and this transform
will be prevented.
zahiraam pushed a commit to tahonermann/llvm-project that referenced this pull request Oct 24, 2023
…vector types (llvm#65443)

The transform 'scalarizeLoadExtract' can be applied to scalable
vector types if the index is less than the minimum number of elements.

The check whether the index is less than the minimum number of elements
locates at line 1175~1180. 'scalarizeLoadExtract' will call
'canScalarizeAccess' and check the returned result if this transform is safe.

At the beginning of the function 'canScalarizeAccess', the index will be
checked
1. If it is less than the number of elements of a fixed vector type.
2. If it is less than the minimum number of elements of a scalable vector type.

Otherwise 'canScalarizeAccess' will return unsafe and this transform
will be prevented.
zahiraam pushed a commit to tahonermann/llvm-project that referenced this pull request Oct 24, 2023
…vector types (llvm#65443)

The transform 'scalarizeLoadExtract' can be applied to scalable
vector types if the index is less than the minimum number of elements.

The check whether the index is less than the minimum number of elements
locates at line 1175~1180. 'scalarizeLoadExtract' will call
'canScalarizeAccess' and check the returned result if this transform is safe.

At the beginning of the function 'canScalarizeAccess', the index will be
checked
1. If it is less than the number of elements of a fixed vector type.
2. If it is less than the minimum number of elements of a scalable vector type.

Otherwise 'canScalarizeAccess' will return unsafe and this transform
will be prevented.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants