-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[TTI] Add alignment argument to TTI for compress/expand support #83516
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TTI] Add alignment argument to TTI for compress/expand support #83516
Conversation
Since `llvm.compressstore` and `llvm.expandload` do require memory access, it's essential for some target to check if alignment is good to be able to lower them to target-specific instructions
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-backend-x86 Author: Kolya Panchenko (nikolaypanchenko) ChangesSince Full diff: https://github.com/llvm/llvm-project/pull/83516.diff 8 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 58577a6b6eb5c0..4eab357f1b33b6 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -777,9 +777,9 @@ class TargetTransformInfo {
bool forceScalarizeMaskedScatter(VectorType *Type, Align Alignment) const;
/// Return true if the target supports masked compress store.
- bool isLegalMaskedCompressStore(Type *DataType) const;
+ bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const;
/// Return true if the target supports masked expand load.
- bool isLegalMaskedExpandLoad(Type *DataType) const;
+ bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const;
/// Return true if the target supports strided load.
bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const;
@@ -1863,8 +1863,8 @@ class TargetTransformInfo::Concept {
Align Alignment) = 0;
virtual bool forceScalarizeMaskedScatter(VectorType *DataType,
Align Alignment) = 0;
- virtual bool isLegalMaskedCompressStore(Type *DataType) = 0;
- virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0;
+ virtual bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) = 0;
+ virtual bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) = 0;
virtual bool isLegalStridedLoadStore(Type *DataType, Align Alignment) = 0;
virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
unsigned Opcode1,
@@ -2358,11 +2358,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
Align Alignment) override {
return Impl.forceScalarizeMaskedScatter(DataType, Alignment);
}
- bool isLegalMaskedCompressStore(Type *DataType) override {
- return Impl.isLegalMaskedCompressStore(DataType);
+ bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) override {
+ return Impl.isLegalMaskedCompressStore(DataType, Alignment);
}
- bool isLegalMaskedExpandLoad(Type *DataType) override {
- return Impl.isLegalMaskedExpandLoad(DataType);
+ bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) override {
+ return Impl.isLegalMaskedExpandLoad(DataType, Alignment);
}
bool isLegalStridedLoadStore(Type *DataType, Align Alignment) override {
return Impl.isLegalStridedLoadStore(DataType, Alignment);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 13379cc126a40c..95fb13d1c97154 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -295,14 +295,18 @@ class TargetTransformInfoImplBase {
return false;
}
- bool isLegalMaskedCompressStore(Type *DataType) const { return false; }
+ bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const {
+ return false;
+ }
bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
const SmallBitVector &OpcodeMask) const {
return false;
}
- bool isLegalMaskedExpandLoad(Type *DataType) const { return false; }
+ bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const {
+ return false;
+ }
bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const {
return false;
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1f11f0d7dd620e..15311be4dba277 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -492,12 +492,14 @@ bool TargetTransformInfo::forceScalarizeMaskedScatter(VectorType *DataType,
return TTIImpl->forceScalarizeMaskedScatter(DataType, Alignment);
}
-bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType) const {
- return TTIImpl->isLegalMaskedCompressStore(DataType);
+bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType,
+ Align Alignment) const {
+ return TTIImpl->isLegalMaskedCompressStore(DataType, Alignment);
}
-bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
- return TTIImpl->isLegalMaskedExpandLoad(DataType);
+bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType,
+ Align Alignment) const {
+ return TTIImpl->isLegalMaskedExpandLoad(DataType, Alignment);
}
bool TargetTransformInfo::isLegalStridedLoadStore(Type *DataType,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 2e4e69fb4f920f..0bd623e1196e1f 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1609,3 +1609,28 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
C2.NumIVMuls, C2.NumBaseAdds,
C2.ScaleCost, C2.ImmCost, C2.SetupCost);
}
+
+bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
+ auto *VTy = dyn_cast<VectorType>(DataTy);
+ if (!VTy || VTy->isScalableTy() || !ST->hasVInstructions())
+ return false;
+
+ Type *ScalarTy = VTy->getScalarType();
+ if (ScalarTy->isFloatTy() || ScalarTy->isDoubleTy())
+ return true;
+
+ if (!ScalarTy->isIntegerTy())
+ return false;
+
+ switch (ScalarTy->getIntegerBitWidth()) {
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ break;
+ default:
+ return false;
+ }
+
+ return getRegUsageForType(VTy) <= 8;
+}
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index af36e9d5d5e886..8daf6845dc8bc9 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -261,6 +261,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
}
+ bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);
+
bool isVScaleKnownToBeAPowerOfTwo() const {
return TLI->isVScaleKnownToBeAPowerOfTwo();
}
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 18bf32fe1acaad..9c1e4b2f83ab7f 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -5938,7 +5938,7 @@ bool X86TTIImpl::isLegalBroadcastLoad(Type *ElementTy,
ElementTy == Type::getDoubleTy(ElementTy->getContext());
}
-bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy) {
+bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) {
if (!isa<VectorType>(DataTy))
return false;
@@ -5962,8 +5962,8 @@ bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy) {
((IntWidth == 8 || IntWidth == 16) && ST->hasVBMI2());
}
-bool X86TTIImpl::isLegalMaskedCompressStore(Type *DataTy) {
- return isLegalMaskedExpandLoad(DataTy);
+bool X86TTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
+ return isLegalMaskedExpandLoad(DataTy, Alignment);
}
bool X86TTIImpl::supportsGather() const {
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h
index 07a3fff4f84b3e..1a5e6bc886aa67 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.h
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h
@@ -269,8 +269,8 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment);
bool isLegalMaskedGather(Type *DataType, Align Alignment);
bool isLegalMaskedScatter(Type *DataType, Align Alignment);
- bool isLegalMaskedExpandLoad(Type *DataType);
- bool isLegalMaskedCompressStore(Type *DataType);
+ bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment);
+ bool isLegalMaskedCompressStore(Type *DataType, Align Alignment);
bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
const SmallBitVector &OpcodeMask) const;
bool hasDivRemOp(Type *DataType, bool IsSigned);
diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
index c01d03f6447240..d545c0ae49f5a1 100644
--- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
@@ -969,12 +969,16 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
return true;
}
case Intrinsic::masked_expandload:
- if (TTI.isLegalMaskedExpandLoad(CI->getType()))
+ if (TTI.isLegalMaskedExpandLoad(
+ CI->getType(),
+ CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
return false;
scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
return true;
case Intrinsic::masked_compressstore:
- if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
+ if (TTI.isLegalMaskedCompressStore(
+ CI->getArgOperand(0)->getType(),
+ CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
return false;
scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
return true;
|
@@ -1609,3 +1609,28 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1, | |||
C2.NumIVMuls, C2.NumBaseAdds, | |||
C2.ScaleCost, C2.ImmCost, C2.SetupCost); | |||
} | |||
|
|||
bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this belongs in the patch that adds the RISCVISelLowering support. It's not legal for RISC-V without that.
return false; | ||
|
||
Type *ScalarTy = VTy->getScalarType(); | ||
if (ScalarTy->isFloatTy() || ScalarTy->isDoubleTy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incorrect for Zve32f which doesn't support double or Zve64x and Zve32x that don't support float. There should be a function that handles this correctly in the gather/scatter TTI hook
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks. I was looking for this before, but didn't know where I can find it. Will move the change to #83457
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…ndload sections. align attribute is used for masked.compress/expandload in commit llvm#83519, llvm#83763, llvm#83516.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Since
llvm.compressstore
andllvm.expandload
do require memory access, it's essential for some target to check if alignment is good to be able to lower them to target-specific instructions