Skip to content

Commit 490594c

Browse files
[RISCV][GISEL] Legalize G_INSERT_SUBVECTOR
This code is heavily based on the SelectionDAG lowerINSERT_SUBVECTOR code.
1 parent 2f09c72 commit 490594c

File tree

7 files changed

+814
-5
lines changed

7 files changed

+814
-5
lines changed

llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,18 @@ class GExtractSubvector : public GenericMachineInstr {
811811
}
812812
};
813813

814+
/// Represents a insert subvector.
815+
class GInsertSubvector : public GenericMachineInstr {
816+
public:
817+
Register getBigVec() const { return getOperand(1).getReg(); }
818+
Register getSubVec() const { return getOperand(2).getReg(); }
819+
uint64_t getIndexImm() const { return getOperand(3).getImm(); }
820+
821+
static bool classof(const MachineInstr *MI) {
822+
return MI->getOpcode() == TargetOpcode::G_INSERT_SUBVECTOR;
823+
}
824+
};
825+
814826
/// Represents a freeze.
815827
class GFreeze : public GenericMachineInstr {
816828
public:

llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ class LegalizerHelper {
380380
LLT CastTy);
381381
LegalizeResult bitcastExtractSubvector(MachineInstr &MI, unsigned TypeIdx,
382382
LLT CastTy);
383+
LegalizeResult bitcastInsertSubvector(MachineInstr &MI, unsigned TypeIdx,
384+
LLT CastTy);
383385

384386
LegalizeResult lowerConstant(MachineInstr &MI);
385387
LegalizeResult lowerFConstant(MachineInstr &MI);

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3725,6 +3725,77 @@ LegalizerHelper::bitcastExtractSubvector(MachineInstr &MI, unsigned TypeIdx,
37253725
return Legalized;
37263726
}
37273727

3728+
/// This attempts to bitcast G_INSERT_SUBVECTOR to CastTy.
3729+
///
3730+
/// <vscale x 16 x i1> = G_INSERT_SUBVECTOR <vscale x 16 x i1>,
3731+
/// <vscale x 8 x i1>,
3732+
/// N
3733+
///
3734+
/// ===>
3735+
///
3736+
/// <vscale x 2 x i8> = G_BITCAST <vscale x 16 x i1>
3737+
/// <vscale x 1 x i8> = G_BITCAST <vscale x 8 x i1>
3738+
/// <vscale x 2 x i8> = G_INSERT_SUBVECTOR <vscale x 2 x i8>,
3739+
/// <vscale x 1 x i8>, N / 8
3740+
/// <vscale x 16 x i1> = G_BITCAST <vscale x 2 x i8>
3741+
LegalizerHelper::LegalizeResult
3742+
LegalizerHelper::bitcastInsertSubvector(MachineInstr &MI, unsigned TypeIdx,
3743+
LLT CastTy) {
3744+
auto ES = cast<GInsertSubvector>(&MI);
3745+
3746+
if (!CastTy.isVector())
3747+
return UnableToLegalize;
3748+
3749+
if (TypeIdx != 0)
3750+
return UnableToLegalize;
3751+
3752+
Register Dst = ES->getReg(0);
3753+
Register BigVec = ES->getBigVec();
3754+
Register SubVec = ES->getSubVec();
3755+
uint64_t Idx = ES->getIndexImm();
3756+
3757+
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
3758+
3759+
LLT DstTy = MRI.getType(Dst);
3760+
LLT BigVecTy = MRI.getType(BigVec);
3761+
LLT SubVecTy = MRI.getType(SubVec);
3762+
3763+
if (DstTy == CastTy)
3764+
return Legalized;
3765+
3766+
if (DstTy.getSizeInBits() != CastTy.getSizeInBits())
3767+
return UnableToLegalize;
3768+
3769+
ElementCount DstTyEC = DstTy.getElementCount();
3770+
ElementCount BigVecTyEC = BigVecTy.getElementCount();
3771+
ElementCount SubVecTyEC = SubVecTy.getElementCount();
3772+
auto DstTyMinElts = DstTyEC.getKnownMinValue();
3773+
auto BigVecTyMinElts = BigVecTyEC.getKnownMinValue();
3774+
auto SubVecTyMinElts = SubVecTyEC.getKnownMinValue();
3775+
3776+
unsigned CastEltSize = CastTy.getElementType().getSizeInBits();
3777+
unsigned DstEltSize = DstTy.getElementType().getSizeInBits();
3778+
if (CastEltSize < DstEltSize)
3779+
return UnableToLegalize;
3780+
3781+
auto AdjustAmt = CastEltSize / DstEltSize;
3782+
if (Idx % AdjustAmt != 0 || DstTyMinElts % AdjustAmt != 0 ||
3783+
BigVecTyMinElts % AdjustAmt != 0 || SubVecTyMinElts % AdjustAmt != 0)
3784+
return UnableToLegalize;
3785+
3786+
Idx /= AdjustAmt;
3787+
BigVecTy = LLT::vector(BigVecTyEC.divideCoefficientBy(AdjustAmt), AdjustAmt);
3788+
SubVecTy = LLT::vector(SubVecTyEC.divideCoefficientBy(AdjustAmt), AdjustAmt);
3789+
auto CastBigVec = MIRBuilder.buildBitcast(BigVecTy, BigVec);
3790+
auto CastSubVec = MIRBuilder.buildBitcast(SubVecTy, SubVec);
3791+
auto PromotedIS =
3792+
MIRBuilder.buildInsertSubvector(CastTy, CastBigVec, CastSubVec, Idx);
3793+
MIRBuilder.buildBitcast(Dst, PromotedIS);
3794+
3795+
ES->eraseFromParent();
3796+
return Legalized;
3797+
}
3798+
37283799
LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) {
37293800
// Lower to a memory-width G_LOAD and a G_SEXT/G_ZEXT/G_ANYEXT
37303801
Register DstReg = LoadMI.getDstReg();
@@ -4033,6 +4104,8 @@ LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) {
40334104
return bitcastConcatVector(MI, TypeIdx, CastTy);
40344105
case TargetOpcode::G_EXTRACT_SUBVECTOR:
40354106
return bitcastExtractSubvector(MI, TypeIdx, CastTy);
4107+
case TargetOpcode::G_INSERT_SUBVECTOR:
4108+
return bitcastInsertSubvector(MI, TypeIdx, CastTy);
40364109
default:
40374110
return UnableToLegalize;
40384111
}

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp

Lines changed: 148 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,13 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
615615
all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
616616
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST))));
617617

618+
619+
getActionDefinitionsBuilder(G_INSERT_SUBVECTOR)
620+
.customIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST),
621+
typeIsLegalBoolVec(1, BoolVecTys, ST)))
622+
.customIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
623+
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST)));
624+
618625
getLegacyLegalizerInfo().computeTables();
619626
}
620627

@@ -833,10 +840,8 @@ static MachineInstrBuilder buildAllOnesMask(LLT VecTy, const SrcOp &VL,
833840

834841
/// Gets the two common "VL" operands: an all-ones mask and the vector length.
835842
/// VecTy is a scalable vector type.
836-
static std::pair<MachineInstrBuilder, MachineInstrBuilder>
837-
buildDefaultVLOps(const DstOp &Dst, MachineIRBuilder &MIB,
838-
MachineRegisterInfo &MRI) {
839-
LLT VecTy = Dst.getLLTTy(MRI);
843+
static std::pair<MachineInstrBuilder, Register>
844+
buildDefaultVLOps(LLT VecTy, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
840845
assert(VecTy.isScalableVector() && "Expecting scalable container type");
841846
const RISCVSubtarget &STI = MIB.getMF().getSubtarget<RISCVSubtarget>();
842847
LLT XLenTy(STI.getXLenVT());
@@ -890,7 +895,7 @@ bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
890895
// Handle case of s64 element vectors on rv32
891896
if (XLenTy.getSizeInBits() == 32 &&
892897
VecTy.getElementType().getSizeInBits() == 64) {
893-
auto [_, VL] = buildDefaultVLOps(Dst, MIB, MRI);
898+
auto [_, VL] = buildDefaultVLOps(MRI.getType(Dst), MIB, MRI);
894899
buildSplatSplitS64WithVL(Dst, MIB.buildUndef(VecTy), SplatVal, VL, MIB,
895900
MRI);
896901
MI.eraseFromParent();
@@ -1025,6 +1030,142 @@ bool RISCVLegalizerInfo::legalizeExtractSubvector(MachineInstr &MI,
10251030
return true;
10261031
}
10271032

1033+
bool RISCVLegalizerInfo::legalizeInsertSubvector(MachineInstr &MI,
1034+
LegalizerHelper &Helper,
1035+
MachineIRBuilder &MIB) const {
1036+
GInsertSubvector &IS = cast<GInsertSubvector>(MI);
1037+
1038+
MachineRegisterInfo &MRI = *MIB.getMRI();
1039+
1040+
Register Dst = IS.getReg(0);
1041+
Register BigVec = IS.getBigVec();
1042+
Register LitVec = IS.getSubVec();
1043+
uint64_t Idx = IS.getIndexImm();
1044+
1045+
LLT BigTy = MRI.getType(BigVec);
1046+
LLT LitTy = MRI.getType(LitVec);
1047+
1048+
if (Idx == 0 ||
1049+
MRI.getVRegDef(BigVec)->getOpcode() == TargetOpcode::G_IMPLICIT_DEF)
1050+
return true;
1051+
1052+
// We don't have the ability to slide mask vectors up indexed by their i1
1053+
// elements; the smallest we can do is i8. Often we are able to bitcast to
1054+
// equivalent i8 vectors. Otherwise, we can must zeroextend to equivalent i8
1055+
// vectors and truncate down after the insert.
1056+
if (LitTy.getElementType() == LLT::scalar(1)) {
1057+
auto BigTyMinElts = BigTy.getElementCount().getKnownMinValue();
1058+
auto LitTyMinElts = LitTy.getElementCount().getKnownMinValue();
1059+
if (BigTyMinElts >= 8 && LitTyMinElts >= 8)
1060+
return Helper.bitcast(
1061+
IS, 0,
1062+
LLT::vector(BigTy.getElementCount().divideCoefficientBy(8), 8));
1063+
1064+
// We can't slide this mask vector up indexed by its i1 elements.
1065+
// This poses a problem when we wish to insert a scalable vector which
1066+
// can't be re-expressed as a larger type. Just choose the slow path and
1067+
// extend to a larger type, then truncate back down.
1068+
BigTy = BigTy.changeElementType(LLT::scalar(8));
1069+
LitTy = LitTy.changeElementType(LLT::scalar(8));
1070+
auto BigZExt = MIB.buildZExt(BigTy, BigVec);
1071+
auto LitZExt = MIB.buildZExt(LitTy, LitVec);
1072+
auto Insert = MIB.buildInsertSubvector(BigTy, BigZExt, LitZExt, Idx);
1073+
auto SplatZero = MIB.buildSplatVector(
1074+
BigTy, MIB.buildConstant(BigTy.getElementType(), 0));
1075+
MIB.buildICmp(CmpInst::Predicate::ICMP_NE, Dst, Insert, SplatZero);
1076+
MI.eraseFromParent();
1077+
return true;
1078+
}
1079+
1080+
const RISCVRegisterInfo *TRI = STI.getRegisterInfo();
1081+
unsigned SubRegIdx, RemIdx;
1082+
std::tie(SubRegIdx, RemIdx) =
1083+
RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
1084+
getMVTForLLT(BigTy), getMVTForLLT(LitTy), Idx, TRI);
1085+
1086+
TypeSize VecRegSize = TypeSize::getScalable(RISCV::RVVBitsPerBlock);
1087+
assert(isPowerOf2_64(
1088+
STI.expandVScale(LitTy.getSizeInBits()).getKnownMinValue()));
1089+
bool ExactlyVecRegSized =
1090+
STI.expandVScale(LitTy.getSizeInBits())
1091+
.isKnownMultipleOf(STI.expandVScale(VecRegSize));
1092+
1093+
// If the Idx has been completely eliminated and this subvector's size is a
1094+
// vector register or a multiple thereof, or the surrounding elements are
1095+
// undef, then this is a subvector insert which naturally aligns to a vector
1096+
// register. These can easily be handled using subregister manipulation.
1097+
if (RemIdx == 0 &&
1098+
(ExactlyVecRegSized ||
1099+
MRI.getVRegDef(BigVec)->getOpcode() == TargetOpcode::G_IMPLICIT_DEF))
1100+
return true;
1101+
1102+
// If the subvector is smaller than a vector register, then the insertion
1103+
// must preserve the undisturbed elements of the register. We do this by
1104+
// lowering to an EXTRACT_SUBVECTOR grabbing the nearest LMUL=1 vector type
1105+
// (which resolves to a subregister copy), performing a VSLIDEUP to place the
1106+
// subvector within the vector register, and an INSERT_SUBVECTOR of that
1107+
// LMUL=1 type back into the larger vector (resolving to another subregister
1108+
// operation). See below for how our VSLIDEUP works. We go via a LMUL=1 type
1109+
// to avoid allocating a large register group to hold our subvector.
1110+
1111+
// VSLIDEUP works by leaving elements 0<i<OFFSET undisturbed, elements
1112+
// OFFSET<=i<VL set to the "subvector" and vl<=i<VLMAX set to the tail policy
1113+
// (in our case undisturbed). This means we can set up a subvector insertion
1114+
// where OFFSET is the insertion offset, and the VL is the OFFSET plus the
1115+
// size of the subvector.
1116+
const LLT XLenTy(STI.getXLenVT());
1117+
LLT InterLitTy = BigTy;
1118+
Register AlignedExtract = BigVec;
1119+
unsigned AlignedIdx = Idx - RemIdx;
1120+
if (TypeSize::isKnownGT(BigTy.getSizeInBits(),
1121+
getLMUL1Ty(BigTy).getSizeInBits())) {
1122+
InterLitTy = getLMUL1Ty(BigTy);
1123+
// Extract a subvector equal to the nearest full vector register type. This
1124+
// should resolve to a G_EXTRACT on a subreg.
1125+
AlignedExtract =
1126+
MIB.buildExtractSubvector(InterLitTy, BigVec, AlignedIdx).getReg(0);
1127+
}
1128+
1129+
auto Insert = MIB.buildInsertSubvector(InterLitTy, MIB.buildUndef(InterLitTy),
1130+
LitVec, 0);
1131+
1132+
auto [Mask, _] = buildDefaultVLOps(BigTy, MIB, MRI);
1133+
auto VL = MIB.buildVScale(XLenTy, LitTy.getElementCount().getKnownMinValue());
1134+
1135+
// Use tail agnostic policy if we're inserting over InterLitTy's tail.
1136+
ElementCount EndIndex =
1137+
ElementCount::getScalable(RemIdx) + LitTy.getElementCount();
1138+
uint64_t Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
1139+
if (STI.expandVScale(EndIndex) ==
1140+
STI.expandVScale(InterLitTy.getElementCount()))
1141+
Policy = RISCVII::TAIL_AGNOSTIC;
1142+
1143+
// If we're inserting into the lowest elements, use a tail undisturbed
1144+
// vmv.v.v.
1145+
MachineInstrBuilder Inserted;
1146+
if (RemIdx == 0) {
1147+
Inserted = MIB.buildInstr(RISCV::G_VMV_V_V_VL, {InterLitTy},
1148+
{AlignedExtract, Insert, VL});
1149+
} else {
1150+
auto SlideupAmt = MIB.buildVScale(XLenTy, RemIdx);
1151+
// Construct the vector length corresponding to RemIdx + length(LitTy).
1152+
VL = MIB.buildAdd(XLenTy, SlideupAmt, VL);
1153+
Inserted =
1154+
MIB.buildInstr(RISCV::G_VSLIDEUP_VL, {InterLitTy},
1155+
{AlignedExtract, Insert, SlideupAmt, Mask, VL, Policy});
1156+
}
1157+
1158+
// If required, insert this subvector back into the correct vector register.
1159+
// This should resolve to an INSERT_SUBREG instruction.
1160+
if (TypeSize::isKnownGT(BigTy.getSizeInBits(), InterLitTy.getSizeInBits()))
1161+
MIB.buildInsertSubvector(Dst, BigVec, LitVec, AlignedIdx);
1162+
else
1163+
Inserted->getOperand(0).setReg(Dst);
1164+
1165+
MI.eraseFromParent();
1166+
return true;
1167+
}
1168+
10281169
bool RISCVLegalizerInfo::legalizeCustom(
10291170
LegalizerHelper &Helper, MachineInstr &MI,
10301171
LostDebugLocObserver &LocObserver) const {
@@ -1092,6 +1233,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
10921233
return legalizeSplatVector(MI, MIRBuilder);
10931234
case TargetOpcode::G_EXTRACT_SUBVECTOR:
10941235
return legalizeExtractSubvector(MI, MIRBuilder);
1236+
case TargetOpcode::G_INSERT_SUBVECTOR:
1237+
return legalizeInsertSubvector(MI, Helper, MIRBuilder);
10951238
case TargetOpcode::G_LOAD:
10961239
case TargetOpcode::G_STORE:
10971240
return legalizeLoadStore(MI, Helper, MIRBuilder);

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class RISCVLegalizerInfo : public LegalizerInfo {
4747
bool legalizeExt(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
4848
bool legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const;
4949
bool legalizeExtractSubvector(MachineInstr &MI, MachineIRBuilder &MIB) const;
50+
bool legalizeInsertSubvector(MachineInstr &MI, LegalizerHelper &Helper,
51+
MachineIRBuilder &MIB) const;
5052
bool legalizeLoadStore(MachineInstr &MI, LegalizerHelper &Helper,
5153
MachineIRBuilder &MIB) const;
5254
};

llvm/lib/Target/RISCV/RISCVInstrGISel.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,20 @@ def G_VSLIDEDOWN_VL : RISCVGenericInstruction {
6767
}
6868
def : GINodeEquiv<G_VSLIDEDOWN_VL, riscv_slidedown_vl>;
6969

70+
// Pseudo equivalent to a RISCVISD::VMV_V_V_VL
71+
def G_VMV_V_V_VL : RISCVGenericInstruction {
72+
let OutOperandList = (outs type0:$dst);
73+
let InOperandList = (ins type0:$vec, type2:$vl);
74+
let hasSideEffects = false;
75+
}
76+
def : GINodeEquiv<G_VMV_V_V_VL, riscv_vmv_v_v_vl>;
77+
78+
// Pseudo equivalent to a RISCVISD::VSLIDEUP_VL
79+
def G_VSLIDEUP_VL : RISCVGenericInstruction {
80+
let OutOperandList = (outs type0:$dst);
81+
let InOperandList = (ins type0:$merge, type0:$vec, type1:$idx, type2:$mask,
82+
type3:$vl, type4:$policy);
83+
let hasSideEffects = false;
84+
}
85+
def : GINodeEquiv<G_VSLIDEUP_VL, riscv_slideup_vl>;
86+

0 commit comments

Comments
 (0)