Skip to content

[LLVM][IR] Teach extractelement folds about constant ConstantInt/FP. #116793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

paulwalker-arm
Copy link
Collaborator

No description provided.

@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:ir llvm:transforms labels Nov 19, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-llvm-ir

Author: Paul Walker (paulwalker-arm)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/116793.diff

4 Files Affected:

  • (modified) llvm/lib/IR/Constants.cpp (+4)
  • (modified) llvm/lib/IR/Instructions.cpp (+10-2)
  • (modified) llvm/test/Transforms/InstCombine/extractelement.ll (+5)
  • (modified) llvm/test/Transforms/InstSimplify/extract-element.ll (+1)
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 3d6c4ad780dc24..95832ed0b8951a 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -1699,6 +1699,10 @@ Constant *Constant::getSplatValue(bool AllowPoison) const {
   assert(this->getType()->isVectorTy() && "Only valid for vectors!");
   if (isa<ConstantAggregateZero>(this))
     return getNullValue(cast<VectorType>(getType())->getElementType());
+  if (auto *CI = dyn_cast<ConstantInt>(this))
+    return ConstantInt::get(getContext(), CI->getValue());
+  if (auto *CFP = dyn_cast<ConstantFP>(this))
+    return ConstantFP::get(getContext(), CFP->getValue());
   if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
     return CV->getSplatValue();
   if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 5b89a27126150a..697bdbcdfa943a 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -1752,8 +1752,17 @@ bool ShuffleVectorInst::isValidOperands(const Value *V1, const Value *V2,
   if (isa<UndefValue>(Mask) || isa<ConstantAggregateZero>(Mask))
     return true;
 
+  // NOTE: Through vector ConstantInt we have the potential to support more
+  // than just zero splat masks but that requires a LangRef change.
+  if (isa<ScalableVectorType>(MaskTy))
+    return false;
+
+  unsigned V1Size = cast<FixedVectorType>(V1->getType())->getNumElements();
+
+  if (const auto *CI = dyn_cast<ConstantInt>(Mask))
+    return !CI->uge(V1Size * 2);
+
   if (const auto *MV = dyn_cast<ConstantVector>(Mask)) {
-    unsigned V1Size = cast<FixedVectorType>(V1->getType())->getNumElements();
     for (Value *Op : MV->operands()) {
       if (auto *CI = dyn_cast<ConstantInt>(Op)) {
         if (CI->uge(V1Size*2))
@@ -1766,7 +1775,6 @@ bool ShuffleVectorInst::isValidOperands(const Value *V1, const Value *V2,
   }
 
   if (const auto *CDS = dyn_cast<ConstantDataSequential>(Mask)) {
-    unsigned V1Size = cast<FixedVectorType>(V1->getType())->getNumElements();
     for (unsigned i = 0, e = cast<FixedVectorType>(MaskTy)->getNumElements();
          i != e; ++i)
       if (CDS->getElementAsInteger(i) >= V1Size*2)
diff --git a/llvm/test/Transforms/InstCombine/extractelement.ll b/llvm/test/Transforms/InstCombine/extractelement.ll
index 28a4702559c46c..fe4fe99fcbeb3d 100644
--- a/llvm/test/Transforms/InstCombine/extractelement.ll
+++ b/llvm/test/Transforms/InstCombine/extractelement.ll
@@ -4,6 +4,11 @@
 ; RUN: opt < %s -passes=instcombine -S -data-layout="E-n64" | FileCheck %s --check-prefixes=ANY,ANYBE,BE64
 ; RUN: opt < %s -passes=instcombine -S -data-layout="E-n128" | FileCheck %s --check-prefixes=ANY,ANYBE,BE128
 
+; RUN: opt < %s -passes=instcombine -S -data-layout="e-n64" -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=ANY,ANYLE,LE64
+; RUN: opt < %s -passes=instcombine -S -data-layout="e-n128" -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=ANY,ANYLE,LE128
+; RUN: opt < %s -passes=instcombine -S -data-layout="E-n64" -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=ANY,ANYBE,BE64
+; RUN: opt < %s -passes=instcombine -S -data-layout="E-n128" -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=ANY,ANYBE,BE128
+
 define i32 @extractelement_out_of_range(<2 x i32> %x) {
 ; ANY-LABEL: @extractelement_out_of_range(
 ; ANY-NEXT:    ret i32 poison
diff --git a/llvm/test/Transforms/InstSimplify/extract-element.ll b/llvm/test/Transforms/InstSimplify/extract-element.ll
index 3060586b25a791..7d30805f4fdc71 100644
--- a/llvm/test/Transforms/InstSimplify/extract-element.ll
+++ b/llvm/test/Transforms/InstSimplify/extract-element.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instsimplify -S | FileCheck %s
+; RUN: opt < %s -passes=instsimplify -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat -S | FileCheck %s
 
 ; Weird Types
 

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Paul Walker (paulwalker-arm)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/116793.diff

4 Files Affected:

  • (modified) llvm/lib/IR/Constants.cpp (+4)
  • (modified) llvm/lib/IR/Instructions.cpp (+10-2)
  • (modified) llvm/test/Transforms/InstCombine/extractelement.ll (+5)
  • (modified) llvm/test/Transforms/InstSimplify/extract-element.ll (+1)
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 3d6c4ad780dc24..95832ed0b8951a 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -1699,6 +1699,10 @@ Constant *Constant::getSplatValue(bool AllowPoison) const {
   assert(this->getType()->isVectorTy() && "Only valid for vectors!");
   if (isa<ConstantAggregateZero>(this))
     return getNullValue(cast<VectorType>(getType())->getElementType());
+  if (auto *CI = dyn_cast<ConstantInt>(this))
+    return ConstantInt::get(getContext(), CI->getValue());
+  if (auto *CFP = dyn_cast<ConstantFP>(this))
+    return ConstantFP::get(getContext(), CFP->getValue());
   if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
     return CV->getSplatValue();
   if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 5b89a27126150a..697bdbcdfa943a 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -1752,8 +1752,17 @@ bool ShuffleVectorInst::isValidOperands(const Value *V1, const Value *V2,
   if (isa<UndefValue>(Mask) || isa<ConstantAggregateZero>(Mask))
     return true;
 
+  // NOTE: Through vector ConstantInt we have the potential to support more
+  // than just zero splat masks but that requires a LangRef change.
+  if (isa<ScalableVectorType>(MaskTy))
+    return false;
+
+  unsigned V1Size = cast<FixedVectorType>(V1->getType())->getNumElements();
+
+  if (const auto *CI = dyn_cast<ConstantInt>(Mask))
+    return !CI->uge(V1Size * 2);
+
   if (const auto *MV = dyn_cast<ConstantVector>(Mask)) {
-    unsigned V1Size = cast<FixedVectorType>(V1->getType())->getNumElements();
     for (Value *Op : MV->operands()) {
       if (auto *CI = dyn_cast<ConstantInt>(Op)) {
         if (CI->uge(V1Size*2))
@@ -1766,7 +1775,6 @@ bool ShuffleVectorInst::isValidOperands(const Value *V1, const Value *V2,
   }
 
   if (const auto *CDS = dyn_cast<ConstantDataSequential>(Mask)) {
-    unsigned V1Size = cast<FixedVectorType>(V1->getType())->getNumElements();
     for (unsigned i = 0, e = cast<FixedVectorType>(MaskTy)->getNumElements();
          i != e; ++i)
       if (CDS->getElementAsInteger(i) >= V1Size*2)
diff --git a/llvm/test/Transforms/InstCombine/extractelement.ll b/llvm/test/Transforms/InstCombine/extractelement.ll
index 28a4702559c46c..fe4fe99fcbeb3d 100644
--- a/llvm/test/Transforms/InstCombine/extractelement.ll
+++ b/llvm/test/Transforms/InstCombine/extractelement.ll
@@ -4,6 +4,11 @@
 ; RUN: opt < %s -passes=instcombine -S -data-layout="E-n64" | FileCheck %s --check-prefixes=ANY,ANYBE,BE64
 ; RUN: opt < %s -passes=instcombine -S -data-layout="E-n128" | FileCheck %s --check-prefixes=ANY,ANYBE,BE128
 
+; RUN: opt < %s -passes=instcombine -S -data-layout="e-n64" -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=ANY,ANYLE,LE64
+; RUN: opt < %s -passes=instcombine -S -data-layout="e-n128" -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=ANY,ANYLE,LE128
+; RUN: opt < %s -passes=instcombine -S -data-layout="E-n64" -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=ANY,ANYBE,BE64
+; RUN: opt < %s -passes=instcombine -S -data-layout="E-n128" -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=ANY,ANYBE,BE128
+
 define i32 @extractelement_out_of_range(<2 x i32> %x) {
 ; ANY-LABEL: @extractelement_out_of_range(
 ; ANY-NEXT:    ret i32 poison
diff --git a/llvm/test/Transforms/InstSimplify/extract-element.ll b/llvm/test/Transforms/InstSimplify/extract-element.ll
index 3060586b25a791..7d30805f4fdc71 100644
--- a/llvm/test/Transforms/InstSimplify/extract-element.ll
+++ b/llvm/test/Transforms/InstSimplify/extract-element.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instsimplify -S | FileCheck %s
+; RUN: opt < %s -passes=instsimplify -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat -S | FileCheck %s
 
 ; Weird Types
 

@paulwalker-arm
Copy link
Collaborator Author

NOTE: This doesn't fix up all folds but I'm trying to make use of existing tests where possible before circling back to add new ones required for future fixes.

Copy link
Contributor

@david-arm david-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@paulwalker-arm paulwalker-arm merged commit 4872ecf into llvm:main Nov 21, 2024
12 checks passed
@paulwalker-arm paulwalker-arm deleted the vector-constants-ir-extractelement branch November 21, 2024 12:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:ir llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants