Skip to content

Commit 63a0a81

Browse files
authored
[NFC][Scalarizer][TargetTransformInfo] Add isTargetIntrinsicWithScalarOpAtArg api (#111441)
This change allows target intrinsics can have scalar args fixes [111440](#111440) This change will let us add scalarization for WaveReadLaneAt: #111010
1 parent 39ac121 commit 63a0a81

File tree

7 files changed

+41
-2
lines changed

7 files changed

+41
-2
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,9 @@ class TargetTransformInfo {
884884

885885
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
886886

887+
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
888+
unsigned ScalarOpdIdx) const;
889+
887890
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
888891
/// are set if the demanded result elements need to be inserted and/or
889892
/// extracted from vectors.
@@ -1935,6 +1938,8 @@ class TargetTransformInfo::Concept {
19351938
virtual bool shouldBuildRelLookupTables() = 0;
19361939
virtual bool useColdCCForColdCall(Function &F) = 0;
19371940
virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
1941+
virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
1942+
unsigned ScalarOpdIdx) = 0;
19381943
virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
19391944
const APInt &DemandedElts,
19401945
bool Insert, bool Extract,
@@ -2477,6 +2482,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
24772482
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) override {
24782483
return Impl.isTargetIntrinsicTriviallyScalarizable(ID);
24792484
}
2485+
2486+
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
2487+
unsigned ScalarOpdIdx) override {
2488+
return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
2489+
}
2490+
24802491
InstructionCost getScalarizationOverhead(VectorType *Ty,
24812492
const APInt &DemandedElts,
24822493
bool Insert, bool Extract,

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,11 @@ class TargetTransformInfoImplBase {
377377
return false;
378378
}
379379

380+
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
381+
unsigned ScalarOpdIdx) const {
382+
return false;
383+
}
384+
380385
InstructionCost getScalarizationOverhead(VectorType *Ty,
381386
const APInt &DemandedElts,
382387
bool Insert, bool Extract,

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
793793
return false;
794794
}
795795

796+
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
797+
unsigned ScalarOpdIdx) const {
798+
return false;
799+
}
800+
796801
/// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
797802
InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
798803
bool Extract,

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,11 @@ bool TargetTransformInfo::isTargetIntrinsicTriviallyScalarizable(
592592
return TTIImpl->isTargetIntrinsicTriviallyScalarizable(ID);
593593
}
594594

595+
bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
596+
Intrinsic::ID ID, unsigned ScalarOpdIdx) const {
597+
return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
598+
}
599+
595600
InstructionCost TargetTransformInfo::getScalarizationOverhead(
596601
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
597602
TTI::TargetCostKind CostKind) const {

llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,17 @@
1313
#include "llvm/IR/Intrinsics.h"
1414
#include "llvm/IR/IntrinsicsDirectX.h"
1515

16-
bool llvm::DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
16+
using namespace llvm;
17+
18+
bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
19+
unsigned ScalarOpdIdx) {
20+
switch (ID) {
21+
default:
22+
return false;
23+
}
24+
}
25+
26+
bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
1727
Intrinsic::ID ID) const {
1828
switch (ID) {
1929
case Intrinsic::dx_frac:

llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
3535
TLI(ST->getTargetLowering()) {}
3636
unsigned getMinVectorRegisterBitWidth() const { return 32; }
3737
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
38+
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
39+
unsigned ScalarOpdIdx);
3840
};
3941
} // namespace llvm
4042

llvm/lib/Transforms/Scalar/Scalarizer.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
745745
Tys[0] = VS->RemainderTy;
746746

747747
for (unsigned J = 0; J != NumArgs; ++J) {
748-
if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) {
748+
if (isVectorIntrinsicWithScalarOpAtArg(ID, J) ||
749+
TTI->isTargetIntrinsicWithScalarOpAtArg(ID, J)) {
749750
ScalarCallOps.push_back(ScalarOperands[J]);
750751
} else {
751752
ScalarCallOps.push_back(Scattered[J][I]);

0 commit comments

Comments
 (0)