Skip to content

Commit d5e14ba

Browse files
committed
[GlobalISel] NFC: Change LLT::vector to take ElementCount.
This also adds new interfaces for the fixed- and scalable case: * LLT::fixed_vector * LLT::scalable_vector The strategy for migrating to the new interfaces was as follows: * If the new LLT is a (modified) clone of another LLT, taking the same number of elements, then use LLT::vector(OtherTy.getElementCount()) or if the number of elements is halfed/doubled, it uses .divideCoefficientBy(2) or operator*. That is because there is no reason to specifically restrict the types to 'fixed_vector'. * If the algorithm works on the number of elements (as unsigned), then just use fixed_vector. This will need to be fixed up in the future when modifying the algorithm to also work for scalable vectors, and will need then need additional tests to confirm the behaviour works the same for scalable vectors. * If the test used the '/*Scalable=*/true` flag of LLT::vector, then this is replaced by LLT::scalable_vector. Reviewed By: aemerson Differential Revision: https://reviews.llvm.org/D104451
1 parent 9c4c2f2 commit d5e14ba

32 files changed

+419
-382
lines changed

llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,7 @@ class LegalizeRuleSet {
10211021
[=](const LegalityQuery &Query) {
10221022
LLT VecTy = Query.Types[TypeIdx];
10231023
return std::make_pair(
1024-
TypeIdx, LLT::vector(MinElements, VecTy.getElementType()));
1024+
TypeIdx, LLT::fixed_vector(MinElements, VecTy.getElementType()));
10251025
});
10261026
}
10271027
/// Limit the number of elements in EltTy vectors to at most MaxElements.

llvm/include/llvm/Support/LowLevelTypeImpl.h

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,30 +55,51 @@ class LLT {
5555
}
5656

5757
/// Get a low-level vector of some number of elements and element width.
58-
/// \p NumElements must be at least 2.
59-
static LLT vector(uint16_t NumElements, unsigned ScalarSizeInBits,
60-
bool Scalable = false) {
61-
assert(((!Scalable && NumElements > 1) || NumElements > 0) &&
62-
"invalid number of vector elements");
58+
static LLT vector(ElementCount EC, unsigned ScalarSizeInBits) {
59+
assert(!EC.isScalar() && "invalid number of vector elements");
6360
assert(ScalarSizeInBits > 0 && "invalid vector element size");
64-
return LLT{/*isPointer=*/false, /*isVector=*/true,
65-
ElementCount::get(NumElements, Scalable), ScalarSizeInBits,
61+
return LLT{/*isPointer=*/false, /*isVector=*/true, EC, ScalarSizeInBits,
6662
/*AddressSpace=*/0};
6763
}
6864

6965
/// Get a low-level vector of some number of elements and element type.
70-
static LLT vector(uint16_t NumElements, LLT ScalarTy, bool Scalable = false) {
71-
assert(((!Scalable && NumElements > 1) || NumElements > 0) &&
72-
"invalid number of vector elements");
66+
static LLT vector(ElementCount EC, LLT ScalarTy) {
67+
assert(!EC.isScalar() && "invalid number of vector elements");
7368
assert(!ScalarTy.isVector() && "invalid vector element type");
74-
return LLT{ScalarTy.isPointer(), /*isVector=*/true,
75-
ElementCount::get(NumElements, Scalable),
69+
return LLT{ScalarTy.isPointer(), /*isVector=*/true, EC,
7670
ScalarTy.getSizeInBits(),
7771
ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0};
7872
}
7973

74+
/// Get a low-level fixed-width vector of some number of elements and element
75+
/// width.
76+
static LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits) {
77+
return vector(ElementCount::getFixed(NumElements), ScalarSizeInBits);
78+
}
79+
80+
/// Get a low-level fixed-width vector of some number of elements and element
81+
/// type.
82+
static LLT fixed_vector(unsigned NumElements, LLT ScalarTy) {
83+
return vector(ElementCount::getFixed(NumElements), ScalarTy);
84+
}
85+
86+
/// Get a low-level scalable vector of some number of elements and element
87+
/// width.
88+
static LLT scalable_vector(unsigned MinNumElements,
89+
unsigned ScalarSizeInBits) {
90+
return vector(ElementCount::getScalable(MinNumElements), ScalarSizeInBits);
91+
}
92+
93+
/// Get a low-level scalable vector of some number of elements and element
94+
/// type.
95+
static LLT scalable_vector(unsigned MinNumElements, LLT ScalarTy) {
96+
return vector(ElementCount::getScalable(MinNumElements), ScalarTy);
97+
}
98+
8099
static LLT scalarOrVector(uint16_t NumElements, LLT ScalarTy) {
81-
return NumElements == 1 ? ScalarTy : LLT::vector(NumElements, ScalarTy);
100+
// FIXME: Migrate interface to use ElementCount
101+
return NumElements == 1 ? ScalarTy
102+
: LLT::fixed_vector(NumElements, ScalarTy);
82103
}
83104

84105
static LLT scalarOrVector(uint16_t NumElements, unsigned ScalarSize) {
@@ -150,9 +171,7 @@ class LLT {
150171
/// If this type is a vector, return a vector with the same number of elements
151172
/// but the new element type. Otherwise, return the new element type.
152173
LLT changeElementType(LLT NewEltTy) const {
153-
return isVector() ? LLT::vector(getElementCount().getKnownMinValue(),
154-
NewEltTy, isScalable())
155-
: NewEltTy;
174+
return isVector() ? LLT::vector(getElementCount(), NewEltTy) : NewEltTy;
156175
}
157176

158177
/// If this type is a vector, return a vector with the same number of elements
@@ -161,8 +180,7 @@ class LLT {
161180
LLT changeElementSize(unsigned NewEltSize) const {
162181
assert(!getScalarType().isPointer() &&
163182
"invalid to directly change element size for pointers");
164-
return isVector() ? LLT::vector(getElementCount().getKnownMinValue(),
165-
NewEltSize, isScalable())
183+
return isVector() ? LLT::vector(getElementCount(), NewEltSize)
166184
: LLT::scalar(NewEltSize);
167185
}
168186

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
438438
} else {
439439
// Vector was split, and elements promoted to a wider type.
440440
// FIXME: Should handle floating point promotions.
441-
LLT BVType = LLT::vector(LLTy.getNumElements(), PartLLT);
441+
LLT BVType = LLT::fixed_vector(LLTy.getNumElements(), PartLLT);
442442
auto BV = B.buildBuildVector(BVType, Regs);
443443
B.buildTrunc(OrigRegs[0], BV);
444444
}

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1479,7 +1479,8 @@ bool IRTranslator::translateGetElementPtr(const User &U,
14791479
// are vectors.
14801480
if (VectorWidth && !PtrTy.isVector()) {
14811481
BaseReg =
1482-
MIRBuilder.buildSplatVector(LLT::vector(VectorWidth, PtrTy), BaseReg)
1482+
MIRBuilder
1483+
.buildSplatVector(LLT::fixed_vector(VectorWidth, PtrTy), BaseReg)
14831484
.getReg(0);
14841485
PtrIRTy = FixedVectorType::get(PtrIRTy, VectorWidth);
14851486
PtrTy = getLLTForType(*PtrIRTy, *DL);

llvm/lib/CodeGen/GlobalISel/LegacyLegalizerInfo.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,8 @@ LegacyLegalizerInfo::findVectorLegalAction(const InstrAspect &Aspect) const {
342342
LLT IntermediateType;
343343
auto ElementSizeAndAction =
344344
findAction(ElemSizeVec, Aspect.Type.getScalarSizeInBits());
345-
IntermediateType =
346-
LLT::vector(Aspect.Type.getNumElements(), ElementSizeAndAction.first);
345+
IntermediateType = LLT::fixed_vector(Aspect.Type.getNumElements(),
346+
ElementSizeAndAction.first);
347347
if (ElementSizeAndAction.second != Legal)
348348
return {ElementSizeAndAction.second, IntermediateType};
349349

@@ -356,8 +356,8 @@ LegacyLegalizerInfo::findVectorLegalAction(const InstrAspect &Aspect) const {
356356
auto NumElementsAndAction =
357357
findAction(NumElementsVec, IntermediateType.getNumElements());
358358
return {NumElementsAndAction.second,
359-
LLT::vector(NumElementsAndAction.first,
360-
IntermediateType.getScalarSizeInBits())};
359+
LLT::fixed_vector(NumElementsAndAction.first,
360+
IntermediateType.getScalarSizeInBits())};
361361
}
362362

363363
unsigned LegacyLegalizerInfo::getOpcodeIdxForOpcode(unsigned Opcode) const {

llvm/lib/CodeGen/GlobalISel/LegalizeMutations.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ LegalizeMutation LegalizeMutations::moreElementsToNextPow2(unsigned TypeIdx,
6969
const LLT VecTy = Query.Types[TypeIdx];
7070
unsigned NewNumElements =
7171
std::max(1u << Log2_32_Ceil(VecTy.getNumElements()), Min);
72-
return std::make_pair(TypeIdx,
73-
LLT::vector(NewNumElements, VecTy.getElementType()));
72+
return std::make_pair(
73+
TypeIdx, LLT::fixed_vector(NewNumElements, VecTy.getElementType()));
7474
};
7575
}
7676

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalar(MachineInstr &MI,
798798
if (SizeOp0 % NarrowSize != 0) {
799799
LLT ImplicitTy = NarrowTy;
800800
if (DstTy.isVector())
801-
ImplicitTy = LLT::vector(DstTy.getNumElements(), ImplicitTy);
801+
ImplicitTy = LLT::vector(DstTy.getElementCount(), ImplicitTy);
802802

803803
Register ImplicitReg = MIRBuilder.buildUndef(ImplicitTy).getReg(0);
804804
MIRBuilder.buildAnyExt(DstReg, ImplicitReg);
@@ -2286,9 +2286,9 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
22862286
LLT VecTy = MRI.getType(VecReg);
22872287
Observer.changingInstr(MI);
22882288

2289-
widenScalarSrc(MI, LLT::vector(VecTy.getNumElements(),
2290-
WideTy.getSizeInBits()),
2291-
1, TargetOpcode::G_SEXT);
2289+
widenScalarSrc(
2290+
MI, LLT::vector(VecTy.getElementCount(), WideTy.getSizeInBits()), 1,
2291+
TargetOpcode::G_SEXT);
22922292

22932293
widenScalarDst(MI, WideTy, 0);
22942294
Observer.changedInstr(MI);
@@ -2309,7 +2309,7 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
23092309

23102310
Register VecReg = MI.getOperand(1).getReg();
23112311
LLT VecTy = MRI.getType(VecReg);
2312-
LLT WideVecTy = LLT::vector(VecTy.getNumElements(), WideTy);
2312+
LLT WideVecTy = LLT::vector(VecTy.getElementCount(), WideTy);
23132313

23142314
widenScalarSrc(MI, WideVecTy, 1, TargetOpcode::G_ANYEXT);
23152315
widenScalarSrc(MI, WideTy, 2, TargetOpcode::G_ANYEXT);
@@ -2469,7 +2469,7 @@ LegalizerHelper::lowerBitcast(MachineInstr &MI) {
24692469
// %3:_(<2 x s8>) = G_BITCAST %2
24702470
// %4:_(<2 x s8>) = G_BITCAST %3
24712471
// %1:_(<4 x s16>) = G_CONCAT_VECTORS %3, %4
2472-
DstCastTy = LLT::vector(NumDstElt / NumSrcElt, DstEltTy);
2472+
DstCastTy = LLT::fixed_vector(NumDstElt / NumSrcElt, DstEltTy);
24732473
SrcPartTy = SrcEltTy;
24742474
} else if (NumSrcElt > NumDstElt) { // Source element type is smaller.
24752475
//
@@ -2481,7 +2481,7 @@ LegalizerHelper::lowerBitcast(MachineInstr &MI) {
24812481
// %3:_(s16) = G_BITCAST %2
24822482
// %4:_(s16) = G_BITCAST %3
24832483
// %1:_(<2 x s16>) = G_BUILD_VECTOR %3, %4
2484-
SrcPartTy = LLT::vector(NumSrcElt / NumDstElt, SrcEltTy);
2484+
SrcPartTy = LLT::fixed_vector(NumSrcElt / NumDstElt, SrcEltTy);
24852485
DstCastTy = DstEltTy;
24862486
}
24872487

@@ -3397,7 +3397,7 @@ LegalizerHelper::fewerElementsVectorCasts(MachineInstr &MI, unsigned TypeIdx,
33973397
if (NumParts * NarrowTy.getNumElements() != DstTy.getNumElements())
33983398
return UnableToLegalize;
33993399

3400-
NarrowTy1 = LLT::vector(NarrowTy.getNumElements(), SrcTy.getElementType());
3400+
NarrowTy1 = LLT::vector(NarrowTy.getElementCount(), SrcTy.getElementType());
34013401
} else {
34023402
NumParts = DstTy.getNumElements();
34033403
NarrowTy1 = SrcTy.getElementType();
@@ -3441,18 +3441,18 @@ LegalizerHelper::fewerElementsVectorCmp(MachineInstr &MI, unsigned TypeIdx,
34413441

34423442
NarrowTy0 = NarrowTy;
34433443
NumParts = NarrowTy.isVector() ? (OldElts / NewElts) : DstTy.getNumElements();
3444-
NarrowTy1 = NarrowTy.isVector() ?
3445-
LLT::vector(NarrowTy.getNumElements(), SrcTy.getScalarSizeInBits()) :
3446-
SrcTy.getElementType();
3444+
NarrowTy1 = NarrowTy.isVector() ? LLT::vector(NarrowTy.getElementCount(),
3445+
SrcTy.getScalarSizeInBits())
3446+
: SrcTy.getElementType();
34473447

34483448
} else {
34493449
unsigned NewElts = NarrowTy.isVector() ? NarrowTy.getNumElements() : 1;
34503450
unsigned OldElts = SrcTy.getNumElements();
34513451

34523452
NumParts = NarrowTy.isVector() ? (OldElts / NewElts) :
34533453
NarrowTy.getNumElements();
3454-
NarrowTy0 = LLT::vector(NarrowTy.getNumElements(),
3455-
DstTy.getScalarSizeInBits());
3454+
NarrowTy0 =
3455+
LLT::vector(NarrowTy.getElementCount(), DstTy.getScalarSizeInBits());
34563456
NarrowTy1 = NarrowTy;
34573457
}
34583458

@@ -3523,8 +3523,9 @@ LegalizerHelper::fewerElementsVectorSelect(MachineInstr &MI, unsigned TypeIdx,
35233523
if (CondTy.getNumElements() == NumParts)
35243524
NarrowTy1 = CondTy.getElementType();
35253525
else
3526-
NarrowTy1 = LLT::vector(CondTy.getNumElements() / NumParts,
3527-
CondTy.getScalarSizeInBits());
3526+
NarrowTy1 =
3527+
LLT::vector(CondTy.getElementCount().divideCoefficientBy(NumParts),
3528+
CondTy.getScalarSizeInBits());
35283529
}
35293530
} else {
35303531
NumParts = CondTy.getNumElements();

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -773,21 +773,22 @@ LLT llvm::getLCMType(LLT OrigTy, LLT TargetTy) {
773773
int GCDElts = greatestCommonDivisor(OrigTy.getNumElements(),
774774
TargetTy.getNumElements());
775775
// Prefer the original element type.
776-
int Mul = OrigTy.getNumElements() * TargetTy.getNumElements();
777-
return LLT::vector(Mul / GCDElts, OrigTy.getElementType());
776+
ElementCount Mul = OrigTy.getElementCount() * TargetTy.getNumElements();
777+
return LLT::vector(Mul.divideCoefficientBy(GCDElts),
778+
OrigTy.getElementType());
778779
}
779780
} else {
780781
if (OrigElt.getSizeInBits() == TargetSize)
781782
return OrigTy;
782783
}
783784

784785
unsigned LCMSize = getLCMSize(OrigSize, TargetSize);
785-
return LLT::vector(LCMSize / OrigElt.getSizeInBits(), OrigElt);
786+
return LLT::fixed_vector(LCMSize / OrigElt.getSizeInBits(), OrigElt);
786787
}
787788

788789
if (TargetTy.isVector()) {
789790
unsigned LCMSize = getLCMSize(OrigSize, TargetSize);
790-
return LLT::vector(LCMSize / OrigSize, OrigTy);
791+
return LLT::fixed_vector(LCMSize / OrigSize, OrigTy);
791792
}
792793

793794
unsigned LCMSize = getLCMSize(OrigSize, TargetSize);
@@ -831,7 +832,7 @@ LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
831832
// scalar.
832833
if (GCD < OrigElt.getSizeInBits())
833834
return LLT::scalar(GCD);
834-
return LLT::vector(GCD / OrigElt.getSizeInBits(), OrigElt);
835+
return LLT::fixed_vector(GCD / OrigElt.getSizeInBits(), OrigElt);
835836
}
836837

837838
if (TargetTy.isVector()) {

llvm/lib/CodeGen/LowLevelType.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) {
2424
LLT ScalarTy = getLLTForType(*VTy->getElementType(), DL);
2525
if (EC.isScalar())
2626
return ScalarTy;
27-
return LLT::vector(EC.getKnownMinValue(), ScalarTy, EC.isScalable());
27+
return LLT::vector(EC, ScalarTy);
2828
}
2929

3030
if (auto PTy = dyn_cast<PointerType>(&Ty)) {
@@ -56,7 +56,7 @@ LLT llvm::getLLTForMVT(MVT Ty) {
5656
if (!Ty.isVector())
5757
return LLT::scalar(Ty.getSizeInBits());
5858

59-
return LLT::vector(Ty.getVectorNumElements(),
59+
return LLT::vector(Ty.getVectorElementCount(),
6060
Ty.getVectorElementType().getSizeInBits());
6161
}
6262

llvm/lib/CodeGen/MIRParser/MIParser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1729,7 +1729,7 @@ bool MIParser::parseLowLevelType(StringRef::iterator Loc, LLT &Ty) {
17291729
return error(Loc, "expected <M x sN> or <M x pA> for vector type");
17301730
lex();
17311731

1732-
Ty = LLT::vector(NumElements, Ty);
1732+
Ty = LLT::fixed_vector(NumElements, Ty);
17331733
return false;
17341734
}
17351735

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,7 +1818,7 @@ bool AArch64TargetLowering::allowsMisalignedMemoryAccesses(
18181818

18191819
// Disregard v2i64. Memcpy lowering produces those and splitting
18201820
// them regresses performance on micro-benchmarks and olden/bh.
1821-
Ty == LLT::vector(2, 64);
1821+
Ty == LLT::fixed_vector(2, 64);
18221822
}
18231823
return true;
18241824
}
@@ -11756,7 +11756,7 @@ LLT AArch64TargetLowering::getOptimalMemOpLLT(
1175611756

1175711757
if (CanUseNEON && Op.isMemset() && !IsSmallMemset &&
1175811758
AlignmentIsAcceptable(MVT::v2i64, Align(16)))
11759-
return LLT::vector(2, 64);
11759+
return LLT::fixed_vector(2, 64);
1176011760
if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16)))
1176111761
return LLT::scalar(128);
1176211762
if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8)))

0 commit comments

Comments
 (0)