Skip to content

Commit cd768ec

Browse files
authored
[AArch64] Support scalable offsets with isLegalAddressingMode (#83255)
Allows us to indicate that an addressing mode featuring a vscale-relative immediate offset is supported.
1 parent fe13412 commit cd768ec

File tree

8 files changed

+91
-14
lines changed

8 files changed

+91
-14
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -707,11 +707,15 @@ class TargetTransformInfo {
707707
/// The type may be VoidTy, in which case only return true if the addressing
708708
/// mode is legal for a load/store of any legal type.
709709
/// If target returns true in LSRWithInstrQueries(), I may be valid.
710+
/// \param ScalableOffset represents a quantity of bytes multiplied by vscale,
711+
/// an invariant value known only at runtime. Most targets should not accept
712+
/// a scalable offset.
713+
///
710714
/// TODO: Handle pre/postinc as well.
711715
bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
712716
bool HasBaseReg, int64_t Scale,
713-
unsigned AddrSpace = 0,
714-
Instruction *I = nullptr) const;
717+
unsigned AddrSpace = 0, Instruction *I = nullptr,
718+
int64_t ScalableOffset = 0) const;
715719

716720
/// Return true if LSR cost of C1 is lower than C2.
717721
bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
@@ -1842,7 +1846,8 @@ class TargetTransformInfo::Concept {
18421846
virtual bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
18431847
int64_t BaseOffset, bool HasBaseReg,
18441848
int64_t Scale, unsigned AddrSpace,
1845-
Instruction *I) = 0;
1849+
Instruction *I,
1850+
int64_t ScalableOffset) = 0;
18461851
virtual bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
18471852
const TargetTransformInfo::LSRCost &C2) = 0;
18481853
virtual bool isNumRegsMajorCostOfLSR() = 0;
@@ -2303,9 +2308,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
23032308
}
23042309
bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
23052310
bool HasBaseReg, int64_t Scale, unsigned AddrSpace,
2306-
Instruction *I) override {
2311+
Instruction *I, int64_t ScalableOffset) override {
23072312
return Impl.isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
2308-
AddrSpace, I);
2313+
AddrSpace, I, ScalableOffset);
23092314
}
23102315
bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
23112316
const TargetTransformInfo::LSRCost &C2) override {

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ class TargetTransformInfoImplBase {
220220

221221
bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
222222
bool HasBaseReg, int64_t Scale, unsigned AddrSpace,
223-
Instruction *I = nullptr) const {
223+
Instruction *I = nullptr,
224+
int64_t ScalableOffset = 0) const {
224225
// Guess that only reg and reg+reg addressing is allowed. This heuristic is
225226
// taken from the implementation of LSR.
226227
return !BaseGV && BaseOffset == 0 && (Scale == 0 || Scale == 1);

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,13 +333,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
333333
}
334334

335335
bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
336-
bool HasBaseReg, int64_t Scale,
337-
unsigned AddrSpace, Instruction *I = nullptr) {
336+
bool HasBaseReg, int64_t Scale, unsigned AddrSpace,
337+
Instruction *I = nullptr,
338+
int64_t ScalableOffset = 0) {
338339
TargetLoweringBase::AddrMode AM;
339340
AM.BaseGV = BaseGV;
340341
AM.BaseOffs = BaseOffset;
341342
AM.HasBaseReg = HasBaseReg;
342343
AM.Scale = Scale;
344+
AM.ScalableOffset = ScalableOffset;
343345
return getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace, I);
344346
}
345347

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2722,17 +2722,19 @@ class TargetLoweringBase {
27222722
}
27232723

27242724
/// This represents an addressing mode of:
2725-
/// BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
2725+
/// BaseGV + BaseOffs + BaseReg + Scale*ScaleReg + ScalableOffset*vscale
27262726
/// If BaseGV is null, there is no BaseGV.
27272727
/// If BaseOffs is zero, there is no base offset.
27282728
/// If HasBaseReg is false, there is no base register.
27292729
/// If Scale is zero, there is no ScaleReg. Scale of 1 indicates a reg with
27302730
/// no scale.
2731+
/// If ScalableOffset is zero, there is no scalable offset.
27312732
struct AddrMode {
27322733
GlobalValue *BaseGV = nullptr;
27332734
int64_t BaseOffs = 0;
27342735
bool HasBaseReg = false;
27352736
int64_t Scale = 0;
2737+
int64_t ScalableOffset = 0;
27362738
AddrMode() = default;
27372739
};
27382740

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,9 +404,10 @@ bool TargetTransformInfo::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
404404
int64_t BaseOffset,
405405
bool HasBaseReg, int64_t Scale,
406406
unsigned AddrSpace,
407-
Instruction *I) const {
407+
Instruction *I,
408+
int64_t ScalableOffset) const {
408409
return TTIImpl->isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg,
409-
Scale, AddrSpace, I);
410+
Scale, AddrSpace, I, ScalableOffset);
410411
}
411412

412413
bool TargetTransformInfo::isLSRCostLess(const LSRCost &C1,

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2011,6 +2011,10 @@ bool TargetLoweringBase::isLegalAddressingMode(const DataLayout &DL,
20112011
// The default implementation of this implements a conservative RISCy, r+r and
20122012
// r+i addr mode.
20132013

2014+
// Scalable offsets not supported
2015+
if (AM.ScalableOffset)
2016+
return false;
2017+
20142018
// Allows a sign-extended 16-bit immediate field.
20152019
if (AM.BaseOffs <= -(1LL << 16) || AM.BaseOffs >= (1LL << 16)-1)
20162020
return false;

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16692,15 +16692,29 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,
1669216692

1669316693
if (Ty->isScalableTy()) {
1669416694
if (isa<ScalableVectorType>(Ty)) {
16695+
// See if we have a foldable vscale-based offset, for vector types which
16696+
// are either legal or smaller than the minimum; more work will be
16697+
// required if we need to consider addressing for types which need
16698+
// legalization by splitting.
16699+
uint64_t VecNumBytes = DL.getTypeSizeInBits(Ty).getKnownMinValue() / 8;
16700+
if (AM.HasBaseReg && !AM.BaseOffs && AM.ScalableOffset && !AM.Scale &&
16701+
(AM.ScalableOffset % VecNumBytes == 0) && VecNumBytes <= 16 &&
16702+
isPowerOf2_64(VecNumBytes))
16703+
return isInt<4>(AM.ScalableOffset / (int64_t)VecNumBytes);
16704+
1669516705
uint64_t VecElemNumBytes =
1669616706
DL.getTypeSizeInBits(cast<VectorType>(Ty)->getElementType()) / 8;
16697-
return AM.HasBaseReg && !AM.BaseOffs &&
16707+
return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset &&
1669816708
(AM.Scale == 0 || (uint64_t)AM.Scale == VecElemNumBytes);
1669916709
}
1670016710

16701-
return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale;
16711+
return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset && !AM.Scale;
1670216712
}
1670316713

16714+
// No scalable offsets allowed for non-scalable types.
16715+
if (AM.ScalableOffset)
16716+
return false;
16717+
1670416718
// check reg + imm case:
1670516719
// i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12
1670616720
uint64_t NumBytes = 0;

llvm/unittests/Target/AArch64/AddressingModes.cpp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ using namespace llvm;
1313
namespace {
1414

1515
struct AddrMode : public TargetLowering::AddrMode {
16-
constexpr AddrMode(GlobalValue *GV, int64_t Offs, bool HasBase, int64_t S) {
16+
constexpr AddrMode(GlobalValue *GV, int64_t Offs, bool HasBase, int64_t S,
17+
int64_t SOffs = 0) {
1718
BaseGV = GV;
1819
BaseOffs = Offs;
1920
HasBaseReg = HasBase;
2021
Scale = S;
22+
ScalableOffset = SOffs;
2123
}
2224
};
2325
struct TestCase {
@@ -153,6 +155,45 @@ const std::initializer_list<TestCase> Tests = {
153155
{{nullptr, 4096 + 1, true, 0}, 8, false},
154156

155157
};
158+
159+
struct SVETestCase {
160+
AddrMode AM;
161+
unsigned TypeBits;
162+
unsigned NumElts;
163+
bool Result;
164+
};
165+
166+
const std::initializer_list<SVETestCase> SVETests = {
167+
// {BaseGV, BaseOffs, HasBaseReg, Scale, SOffs}, EltBits, Count, Result
168+
// Test immediate range -- [-8,7] vector's worth.
169+
// <vscale x 16 x i8>, increment by one vector
170+
{{nullptr, 0, true, 0, 16}, 8, 16, true},
171+
// <vscale x 4 x i32>, increment by eight vectors
172+
{{nullptr, 0, true, 0, 128}, 32, 4, false},
173+
// <vscale x 8 x i16>, increment by seven vectors
174+
{{nullptr, 0, true, 0, 112}, 16, 8, true},
175+
// <vscale x 2 x i64>, decrement by eight vectors
176+
{{nullptr, 0, true, 0, -128}, 64, 2, true},
177+
// <vscale x 16 x i8>, decrement by nine vectors
178+
{{nullptr, 0, true, 0, -144}, 8, 16, false},
179+
180+
// Half the size of a vector register, but allowable with extending
181+
// loads and truncating stores
182+
// <vscale x 8 x i8>, increment by three vectors
183+
{{nullptr, 0, true, 0, 24}, 8, 8, true},
184+
185+
// Test invalid types or offsets
186+
// <vscale x 5 x i32>, increment by one vector (base size > 16B)
187+
{{nullptr, 0, true, 0, 20}, 32, 5, false},
188+
// <vscale x 8 x i16>, increment by half a vector
189+
{{nullptr, 0, true, 0, 8}, 16, 8, false},
190+
// <vscale x 3 x i8>, increment by 3 vectors (non-power-of-two)
191+
{{nullptr, 0, true, 0, 9}, 8, 3, false},
192+
193+
// Scalable and fixed offsets
194+
// <vscale x 16 x i8>, increment by 32 then decrement by vscale x 16
195+
{{nullptr, 32, true, 0, -16}, 8, 16, false},
196+
};
156197
} // namespace
157198

158199
TEST(AddressingModes, AddressingModes) {
@@ -179,4 +220,11 @@ TEST(AddressingModes, AddressingModes) {
179220
Type *Typ = Type::getIntNTy(Ctx, Test.TypeBits);
180221
ASSERT_EQ(TLI->isLegalAddressingMode(DL, Test.AM, Typ, 0), Test.Result);
181222
}
223+
224+
for (const auto &SVETest : SVETests) {
225+
Type *Ty = VectorType::get(Type::getIntNTy(Ctx, SVETest.TypeBits),
226+
ElementCount::getScalable(SVETest.NumElts));
227+
ASSERT_EQ(TLI->isLegalAddressingMode(DL, SVETest.AM, Ty, 0),
228+
SVETest.Result);
229+
}
182230
}

0 commit comments

Comments
 (0)