Skip to content

[AArch64] Support scalable offsets with isLegalAddressingMode #83255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -707,11 +707,15 @@ class TargetTransformInfo {
/// The type may be VoidTy, in which case only return true if the addressing
/// mode is legal for a load/store of any legal type.
/// If target returns true in LSRWithInstrQueries(), I may be valid.
/// \param ScalableOffset represents a quantity of bytes multiplied by vscale,
/// an invariant value known only at runtime. Most targets should not accept
/// a scalable offset.
///
/// TODO: Handle pre/postinc as well.
bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg, int64_t Scale,
unsigned AddrSpace = 0,
Instruction *I = nullptr) const;
unsigned AddrSpace = 0, Instruction *I = nullptr,
int64_t ScalableOffset = 0) const;

/// Return true if LSR cost of C1 is lower than C2.
bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
Expand Down Expand Up @@ -1839,7 +1843,8 @@ class TargetTransformInfo::Concept {
virtual bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset, bool HasBaseReg,
int64_t Scale, unsigned AddrSpace,
Instruction *I) = 0;
Instruction *I,
int64_t ScalableOffset) = 0;
virtual bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
const TargetTransformInfo::LSRCost &C2) = 0;
virtual bool isNumRegsMajorCostOfLSR() = 0;
Expand Down Expand Up @@ -2300,9 +2305,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
}
bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg, int64_t Scale, unsigned AddrSpace,
Instruction *I) override {
Instruction *I, int64_t ScalableOffset) override {
return Impl.isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale,
AddrSpace, I);
AddrSpace, I, ScalableOffset);
}
bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
const TargetTransformInfo::LSRCost &C2) override {
Expand Down
3 changes: 2 additions & 1 deletion llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ class TargetTransformInfoImplBase {

bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg, int64_t Scale, unsigned AddrSpace,
Instruction *I = nullptr) const {
Instruction *I = nullptr,
int64_t ScalableOffset = 0) const {
// Guess that only reg and reg+reg addressing is allowed. This heuristic is
// taken from the implementation of LSR.
return !BaseGV && BaseOffset == 0 && (Scale == 0 || Scale == 1);
Expand Down
6 changes: 4 additions & 2 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
}

bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg, int64_t Scale,
unsigned AddrSpace, Instruction *I = nullptr) {
bool HasBaseReg, int64_t Scale, unsigned AddrSpace,
Instruction *I = nullptr,
int64_t ScalableOffset = 0) {
TargetLoweringBase::AddrMode AM;
AM.BaseGV = BaseGV;
AM.BaseOffs = BaseOffset;
AM.HasBaseReg = HasBaseReg;
AM.Scale = Scale;
AM.ScalableOffset = ScalableOffset;
return getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace, I);
}

Expand Down
4 changes: 3 additions & 1 deletion llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -2722,17 +2722,19 @@ class TargetLoweringBase {
}

/// This represents an addressing mode of:
/// BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
/// BaseGV + BaseOffs + BaseReg + Scale*ScaleReg + ScalableOffset*vscale
/// If BaseGV is null, there is no BaseGV.
/// If BaseOffs is zero, there is no base offset.
/// If HasBaseReg is false, there is no base register.
/// If Scale is zero, there is no ScaleReg. Scale of 1 indicates a reg with
/// no scale.
/// If ScalableOffset is zero, there is no scalable offset.
struct AddrMode {
GlobalValue *BaseGV = nullptr;
int64_t BaseOffs = 0;
bool HasBaseReg = false;
int64_t Scale = 0;
int64_t ScalableOffset = 0;
AddrMode() = default;
};

Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,10 @@ bool TargetTransformInfo::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset,
bool HasBaseReg, int64_t Scale,
unsigned AddrSpace,
Instruction *I) const {
Instruction *I,
int64_t ScalableOffset) const {
return TTIImpl->isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg,
Scale, AddrSpace, I);
Scale, AddrSpace, I, ScalableOffset);
}

bool TargetTransformInfo::isLSRCostLess(const LSRCost &C1,
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2011,6 +2011,10 @@ bool TargetLoweringBase::isLegalAddressingMode(const DataLayout &DL,
// The default implementation of this implements a conservative RISCy, r+r and
// r+i addr mode.

// Scalable offsets not supported
if (AM.ScalableOffset)
return false;

// Allows a sign-extended 16-bit immediate field.
if (AM.BaseOffs <= -(1LL << 16) || AM.BaseOffs >= (1LL << 16)-1)
return false;
Expand Down
18 changes: 16 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16671,15 +16671,29 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,

if (Ty->isScalableTy()) {
if (isa<ScalableVectorType>(Ty)) {
// See if we have a foldable vscale-based offset, for vector types which
// are either legal or smaller than the minimum; more work will be
// required if we need to consider addressing for types which need
// legalization by splitting.
uint64_t VecNumBytes = DL.getTypeSizeInBits(Ty).getKnownMinValue() / 8;
if (AM.HasBaseReg && !AM.BaseOffs && AM.ScalableOffset && !AM.Scale &&
(AM.ScalableOffset % VecNumBytes == 0) && VecNumBytes <= 16 &&
isPowerOf2_64(VecNumBytes))
return isInt<4>(AM.ScalableOffset / (int64_t)VecNumBytes);

uint64_t VecElemNumBytes =
DL.getTypeSizeInBits(cast<VectorType>(Ty)->getElementType()) / 8;
return AM.HasBaseReg && !AM.BaseOffs &&
return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset &&
(AM.Scale == 0 || (uint64_t)AM.Scale == VecElemNumBytes);
}

return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale;
return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset && !AM.Scale;
}

// No scalable offsets allowed for non-scalable types.
if (AM.ScalableOffset)
return false;

// check reg + imm case:
// i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12
uint64_t NumBytes = 0;
Expand Down
50 changes: 49 additions & 1 deletion llvm/unittests/Target/AArch64/AddressingModes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ using namespace llvm;
namespace {

struct AddrMode : public TargetLowering::AddrMode {
constexpr AddrMode(GlobalValue *GV, int64_t Offs, bool HasBase, int64_t S) {
constexpr AddrMode(GlobalValue *GV, int64_t Offs, bool HasBase, int64_t S,
int64_t SOffs = 0) {
BaseGV = GV;
BaseOffs = Offs;
HasBaseReg = HasBase;
Scale = S;
ScalableOffset = SOffs;
}
};
struct TestCase {
Expand Down Expand Up @@ -153,6 +155,45 @@ const std::initializer_list<TestCase> Tests = {
{{nullptr, 4096 + 1, true, 0}, 8, false},

};

struct SVETestCase {
AddrMode AM;
unsigned TypeBits;
unsigned NumElts;
bool Result;
};

const std::initializer_list<SVETestCase> SVETests = {
// {BaseGV, BaseOffs, HasBaseReg, Scale, SOffs}, EltBits, Count, Result
// Test immediate range -- [-8,7] vector's worth.
// <vscale x 16 x i8>, increment by one vector
{{nullptr, 0, true, 0, 16}, 8, 16, true},
// <vscale x 4 x i32>, increment by eight vectors
{{nullptr, 0, true, 0, 128}, 32, 4, false},
// <vscale x 8 x i16>, increment by seven vectors
{{nullptr, 0, true, 0, 112}, 16, 8, true},
// <vscale x 2 x i64>, decrement by eight vectors
{{nullptr, 0, true, 0, -128}, 64, 2, true},
// <vscale x 16 x i8>, decrement by nine vectors
{{nullptr, 0, true, 0, -144}, 8, 16, false},

// Half the size of a vector register, but allowable with extending
// loads and truncating stores
// <vscale x 8 x i8>, increment by three vectors
{{nullptr, 0, true, 0, 24}, 8, 8, true},

// Test invalid types or offsets
// <vscale x 5 x i32>, increment by one vector (base size > 16B)
{{nullptr, 0, true, 0, 20}, 32, 5, false},
// <vscale x 8 x i16>, increment by half a vector
{{nullptr, 0, true, 0, 8}, 16, 8, false},
// <vscale x 3 x i8>, increment by 3 vectors (non-power-of-two)
{{nullptr, 0, true, 0, 9}, 8, 3, false},

// Scalable and fixed offsets
// <vscale x 16 x i8>, increment by 32 then decrement by vscale x 16
{{nullptr, 32, true, 0, -16}, 8, 16, false},
};
} // namespace

TEST(AddressingModes, AddressingModes) {
Expand All @@ -179,4 +220,11 @@ TEST(AddressingModes, AddressingModes) {
Type *Typ = Type::getIntNTy(Ctx, Test.TypeBits);
ASSERT_EQ(TLI->isLegalAddressingMode(DL, Test.AM, Typ, 0), Test.Result);
}

for (const auto &SVETest : SVETests) {
Type *Ty = VectorType::get(Type::getIntNTy(Ctx, SVETest.TypeBits),
ElementCount::getScalable(SVETest.NumElts));
ASSERT_EQ(TLI->isLegalAddressingMode(DL, SVETest.AM, Ty, 0),
SVETest.Result);
}
}