Skip to content

Commit 5d0a49d

Browse files
[RISCV][GISEL] Legalize G_INSERT_SUBVECTOR
This code is heavily based on the SelectionDAG lowerINSERT_SUBVECTOR code.
1 parent fa6c02a commit 5d0a49d

File tree

7 files changed

+813
-4
lines changed

7 files changed

+813
-4
lines changed

llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,18 @@ class GExtractSubvector : public GenericMachineInstr {
811811
}
812812
};
813813

814+
/// Represents a insert subvector.
815+
class GInsertSubvector : public GenericMachineInstr {
816+
public:
817+
Register getBigVec() const { return getOperand(1).getReg(); }
818+
Register getSubVec() const { return getOperand(2).getReg(); }
819+
uint64_t getIndexImm() const { return getOperand(3).getImm(); }
820+
821+
static bool classof(const MachineInstr *MI) {
822+
return MI->getOpcode() == TargetOpcode::G_INSERT_SUBVECTOR;
823+
}
824+
};
825+
814826
/// Represents a freeze.
815827
class GFreeze : public GenericMachineInstr {
816828
public:

llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ class LegalizerHelper {
380380
LLT CastTy);
381381
LegalizeResult bitcastExtractSubvector(MachineInstr &MI, unsigned TypeIdx,
382382
LLT CastTy);
383+
LegalizeResult bitcastInsertSubvector(MachineInstr &MI, unsigned TypeIdx,
384+
LLT CastTy);
383385

384386
LegalizeResult lowerConstant(MachineInstr &MI);
385387
LegalizeResult lowerFConstant(MachineInstr &MI);

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3725,6 +3725,77 @@ LegalizerHelper::bitcastExtractSubvector(MachineInstr &MI, unsigned TypeIdx,
37253725
return Legalized;
37263726
}
37273727

3728+
/// This attempts to bitcast G_INSERT_SUBVECTOR to CastTy.
3729+
///
3730+
/// <vscale x 16 x i1> = G_INSERT_SUBVECTOR <vscale x 16 x i1>,
3731+
/// <vscale x 8 x i1>,
3732+
/// N
3733+
///
3734+
/// ===>
3735+
///
3736+
/// <vscale x 2 x i8> = G_BITCAST <vscale x 16 x i1>
3737+
/// <vscale x 1 x i8> = G_BITCAST <vscale x 8 x i1>
3738+
/// <vscale x 2 x i8> = G_INSERT_SUBVECTOR <vscale x 2 x i8>,
3739+
/// <vscale x 1 x i8>, N / 8
3740+
/// <vscale x 16 x i1> = G_BITCAST <vscale x 2 x i8>
3741+
LegalizerHelper::LegalizeResult
3742+
LegalizerHelper::bitcastInsertSubvector(MachineInstr &MI, unsigned TypeIdx,
3743+
LLT CastTy) {
3744+
auto ES = cast<GInsertSubvector>(&MI);
3745+
3746+
if (!CastTy.isVector())
3747+
return UnableToLegalize;
3748+
3749+
if (TypeIdx != 0)
3750+
return UnableToLegalize;
3751+
3752+
Register Dst = ES->getReg(0);
3753+
Register BigVec = ES->getBigVec();
3754+
Register SubVec = ES->getSubVec();
3755+
uint64_t Idx = ES->getIndexImm();
3756+
3757+
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
3758+
3759+
LLT DstTy = MRI.getType(Dst);
3760+
LLT BigVecTy = MRI.getType(BigVec);
3761+
LLT SubVecTy = MRI.getType(SubVec);
3762+
3763+
if (DstTy == CastTy)
3764+
return Legalized;
3765+
3766+
if (DstTy.getSizeInBits() != CastTy.getSizeInBits())
3767+
return UnableToLegalize;
3768+
3769+
ElementCount DstTyEC = DstTy.getElementCount();
3770+
ElementCount BigVecTyEC = BigVecTy.getElementCount();
3771+
ElementCount SubVecTyEC = SubVecTy.getElementCount();
3772+
auto DstTyMinElts = DstTyEC.getKnownMinValue();
3773+
auto BigVecTyMinElts = BigVecTyEC.getKnownMinValue();
3774+
auto SubVecTyMinElts = SubVecTyEC.getKnownMinValue();
3775+
3776+
unsigned CastEltSize = CastTy.getElementType().getSizeInBits();
3777+
unsigned DstEltSize = DstTy.getElementType().getSizeInBits();
3778+
if (CastEltSize < DstEltSize)
3779+
return UnableToLegalize;
3780+
3781+
auto AdjustAmt = CastEltSize / DstEltSize;
3782+
if (Idx % AdjustAmt != 0 || DstTyMinElts % AdjustAmt != 0 ||
3783+
BigVecTyMinElts % AdjustAmt != 0 || SubVecTyMinElts % AdjustAmt != 0)
3784+
return UnableToLegalize;
3785+
3786+
Idx /= AdjustAmt;
3787+
BigVecTy = LLT::vector(BigVecTyEC.divideCoefficientBy(AdjustAmt), AdjustAmt);
3788+
SubVecTy = LLT::vector(SubVecTyEC.divideCoefficientBy(AdjustAmt), AdjustAmt);
3789+
auto CastBigVec = MIRBuilder.buildBitcast(BigVecTy, BigVec);
3790+
auto CastSubVec = MIRBuilder.buildBitcast(SubVecTy, SubVec);
3791+
auto PromotedIS =
3792+
MIRBuilder.buildInsertSubvector(CastTy, CastBigVec, CastSubVec, Idx);
3793+
MIRBuilder.buildBitcast(Dst, PromotedIS);
3794+
3795+
ES->eraseFromParent();
3796+
return Legalized;
3797+
}
3798+
37283799
LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) {
37293800
// Lower to a memory-width G_LOAD and a G_SEXT/G_ZEXT/G_ANYEXT
37303801
Register DstReg = LoadMI.getDstReg();
@@ -4033,6 +4104,8 @@ LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) {
40334104
return bitcastConcatVector(MI, TypeIdx, CastTy);
40344105
case TargetOpcode::G_EXTRACT_SUBVECTOR:
40354106
return bitcastExtractSubvector(MI, TypeIdx, CastTy);
4107+
case TargetOpcode::G_INSERT_SUBVECTOR:
4108+
return bitcastInsertSubvector(MI, TypeIdx, CastTy);
40364109
default:
40374110
return UnableToLegalize;
40384111
}

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp

Lines changed: 147 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,13 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
622622
all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
623623
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST))));
624624

625+
626+
getActionDefinitionsBuilder(G_INSERT_SUBVECTOR)
627+
.customIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST),
628+
typeIsLegalBoolVec(1, BoolVecTys, ST)))
629+
.customIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
630+
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST)));
631+
625632
getLegacyLegalizerInfo().computeTables();
626633
}
627634

@@ -865,9 +872,7 @@ static MachineInstrBuilder buildAllOnesMask(LLT VecTy, const SrcOp &VL,
865872
/// Gets the two common "VL" operands: an all-ones mask and the vector length.
866873
/// VecTy is a scalable vector type.
867874
static std::pair<MachineInstrBuilder, Register>
868-
buildDefaultVLOps(const DstOp &Dst, MachineIRBuilder &MIB,
869-
MachineRegisterInfo &MRI) {
870-
LLT VecTy = Dst.getLLTTy(MRI);
875+
buildDefaultVLOps(LLT VecTy, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
871876
assert(VecTy.isScalableVector() && "Expecting scalable container type");
872877
Register VL(RISCV::X0);
873878
MachineInstrBuilder Mask = buildAllOnesMask(VecTy, VL, MIB, MRI);
@@ -919,7 +924,7 @@ bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
919924
// Handle case of s64 element vectors on rv32
920925
if (XLenTy.getSizeInBits() == 32 &&
921926
VecTy.getElementType().getSizeInBits() == 64) {
922-
auto [_, VL] = buildDefaultVLOps(Dst, MIB, MRI);
927+
auto [_, VL] = buildDefaultVLOps(MRI.getType(Dst), MIB, MRI);
923928
buildSplatSplitS64WithVL(Dst, MIB.buildUndef(VecTy), SplatVal, VL, MIB,
924929
MRI);
925930
MI.eraseFromParent();
@@ -1054,6 +1059,142 @@ bool RISCVLegalizerInfo::legalizeExtractSubvector(MachineInstr &MI,
10541059
return true;
10551060
}
10561061

1062+
bool RISCVLegalizerInfo::legalizeInsertSubvector(MachineInstr &MI,
1063+
LegalizerHelper &Helper,
1064+
MachineIRBuilder &MIB) const {
1065+
GInsertSubvector &IS = cast<GInsertSubvector>(MI);
1066+
1067+
MachineRegisterInfo &MRI = *MIB.getMRI();
1068+
1069+
Register Dst = IS.getReg(0);
1070+
Register BigVec = IS.getBigVec();
1071+
Register LitVec = IS.getSubVec();
1072+
uint64_t Idx = IS.getIndexImm();
1073+
1074+
LLT BigTy = MRI.getType(BigVec);
1075+
LLT LitTy = MRI.getType(LitVec);
1076+
1077+
if (Idx == 0 ||
1078+
MRI.getVRegDef(BigVec)->getOpcode() == TargetOpcode::G_IMPLICIT_DEF)
1079+
return true;
1080+
1081+
// We don't have the ability to slide mask vectors up indexed by their i1
1082+
// elements; the smallest we can do is i8. Often we are able to bitcast to
1083+
// equivalent i8 vectors. Otherwise, we can must zeroextend to equivalent i8
1084+
// vectors and truncate down after the insert.
1085+
if (LitTy.getElementType() == LLT::scalar(1)) {
1086+
auto BigTyMinElts = BigTy.getElementCount().getKnownMinValue();
1087+
auto LitTyMinElts = LitTy.getElementCount().getKnownMinValue();
1088+
if (BigTyMinElts >= 8 && LitTyMinElts >= 8)
1089+
return Helper.bitcast(
1090+
IS, 0,
1091+
LLT::vector(BigTy.getElementCount().divideCoefficientBy(8), 8));
1092+
1093+
// We can't slide this mask vector up indexed by its i1 elements.
1094+
// This poses a problem when we wish to insert a scalable vector which
1095+
// can't be re-expressed as a larger type. Just choose the slow path and
1096+
// extend to a larger type, then truncate back down.
1097+
BigTy = BigTy.changeElementType(LLT::scalar(8));
1098+
LitTy = LitTy.changeElementType(LLT::scalar(8));
1099+
auto BigZExt = MIB.buildZExt(BigTy, BigVec);
1100+
auto LitZExt = MIB.buildZExt(LitTy, LitVec);
1101+
auto Insert = MIB.buildInsertSubvector(BigTy, BigZExt, LitZExt, Idx);
1102+
auto SplatZero = MIB.buildSplatVector(
1103+
BigTy, MIB.buildConstant(BigTy.getElementType(), 0));
1104+
MIB.buildICmp(CmpInst::Predicate::ICMP_NE, Dst, Insert, SplatZero);
1105+
MI.eraseFromParent();
1106+
return true;
1107+
}
1108+
1109+
const RISCVRegisterInfo *TRI = STI.getRegisterInfo();
1110+
unsigned SubRegIdx, RemIdx;
1111+
std::tie(SubRegIdx, RemIdx) =
1112+
RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
1113+
getMVTForLLT(BigTy), getMVTForLLT(LitTy), Idx, TRI);
1114+
1115+
TypeSize VecRegSize = TypeSize::getScalable(RISCV::RVVBitsPerBlock);
1116+
assert(isPowerOf2_64(
1117+
STI.expandVScale(LitTy.getSizeInBits()).getKnownMinValue()));
1118+
bool ExactlyVecRegSized =
1119+
STI.expandVScale(LitTy.getSizeInBits())
1120+
.isKnownMultipleOf(STI.expandVScale(VecRegSize));
1121+
1122+
// If the Idx has been completely eliminated and this subvector's size is a
1123+
// vector register or a multiple thereof, or the surrounding elements are
1124+
// undef, then this is a subvector insert which naturally aligns to a vector
1125+
// register. These can easily be handled using subregister manipulation.
1126+
if (RemIdx == 0 &&
1127+
(ExactlyVecRegSized ||
1128+
MRI.getVRegDef(BigVec)->getOpcode() == TargetOpcode::G_IMPLICIT_DEF))
1129+
return true;
1130+
1131+
// If the subvector is smaller than a vector register, then the insertion
1132+
// must preserve the undisturbed elements of the register. We do this by
1133+
// lowering to an EXTRACT_SUBVECTOR grabbing the nearest LMUL=1 vector type
1134+
// (which resolves to a subregister copy), performing a VSLIDEUP to place the
1135+
// subvector within the vector register, and an INSERT_SUBVECTOR of that
1136+
// LMUL=1 type back into the larger vector (resolving to another subregister
1137+
// operation). See below for how our VSLIDEUP works. We go via a LMUL=1 type
1138+
// to avoid allocating a large register group to hold our subvector.
1139+
1140+
// VSLIDEUP works by leaving elements 0<i<OFFSET undisturbed, elements
1141+
// OFFSET<=i<VL set to the "subvector" and vl<=i<VLMAX set to the tail policy
1142+
// (in our case undisturbed). This means we can set up a subvector insertion
1143+
// where OFFSET is the insertion offset, and the VL is the OFFSET plus the
1144+
// size of the subvector.
1145+
const LLT XLenTy(STI.getXLenVT());
1146+
LLT InterLitTy = BigTy;
1147+
Register AlignedExtract = BigVec;
1148+
unsigned AlignedIdx = Idx - RemIdx;
1149+
if (TypeSize::isKnownGT(BigTy.getSizeInBits(),
1150+
getLMUL1Ty(BigTy).getSizeInBits())) {
1151+
InterLitTy = getLMUL1Ty(BigTy);
1152+
// Extract a subvector equal to the nearest full vector register type. This
1153+
// should resolve to a G_EXTRACT on a subreg.
1154+
AlignedExtract =
1155+
MIB.buildExtractSubvector(InterLitTy, BigVec, AlignedIdx).getReg(0);
1156+
}
1157+
1158+
auto Insert = MIB.buildInsertSubvector(InterLitTy, MIB.buildUndef(InterLitTy),
1159+
LitVec, 0);
1160+
1161+
auto [Mask, _] = buildDefaultVLOps(BigTy, MIB, MRI);
1162+
auto VL = MIB.buildVScale(XLenTy, LitTy.getElementCount().getKnownMinValue());
1163+
1164+
// Use tail agnostic policy if we're inserting over InterLitTy's tail.
1165+
ElementCount EndIndex =
1166+
ElementCount::getScalable(RemIdx) + LitTy.getElementCount();
1167+
uint64_t Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
1168+
if (STI.expandVScale(EndIndex) ==
1169+
STI.expandVScale(InterLitTy.getElementCount()))
1170+
Policy = RISCVII::TAIL_AGNOSTIC;
1171+
1172+
// If we're inserting into the lowest elements, use a tail undisturbed
1173+
// vmv.v.v.
1174+
MachineInstrBuilder Inserted;
1175+
if (RemIdx == 0) {
1176+
Inserted = MIB.buildInstr(RISCV::G_VMV_V_V_VL, {InterLitTy},
1177+
{AlignedExtract, Insert, VL});
1178+
} else {
1179+
auto SlideupAmt = MIB.buildVScale(XLenTy, RemIdx);
1180+
// Construct the vector length corresponding to RemIdx + length(LitTy).
1181+
VL = MIB.buildAdd(XLenTy, SlideupAmt, VL);
1182+
Inserted =
1183+
MIB.buildInstr(RISCV::G_VSLIDEUP_VL, {InterLitTy},
1184+
{AlignedExtract, Insert, SlideupAmt, Mask, VL, Policy});
1185+
}
1186+
1187+
// If required, insert this subvector back into the correct vector register.
1188+
// This should resolve to an INSERT_SUBREG instruction.
1189+
if (TypeSize::isKnownGT(BigTy.getSizeInBits(), InterLitTy.getSizeInBits()))
1190+
MIB.buildInsertSubvector(Dst, BigVec, LitVec, AlignedIdx);
1191+
else
1192+
Inserted->getOperand(0).setReg(Dst);
1193+
1194+
MI.eraseFromParent();
1195+
return true;
1196+
}
1197+
10571198
bool RISCVLegalizerInfo::legalizeCustom(
10581199
LegalizerHelper &Helper, MachineInstr &MI,
10591200
LostDebugLocObserver &LocObserver) const {
@@ -1126,6 +1267,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
11261267
return legalizeSplatVector(MI, MIRBuilder);
11271268
case TargetOpcode::G_EXTRACT_SUBVECTOR:
11281269
return legalizeExtractSubvector(MI, MIRBuilder);
1270+
case TargetOpcode::G_INSERT_SUBVECTOR:
1271+
return legalizeInsertSubvector(MI, Helper, MIRBuilder);
11291272
case TargetOpcode::G_LOAD:
11301273
case TargetOpcode::G_STORE:
11311274
return legalizeLoadStore(MI, Helper, MIRBuilder);

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class RISCVLegalizerInfo : public LegalizerInfo {
4747
bool legalizeExt(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
4848
bool legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const;
4949
bool legalizeExtractSubvector(MachineInstr &MI, MachineIRBuilder &MIB) const;
50+
bool legalizeInsertSubvector(MachineInstr &MI, LegalizerHelper &Helper,
51+
MachineIRBuilder &MIB) const;
5052
bool legalizeLoadStore(MachineInstr &MI, LegalizerHelper &Helper,
5153
MachineIRBuilder &MIB) const;
5254
};

llvm/lib/Target/RISCV/RISCVInstrGISel.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,20 @@ def G_VSLIDEDOWN_VL : RISCVGenericInstruction {
6767
}
6868
def : GINodeEquiv<G_VSLIDEDOWN_VL, riscv_slidedown_vl>;
6969

70+
// Pseudo equivalent to a RISCVISD::VMV_V_V_VL
71+
def G_VMV_V_V_VL : RISCVGenericInstruction {
72+
let OutOperandList = (outs type0:$dst);
73+
let InOperandList = (ins type0:$vec, type2:$vl);
74+
let hasSideEffects = false;
75+
}
76+
def : GINodeEquiv<G_VMV_V_V_VL, riscv_vmv_v_v_vl>;
77+
78+
// Pseudo equivalent to a RISCVISD::VSLIDEUP_VL
79+
def G_VSLIDEUP_VL : RISCVGenericInstruction {
80+
let OutOperandList = (outs type0:$dst);
81+
let InOperandList = (ins type0:$merge, type0:$vec, type1:$idx, type2:$mask,
82+
type3:$vl, type4:$policy);
83+
let hasSideEffects = false;
84+
}
85+
def : GINodeEquiv<G_VSLIDEUP_VL, riscv_slideup_vl>;
86+

0 commit comments

Comments
 (0)