Skip to content

[RISCV][GlobalISel] Legalize Scalable Vector Loads and Stores #84965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 31, 2024
Merged
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ LegalityPredicate LegalityPredicates::memSizeNotByteSizePow2(unsigned MMOIdx) {
return [=](const LegalityQuery &Query) {
const LLT MemTy = Query.MMODescrs[MMOIdx].MemoryTy;
return !MemTy.isByteSized() ||
!llvm::has_single_bit<uint32_t>(MemTy.getSizeInBytes());
!llvm::has_single_bit<uint32_t>(
MemTy.getSizeInBytes().getKnownMinValue());
};
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/MIRParser/MIParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3388,7 +3388,7 @@ bool MIParser::parseMachineMemoryOperand(MachineMemOperand *&Dest) {
if (expectAndConsume(MIToken::rparen))
return true;

Size = MemoryType.getSizeInBytes();
Size = MemoryType.getSizeInBytes().getKnownMinValue();
}

MachinePointerInfo Ptr = MachinePointerInfo();
Expand Down
102 changes: 101 additions & 1 deletion llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/MachineConstantPool.h"
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/CodeGen/ValueTypes.h"
Expand Down Expand Up @@ -67,6 +68,17 @@ typeIsLegalBoolVec(unsigned TypeIdx, std::initializer_list<LLT> BoolVecTys,
return all(typeInSet(TypeIdx, BoolVecTys), P);
}

static LegalityPredicate typeIsLegalPtrVec(unsigned TypeIdx,
std::initializer_list<LLT> PtrVecTys,
const RISCVSubtarget &ST) {
LegalityPredicate P = [=, &ST](const LegalityQuery &Query) {
return ST.hasVInstructions() &&
(Query.Types[TypeIdx].getElementCount().getKnownMinValue() != 1 ||
ST.getELen() == 64);
};
return all(typeInSet(TypeIdx, PtrVecTys), P);
}

RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
: STI(ST), XLen(STI.getXLen()), sXLen(LLT::scalar(XLen)) {
const LLT sDoubleXLen = LLT::scalar(2 * XLen);
Expand Down Expand Up @@ -111,6 +123,11 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
const LLT nxv4s64 = LLT::scalable_vector(4, s64);
const LLT nxv8s64 = LLT::scalable_vector(8, s64);

const LLT nxv1p0 = LLT::scalable_vector(1, p0);
const LLT nxv2p0 = LLT::scalable_vector(2, p0);
const LLT nxv4p0 = LLT::scalable_vector(4, p0);
const LLT nxv8p0 = LLT::scalable_vector(8, p0);

using namespace TargetOpcode;

auto BoolVecTys = {nxv1s1, nxv2s1, nxv4s1, nxv8s1, nxv16s1, nxv32s1, nxv64s1};
Expand All @@ -120,6 +137,8 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
nxv32s16, nxv1s32, nxv2s32, nxv4s32, nxv8s32, nxv16s32,
nxv1s64, nxv2s64, nxv4s64, nxv8s64};

auto PtrVecTys = {nxv1p0, nxv2p0, nxv4p0, nxv8p0};

getActionDefinitionsBuilder({G_ADD, G_SUB, G_AND, G_OR, G_XOR})
.legalFor({s32, sXLen})
.legalIf(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST))
Expand Down Expand Up @@ -266,6 +285,23 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
{s32, p0, s16, 16},
{s32, p0, s32, 32},
{p0, p0, sXLen, XLen}});
if (ST.hasVInstructions())
LoadStoreActions.legalForTypesWithMemDesc({{nxv2s8, p0, nxv2s8, 8},
{nxv4s8, p0, nxv4s8, 8},
{nxv8s8, p0, nxv8s8, 8},
{nxv16s8, p0, nxv16s8, 8},
{nxv32s8, p0, nxv32s8, 8},
{nxv64s8, p0, nxv64s8, 8},
{nxv2s16, p0, nxv2s16, 16},
{nxv4s16, p0, nxv4s16, 16},
{nxv8s16, p0, nxv8s16, 16},
{nxv16s16, p0, nxv16s16, 16},
{nxv32s16, p0, nxv32s16, 16},
{nxv2s32, p0, nxv2s32, 32},
{nxv4s32, p0, nxv4s32, 32},
{nxv8s32, p0, nxv8s32, 32},
{nxv16s32, p0, nxv16s32, 32}});

auto &ExtLoadActions =
getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD})
.legalForTypesWithMemDesc({{s32, p0, s8, 8}, {s32, p0, s16, 16}});
Expand All @@ -279,7 +315,28 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
} else if (ST.hasStdExtD()) {
LoadStoreActions.legalForTypesWithMemDesc({{s64, p0, s64, 64}});
}
LoadStoreActions.clampScalar(0, s32, sXLen).lower();
if (ST.hasVInstructions() && ST.getELen() == 64)
LoadStoreActions.legalForTypesWithMemDesc({{nxv1s8, p0, nxv1s8, 8},
{nxv1s16, p0, nxv1s16, 16},
{nxv1s32, p0, nxv1s32, 32}});

if (ST.hasVInstructionsI64())
LoadStoreActions.legalForTypesWithMemDesc({{nxv1s64, p0, nxv1s64, 64},

{nxv2s64, p0, nxv2s64, 64},
{nxv4s64, p0, nxv4s64, 64},
{nxv8s64, p0, nxv8s64, 64}});

LoadStoreActions.widenScalarToNextPow2(0, /* MinSize = */ 8)
.lowerIfMemSizeNotByteSizePow2()
// we will take the custom lowering logic if we have scalable vector types
// with non-standard alignments
.customIf(LegalityPredicate(
LegalityPredicates::any(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
typeIsLegalPtrVec(0, PtrVecTys, ST))))
.clampScalar(0, s32, sXLen)
.lower();

ExtLoadActions.widenScalarToNextPow2(0).clampScalar(0, s32, sXLen).lower();

getActionDefinitionsBuilder({G_PTR_ADD, G_PTRMASK}).legalFor({{p0, sXLen}});
Expand Down Expand Up @@ -651,6 +708,46 @@ bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
return true;
}

bool RISCVLegalizerInfo::legalizeLoadStore(MachineInstr &MI,
LegalizerHelper &Helper,
MachineIRBuilder &MIB) const {
assert((isa<GLoad>(MI) || isa<GStore>(MI)) &&
"Machine instructions must be Load/Store.");
MachineRegisterInfo &MRI = *MIB.getMRI();
MachineFunction *MF = MI.getMF();
const DataLayout &DL = MIB.getDataLayout();
LLVMContext &Ctx = MF->getFunction().getContext();

Register DstReg = MI.getOperand(0).getReg();
LLT DataTy = MRI.getType(DstReg);
if (!DataTy.isVector())
return false;

if (!MI.hasOneMemOperand())
return false;

MachineMemOperand *MMO = *MI.memoperands_begin();

const auto *TLI = STI.getTargetLowering();
EVT VT = EVT::getEVT(getTypeForLLT(DataTy, Ctx));

if (TLI->allowsMemoryAccessForAlignment(Ctx, DL, VT, *MMO))
return true;
Comment on lines +734 to +735
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should really be considered in the default lower action. I've had a patch for years I need to get back to


unsigned EltSizeBits = DataTy.getScalarSizeInBits();
assert((EltSizeBits == 16 || EltSizeBits == 32 || EltSizeBits == 64) &&
"Unexpected unaligned RVV load type");

// Calculate the new vector type with i8 elements
unsigned NumElements =
DataTy.getElementCount().getKnownMinValue() * (EltSizeBits / 8);
LLT NewDataTy = LLT::scalable_vector(NumElements, 8);

Helper.bitcast(MI, 0, NewDataTy);

return true;
}

/// Return the type of the mask type suitable for masking the provided
/// vector type. This is simply an i1 element type vector of the same
/// (possibly scalable) length.
Expand Down Expand Up @@ -828,6 +925,9 @@ bool RISCVLegalizerInfo::legalizeCustom(
return legalizeExt(MI, MIRBuilder);
case TargetOpcode::G_SPLAT_VECTOR:
return legalizeSplatVector(MI, MIRBuilder);
case TargetOpcode::G_LOAD:
case TargetOpcode::G_STORE:
return legalizeLoadStore(MI, Helper, MIRBuilder);
}

llvm_unreachable("expected switch to return");
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef LLVM_LIB_TARGET_RISCV_RISCVMACHINELEGALIZER_H
#define LLVM_LIB_TARGET_RISCV_RISCVMACHINELEGALIZER_H

#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
#include "llvm/CodeGen/Register.h"

Expand Down Expand Up @@ -45,6 +46,8 @@ class RISCVLegalizerInfo : public LegalizerInfo {
bool legalizeVScale(MachineInstr &MI, MachineIRBuilder &MIB) const;
bool legalizeExt(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
bool legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const;
bool legalizeLoadStore(MachineInstr &MI, LegalizerHelper &Helper,
MachineIRBuilder &MIB) const;
};
} // end namespace llvm
#endif
Loading