Skip to content

Commit e1db193

Browse files
[LVL][CSA] Legalize CSA vectorization
1 parent 20c099c commit e1db193

File tree

7 files changed

+72
-4
lines changed

7 files changed

+72
-4
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,6 +1800,10 @@ class TargetTransformInfo {
18001800
: EVLParamStrategy(EVLParamStrategy), OpStrategy(OpStrategy) {}
18011801
};
18021802

1803+
/// \returns true if the loop vectorizer should vectorize conditional
1804+
/// scalar assignments for the target.
1805+
bool enableCSAVectorization() const;
1806+
18031807
/// \returns How the target needs this vector-predicated operation to be
18041808
/// transformed.
18051809
VPLegalization getVPLegalizationStrategy(const VPIntrinsic &PI) const;
@@ -2225,6 +2229,7 @@ class TargetTransformInfo::Concept {
22252229
SmallVectorImpl<Use *> &OpsToSink) const = 0;
22262230

22272231
virtual bool isVectorShiftByScalarCheap(Type *Ty) const = 0;
2232+
virtual bool enableCSAVectorization() const = 0;
22282233
virtual VPLegalization
22292234
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
22302235
virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -3020,6 +3025,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
30203025
return Impl.isVectorShiftByScalarCheap(Ty);
30213026
}
30223027

3028+
bool enableCSAVectorization() const override {
3029+
return Impl.enableCSAVectorization();
3030+
}
3031+
30233032
VPLegalization
30243033
getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
30253034
return Impl.getVPLegalizationStrategy(PI);

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,8 @@ class TargetTransformInfoImplBase {
997997

998998
bool isVectorShiftByScalarCheap(Type *Ty) const { return false; }
999999

1000+
bool enableCSAVectorization() const { return false; }
1001+
10001002
TargetTransformInfo::VPLegalization
10011003
getVPLegalizationStrategy(const VPIntrinsic &PI) const {
10021004
return TargetTransformInfo::VPLegalization(

llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#define LLVM_TRANSFORMS_VECTORIZE_LOOPVECTORIZATIONLEGALITY_H
2828

2929
#include "llvm/ADT/MapVector.h"
30+
#include "llvm/Analysis/CSADescriptors.h"
3031
#include "llvm/Analysis/LoopAccessAnalysis.h"
3132
#include "llvm/Support/TypeSize.h"
3233
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -269,6 +270,10 @@ class LoopVectorizationLegality {
269270
/// induction descriptor.
270271
using InductionList = MapVector<PHINode *, InductionDescriptor>;
271272

273+
/// CSAList contains the CSA descriptors for all the CSAs that were found
274+
/// in the loop, rooted by their phis.
275+
using CSAList = MapVector<PHINode *, CSADescriptor>;
276+
272277
/// RecurrenceSet contains the phi nodes that are recurrences other than
273278
/// inductions and reductions.
274279
using RecurrenceSet = SmallPtrSet<const PHINode *, 8>;
@@ -321,6 +326,12 @@ class LoopVectorizationLegality {
321326
/// Returns True if V is a Phi node of an induction variable in this loop.
322327
bool isInductionPhi(const Value *V) const;
323328

329+
/// Returns the CSAs found in the loop.
330+
const CSAList &getCSAs() const { return CSAs; }
331+
332+
/// Returns true if Phi is the root of a CSA in the loop.
333+
bool isCSAPhi(PHINode *Phi) const { return CSAs.count(Phi) != 0; }
334+
324335
/// Returns a pointer to the induction descriptor, if \p Phi is an integer or
325336
/// floating point induction.
326337
const InductionDescriptor *getIntOrFpInductionDescriptor(PHINode *Phi) const;
@@ -545,6 +556,10 @@ class LoopVectorizationLegality {
545556
void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID,
546557
SmallPtrSetImpl<Value *> &AllowedExit);
547558

559+
// Updates the vetorization state by adding \p Phi to the CSA list.
560+
void addCSAPhi(PHINode *Phi, const CSADescriptor &CSADesc,
561+
SmallPtrSetImpl<Value *> &AllowedExit);
562+
548563
/// The loop that we evaluate.
549564
Loop *TheLoop;
550565

@@ -589,6 +604,9 @@ class LoopVectorizationLegality {
589604
/// variables can be pointers.
590605
InductionList Inductions;
591606

607+
/// Holds the conditional scalar assignments
608+
CSAList CSAs;
609+
592610
/// Holds all the casts that participate in the update chain of the induction
593611
/// variables, and that have been proven to be redundant (possibly under a
594612
/// runtime guard). These casts can be ignored when creating the vectorized

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,10 @@ bool TargetTransformInfo::preferEpilogueVectorization() const {
13291329
return TTIImpl->preferEpilogueVectorization();
13301330
}
13311331

1332+
bool TargetTransformInfo::enableCSAVectorization() const {
1333+
return TTIImpl->enableCSAVectorization();
1334+
}
1335+
13321336
TargetTransformInfo::VPLegalization
13331337
TargetTransformInfo::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {
13341338
return TTIImpl->getVPLegalizationStrategy(VPI);

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,6 +2315,11 @@ bool RISCVTTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) {
23152315
return true;
23162316
}
23172317

2318+
bool RISCVTTIImpl::enableCSAVectorization() const {
2319+
return ST->hasVInstructions() &&
2320+
ST->getProcFamily() == RISCVSubtarget::SiFive7;
2321+
}
2322+
23182323
bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
23192324
auto *VTy = dyn_cast<VectorType>(DataTy);
23202325
if (!VTy || VTy->isScalableTy())

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
308308
return TLI->isVScaleKnownToBeAPowerOfTwo();
309309
}
310310

311+
/// \returns true if the loop vectorizer should vectorize conditional
312+
/// scalar assignments for the target.
313+
bool enableCSAVectorization() const;
314+
311315
/// \returns How the target needs this vector-predicated operation to be
312316
/// transformed.
313317
TargetTransformInfo::VPLegalization

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ static cl::opt<bool> EnableHistogramVectorization(
8383
"enable-histogram-loop-vectorization", cl::init(false), cl::Hidden,
8484
cl::desc("Enables autovectorization of some loops containing histograms"));
8585

86+
static cl::opt<bool>
87+
EnableCSA("enable-csa-vectorization", cl::init(false), cl::Hidden,
88+
cl::desc("Control whether CSA loop vectorization is enabled"));
89+
8690
/// Maximum vectorization interleave count.
8791
static const unsigned MaxInterleaveFactor = 16;
8892

@@ -750,6 +754,15 @@ bool LoopVectorizationLegality::setupOuterLoopInductions() {
750754
return llvm::all_of(Header->phis(), IsSupportedPhi);
751755
}
752756

757+
void LoopVectorizationLegality::addCSAPhi(
758+
PHINode *Phi, const CSADescriptor &CSADesc,
759+
SmallPtrSetImpl<Value *> &AllowedExit) {
760+
assert(CSADesc.isValid() && "Expected Valid CSADescriptor");
761+
LLVM_DEBUG(dbgs() << "LV: found legal CSA opportunity" << *Phi << "\n");
762+
AllowedExit.insert(Phi);
763+
CSAs.insert({Phi, CSADesc});
764+
}
765+
753766
/// Checks if a function is scalarizable according to the TLI, in
754767
/// the sense that it should be vectorized and then expanded in
755768
/// multiple scalar calls. This is represented in the
@@ -867,14 +880,23 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
867880
continue;
868881
}
869882

870-
// As a last resort, coerce the PHI to a AddRec expression
871-
// and re-try classifying it a an induction PHI.
883+
// Try to coerce the PHI to a AddRec expression and re-try classifying
884+
// it a an induction PHI.
872885
if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true) &&
873886
!IsDisallowedStridedPointerInduction(ID)) {
874887
addInductionPhi(Phi, ID, AllowedExit);
875888
continue;
876889
}
877890

891+
// Check if the PHI can be classified as a CSA PHI.
892+
if (EnableCSA || (TTI->enableCSAVectorization() &&
893+
EnableCSA.getNumOccurrences() == 0)) {
894+
if (auto CSADesc = CSADescriptor::isCSAPhi(Phi, TheLoop)) {
895+
addCSAPhi(Phi, CSADesc, AllowedExit);
896+
continue;
897+
}
898+
}
899+
878900
reportVectorizationFailure("Found an unidentified PHI",
879901
"value that could not be identified as "
880902
"reduction is used outside the loop",
@@ -1846,11 +1868,15 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
18461868
for (const auto &Reduction : getReductionVars())
18471869
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
18481870

1871+
SmallPtrSet<const Value *, 8> CSALiveOuts;
1872+
for (const auto &CSA : getCSAs())
1873+
CSALiveOuts.insert(CSA.second.getAssignment());
1874+
18491875
// TODO: handle non-reduction outside users when tail is folded by masking.
18501876
for (auto *AE : AllowedExit) {
18511877
// Check that all users of allowed exit values are inside the loop or
1852-
// are the live-out of a reduction.
1853-
if (ReductionLiveOuts.count(AE))
1878+
// are the live-out of a reduction or a CSA
1879+
if (ReductionLiveOuts.count(AE) || CSALiveOuts.count(AE))
18541880
continue;
18551881
for (User *U : AE->users()) {
18561882
Instruction *UI = cast<Instruction>(U);

0 commit comments

Comments
 (0)