Skip to content

[TTI] Add alignment argument to TTI for compress/expand support #83516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

nikolaypanchenko
Copy link
Contributor

Since llvm.compressstore and llvm.expandload do require memory access, it's essential for some target to check if alignment is good to be able to lower them to target-specific instructions

Since `llvm.compressstore` and `llvm.expandload` do require memory
access, it's essential for some target to check if alignment is good to
be able to lower them to target-specific instructions
@llvmbot
Copy link
Member

llvmbot commented Mar 1, 2024

@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-backend-risc-v
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-backend-x86

Author: Kolya Panchenko (nikolaypanchenko)

Changes

Since llvm.compressstore and llvm.expandload do require memory access, it's essential for some target to check if alignment is good to be able to lower them to target-specific instructions


Full diff: https://github.com/llvm/llvm-project/pull/83516.diff

8 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+8-8)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+6-2)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+6-4)
  • (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp (+25)
  • (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (+2)
  • (modified) llvm/lib/Target/X86/X86TargetTransformInfo.cpp (+3-3)
  • (modified) llvm/lib/Target/X86/X86TargetTransformInfo.h (+2-2)
  • (modified) llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp (+6-2)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 58577a6b6eb5c0..4eab357f1b33b6 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -777,9 +777,9 @@ class TargetTransformInfo {
   bool forceScalarizeMaskedScatter(VectorType *Type, Align Alignment) const;
 
   /// Return true if the target supports masked compress store.
-  bool isLegalMaskedCompressStore(Type *DataType) const;
+  bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const;
   /// Return true if the target supports masked expand load.
-  bool isLegalMaskedExpandLoad(Type *DataType) const;
+  bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const;
 
   /// Return true if the target supports strided load.
   bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const;
@@ -1863,8 +1863,8 @@ class TargetTransformInfo::Concept {
                                           Align Alignment) = 0;
   virtual bool forceScalarizeMaskedScatter(VectorType *DataType,
                                            Align Alignment) = 0;
-  virtual bool isLegalMaskedCompressStore(Type *DataType) = 0;
-  virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0;
+  virtual bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) = 0;
+  virtual bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) = 0;
   virtual bool isLegalStridedLoadStore(Type *DataType, Align Alignment) = 0;
   virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
                                unsigned Opcode1,
@@ -2358,11 +2358,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                                    Align Alignment) override {
     return Impl.forceScalarizeMaskedScatter(DataType, Alignment);
   }
-  bool isLegalMaskedCompressStore(Type *DataType) override {
-    return Impl.isLegalMaskedCompressStore(DataType);
+  bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) override {
+    return Impl.isLegalMaskedCompressStore(DataType, Alignment);
   }
-  bool isLegalMaskedExpandLoad(Type *DataType) override {
-    return Impl.isLegalMaskedExpandLoad(DataType);
+  bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) override {
+    return Impl.isLegalMaskedExpandLoad(DataType, Alignment);
   }
   bool isLegalStridedLoadStore(Type *DataType, Align Alignment) override {
     return Impl.isLegalStridedLoadStore(DataType, Alignment);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 13379cc126a40c..95fb13d1c97154 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -295,14 +295,18 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
-  bool isLegalMaskedCompressStore(Type *DataType) const { return false; }
+  bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const {
+    return false;
+  }
 
   bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
                        const SmallBitVector &OpcodeMask) const {
     return false;
   }
 
-  bool isLegalMaskedExpandLoad(Type *DataType) const { return false; }
+  bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const {
+    return false;
+  }
 
   bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const {
     return false;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1f11f0d7dd620e..15311be4dba277 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -492,12 +492,14 @@ bool TargetTransformInfo::forceScalarizeMaskedScatter(VectorType *DataType,
   return TTIImpl->forceScalarizeMaskedScatter(DataType, Alignment);
 }
 
-bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType) const {
-  return TTIImpl->isLegalMaskedCompressStore(DataType);
+bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType,
+                                                     Align Alignment) const {
+  return TTIImpl->isLegalMaskedCompressStore(DataType, Alignment);
 }
 
-bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
-  return TTIImpl->isLegalMaskedExpandLoad(DataType);
+bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType,
+                                                  Align Alignment) const {
+  return TTIImpl->isLegalMaskedExpandLoad(DataType, Alignment);
 }
 
 bool TargetTransformInfo::isLegalStridedLoadStore(Type *DataType,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 2e4e69fb4f920f..0bd623e1196e1f 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1609,3 +1609,28 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
                   C2.NumIVMuls, C2.NumBaseAdds,
                   C2.ScaleCost, C2.ImmCost, C2.SetupCost);
 }
+
+bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
+  auto *VTy = dyn_cast<VectorType>(DataTy);
+  if (!VTy || VTy->isScalableTy() || !ST->hasVInstructions())
+    return false;
+
+  Type *ScalarTy = VTy->getScalarType();
+  if (ScalarTy->isFloatTy() || ScalarTy->isDoubleTy())
+    return true;
+
+  if (!ScalarTy->isIntegerTy())
+    return false;
+
+  switch (ScalarTy->getIntegerBitWidth()) {
+  case 8:
+  case 16:
+  case 32:
+  case 64:
+    break;
+  default:
+    return false;
+  }
+
+  return getRegUsageForType(VTy) <= 8;
+}
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index af36e9d5d5e886..8daf6845dc8bc9 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -261,6 +261,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
     return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
   }
 
+  bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);
+
   bool isVScaleKnownToBeAPowerOfTwo() const {
     return TLI->isVScaleKnownToBeAPowerOfTwo();
   }
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 18bf32fe1acaad..9c1e4b2f83ab7f 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -5938,7 +5938,7 @@ bool X86TTIImpl::isLegalBroadcastLoad(Type *ElementTy,
          ElementTy == Type::getDoubleTy(ElementTy->getContext());
 }
 
-bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy) {
+bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) {
   if (!isa<VectorType>(DataTy))
     return false;
 
@@ -5962,8 +5962,8 @@ bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy) {
          ((IntWidth == 8 || IntWidth == 16) && ST->hasVBMI2());
 }
 
-bool X86TTIImpl::isLegalMaskedCompressStore(Type *DataTy) {
-  return isLegalMaskedExpandLoad(DataTy);
+bool X86TTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
+  return isLegalMaskedExpandLoad(DataTy, Alignment);
 }
 
 bool X86TTIImpl::supportsGather() const {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 07a3fff4f84b3e..1a5e6bc886aa67 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -269,8 +269,8 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
   bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment);
   bool isLegalMaskedGather(Type *DataType, Align Alignment);
   bool isLegalMaskedScatter(Type *DataType, Align Alignment);
-  bool isLegalMaskedExpandLoad(Type *DataType);
-  bool isLegalMaskedCompressStore(Type *DataType);
+  bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment);
+  bool isLegalMaskedCompressStore(Type *DataType, Align Alignment);
   bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
                        const SmallBitVector &OpcodeMask) const;
   bool hasDivRemOp(Type *DataType, bool IsSigned);
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index c01d03f6447240..d545c0ae49f5a1 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -969,12 +969,16 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
       return true;
     }
     case Intrinsic::masked_expandload:
-      if (TTI.isLegalMaskedExpandLoad(CI->getType()))
+      if (TTI.isLegalMaskedExpandLoad(
+              CI->getType(),
+              CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
         return false;
       scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
       return true;
     case Intrinsic::masked_compressstore:
-      if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
+      if (TTI.isLegalMaskedCompressStore(
+              CI->getArgOperand(0)->getType(),
+              CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
         return false;
       scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
       return true;

@@ -1609,3 +1609,28 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
C2.NumIVMuls, C2.NumBaseAdds,
C2.ScaleCost, C2.ImmCost, C2.SetupCost);
}

bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this belongs in the patch that adds the RISCVISelLowering support. It's not legal for RISC-V without that.

return false;

Type *ScalarTy = VTy->getScalarType();
if (ScalarTy->isFloatTy() || ScalarTy->isDoubleTy())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect for Zve32f which doesn't support double or Zve64x and Zve32x that don't support float. There should be a function that handles this correctly in the gather/scatter TTI hook

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. I was looking for this before, but didn't know where I can find it. Will move the change to #83457

Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

yetingk pushed a commit to yetingk/llvm-project that referenced this pull request Mar 4, 2024
…ndload sections.

align attribute is used for masked.compress/expandload in commit llvm#83519, llvm#83763, llvm#83516.
yetingk added a commit that referenced this pull request Mar 5, 2024
…ndload sections. (#83808)

Align attribute has already been used for masked.compress/expandload in
commit #83519, #83763 and #83516.
Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@npanchen npanchen merged commit 889d99a into llvm:main Mar 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:RISC-V backend:X86 llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants