Skip to content

Commit a0d8fa5

Browse files
authored
[RISCV][GlobalISel] Legalize Scalable Vector Loads and Stores (llvm#84965)
This patch supports legalizing load and store instruction for scalable vectors in RISCV
1 parent ee0f43a commit a0d8fa5

File tree

6 files changed

+2193
-3
lines changed

6 files changed

+2193
-3
lines changed

llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ LegalityPredicate LegalityPredicates::memSizeNotByteSizePow2(unsigned MMOIdx) {
194194
return [=](const LegalityQuery &Query) {
195195
const LLT MemTy = Query.MMODescrs[MMOIdx].MemoryTy;
196196
return !MemTy.isByteSized() ||
197-
!llvm::has_single_bit<uint32_t>(MemTy.getSizeInBytes());
197+
!llvm::has_single_bit<uint32_t>(
198+
MemTy.getSizeInBytes().getKnownMinValue());
198199
};
199200
}
200201

llvm/lib/CodeGen/MIRParser/MIParser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3388,7 +3388,7 @@ bool MIParser::parseMachineMemoryOperand(MachineMemOperand *&Dest) {
33883388
if (expectAndConsume(MIToken::rparen))
33893389
return true;
33903390

3391-
Size = MemoryType.getSizeInBytes();
3391+
Size = MemoryType.getSizeInBytes().getKnownMinValue();
33923392
}
33933393

33943394
MachinePointerInfo Ptr = MachinePointerInfo();

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
2020
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
2121
#include "llvm/CodeGen/MachineConstantPool.h"
22+
#include "llvm/CodeGen/MachineMemOperand.h"
2223
#include "llvm/CodeGen/MachineRegisterInfo.h"
2324
#include "llvm/CodeGen/TargetOpcodes.h"
2425
#include "llvm/CodeGen/ValueTypes.h"
@@ -67,6 +68,17 @@ typeIsLegalBoolVec(unsigned TypeIdx, std::initializer_list<LLT> BoolVecTys,
6768
return all(typeInSet(TypeIdx, BoolVecTys), P);
6869
}
6970

71+
static LegalityPredicate typeIsLegalPtrVec(unsigned TypeIdx,
72+
std::initializer_list<LLT> PtrVecTys,
73+
const RISCVSubtarget &ST) {
74+
LegalityPredicate P = [=, &ST](const LegalityQuery &Query) {
75+
return ST.hasVInstructions() &&
76+
(Query.Types[TypeIdx].getElementCount().getKnownMinValue() != 1 ||
77+
ST.getELen() == 64);
78+
};
79+
return all(typeInSet(TypeIdx, PtrVecTys), P);
80+
}
81+
7082
RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
7183
: STI(ST), XLen(STI.getXLen()), sXLen(LLT::scalar(XLen)) {
7284
const LLT sDoubleXLen = LLT::scalar(2 * XLen);
@@ -111,6 +123,11 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
111123
const LLT nxv4s64 = LLT::scalable_vector(4, s64);
112124
const LLT nxv8s64 = LLT::scalable_vector(8, s64);
113125

126+
const LLT nxv1p0 = LLT::scalable_vector(1, p0);
127+
const LLT nxv2p0 = LLT::scalable_vector(2, p0);
128+
const LLT nxv4p0 = LLT::scalable_vector(4, p0);
129+
const LLT nxv8p0 = LLT::scalable_vector(8, p0);
130+
114131
using namespace TargetOpcode;
115132

116133
auto BoolVecTys = {nxv1s1, nxv2s1, nxv4s1, nxv8s1, nxv16s1, nxv32s1, nxv64s1};
@@ -120,6 +137,8 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
120137
nxv32s16, nxv1s32, nxv2s32, nxv4s32, nxv8s32, nxv16s32,
121138
nxv1s64, nxv2s64, nxv4s64, nxv8s64};
122139

140+
auto PtrVecTys = {nxv1p0, nxv2p0, nxv4p0, nxv8p0};
141+
123142
getActionDefinitionsBuilder({G_ADD, G_SUB, G_AND, G_OR, G_XOR})
124143
.legalFor({s32, sXLen})
125144
.legalIf(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST))
@@ -266,6 +285,23 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
266285
{s32, p0, s16, 16},
267286
{s32, p0, s32, 32},
268287
{p0, p0, sXLen, XLen}});
288+
if (ST.hasVInstructions())
289+
LoadStoreActions.legalForTypesWithMemDesc({{nxv2s8, p0, nxv2s8, 8},
290+
{nxv4s8, p0, nxv4s8, 8},
291+
{nxv8s8, p0, nxv8s8, 8},
292+
{nxv16s8, p0, nxv16s8, 8},
293+
{nxv32s8, p0, nxv32s8, 8},
294+
{nxv64s8, p0, nxv64s8, 8},
295+
{nxv2s16, p0, nxv2s16, 16},
296+
{nxv4s16, p0, nxv4s16, 16},
297+
{nxv8s16, p0, nxv8s16, 16},
298+
{nxv16s16, p0, nxv16s16, 16},
299+
{nxv32s16, p0, nxv32s16, 16},
300+
{nxv2s32, p0, nxv2s32, 32},
301+
{nxv4s32, p0, nxv4s32, 32},
302+
{nxv8s32, p0, nxv8s32, 32},
303+
{nxv16s32, p0, nxv16s32, 32}});
304+
269305
auto &ExtLoadActions =
270306
getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD})
271307
.legalForTypesWithMemDesc({{s32, p0, s8, 8}, {s32, p0, s16, 16}});
@@ -279,7 +315,28 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
279315
} else if (ST.hasStdExtD()) {
280316
LoadStoreActions.legalForTypesWithMemDesc({{s64, p0, s64, 64}});
281317
}
282-
LoadStoreActions.clampScalar(0, s32, sXLen).lower();
318+
if (ST.hasVInstructions() && ST.getELen() == 64)
319+
LoadStoreActions.legalForTypesWithMemDesc({{nxv1s8, p0, nxv1s8, 8},
320+
{nxv1s16, p0, nxv1s16, 16},
321+
{nxv1s32, p0, nxv1s32, 32}});
322+
323+
if (ST.hasVInstructionsI64())
324+
LoadStoreActions.legalForTypesWithMemDesc({{nxv1s64, p0, nxv1s64, 64},
325+
326+
{nxv2s64, p0, nxv2s64, 64},
327+
{nxv4s64, p0, nxv4s64, 64},
328+
{nxv8s64, p0, nxv8s64, 64}});
329+
330+
LoadStoreActions.widenScalarToNextPow2(0, /* MinSize = */ 8)
331+
.lowerIfMemSizeNotByteSizePow2()
332+
// we will take the custom lowering logic if we have scalable vector types
333+
// with non-standard alignments
334+
.customIf(LegalityPredicate(
335+
LegalityPredicates::any(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
336+
typeIsLegalPtrVec(0, PtrVecTys, ST))))
337+
.clampScalar(0, s32, sXLen)
338+
.lower();
339+
283340
ExtLoadActions.widenScalarToNextPow2(0).clampScalar(0, s32, sXLen).lower();
284341

285342
getActionDefinitionsBuilder({G_PTR_ADD, G_PTRMASK}).legalFor({{p0, sXLen}});
@@ -651,6 +708,46 @@ bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
651708
return true;
652709
}
653710

711+
bool RISCVLegalizerInfo::legalizeLoadStore(MachineInstr &MI,
712+
LegalizerHelper &Helper,
713+
MachineIRBuilder &MIB) const {
714+
assert((isa<GLoad>(MI) || isa<GStore>(MI)) &&
715+
"Machine instructions must be Load/Store.");
716+
MachineRegisterInfo &MRI = *MIB.getMRI();
717+
MachineFunction *MF = MI.getMF();
718+
const DataLayout &DL = MIB.getDataLayout();
719+
LLVMContext &Ctx = MF->getFunction().getContext();
720+
721+
Register DstReg = MI.getOperand(0).getReg();
722+
LLT DataTy = MRI.getType(DstReg);
723+
if (!DataTy.isVector())
724+
return false;
725+
726+
if (!MI.hasOneMemOperand())
727+
return false;
728+
729+
MachineMemOperand *MMO = *MI.memoperands_begin();
730+
731+
const auto *TLI = STI.getTargetLowering();
732+
EVT VT = EVT::getEVT(getTypeForLLT(DataTy, Ctx));
733+
734+
if (TLI->allowsMemoryAccessForAlignment(Ctx, DL, VT, *MMO))
735+
return true;
736+
737+
unsigned EltSizeBits = DataTy.getScalarSizeInBits();
738+
assert((EltSizeBits == 16 || EltSizeBits == 32 || EltSizeBits == 64) &&
739+
"Unexpected unaligned RVV load type");
740+
741+
// Calculate the new vector type with i8 elements
742+
unsigned NumElements =
743+
DataTy.getElementCount().getKnownMinValue() * (EltSizeBits / 8);
744+
LLT NewDataTy = LLT::scalable_vector(NumElements, 8);
745+
746+
Helper.bitcast(MI, 0, NewDataTy);
747+
748+
return true;
749+
}
750+
654751
/// Return the type of the mask type suitable for masking the provided
655752
/// vector type. This is simply an i1 element type vector of the same
656753
/// (possibly scalable) length.
@@ -828,6 +925,9 @@ bool RISCVLegalizerInfo::legalizeCustom(
828925
return legalizeExt(MI, MIRBuilder);
829926
case TargetOpcode::G_SPLAT_VECTOR:
830927
return legalizeSplatVector(MI, MIRBuilder);
928+
case TargetOpcode::G_LOAD:
929+
case TargetOpcode::G_STORE:
930+
return legalizeLoadStore(MI, Helper, MIRBuilder);
831931
}
832932

833933
llvm_unreachable("expected switch to return");

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef LLVM_LIB_TARGET_RISCV_RISCVMACHINELEGALIZER_H
1414
#define LLVM_LIB_TARGET_RISCV_RISCVMACHINELEGALIZER_H
1515

16+
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
1617
#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
1718
#include "llvm/CodeGen/Register.h"
1819

@@ -45,6 +46,8 @@ class RISCVLegalizerInfo : public LegalizerInfo {
4546
bool legalizeVScale(MachineInstr &MI, MachineIRBuilder &MIB) const;
4647
bool legalizeExt(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
4748
bool legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const;
49+
bool legalizeLoadStore(MachineInstr &MI, LegalizerHelper &Helper,
50+
MachineIRBuilder &MIB) const;
4851
};
4952
} // end namespace llvm
5053
#endif

0 commit comments

Comments
 (0)