Skip to content

Commit f957d08

Browse files
[RISCV][GISEL] Legalize G_EXTRACT_SUBVECTOR (#109426)
This is heavily based on the SelectionDAG lowerEXTRACT_SUBVECTOR code.
1 parent d8df118 commit f957d08

File tree

7 files changed

+695
-0
lines changed

7 files changed

+695
-0
lines changed

llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,17 @@ class GInsertVectorElement : public GenericMachineInstr {
800800
}
801801
};
802802

803+
/// Represents an extract subvector.
804+
class GExtractSubvector : public GenericMachineInstr {
805+
public:
806+
Register getSrcVec() const { return getOperand(1).getReg(); }
807+
uint64_t getIndexImm() const { return getOperand(2).getImm(); }
808+
809+
static bool classof(const MachineInstr *MI) {
810+
return MI->getOpcode() == TargetOpcode::G_EXTRACT_SUBVECTOR;
811+
}
812+
};
813+
803814
/// Represents a freeze.
804815
class GFreeze : public GenericMachineInstr {
805816
public:

llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,8 @@ class LegalizerHelper {
378378
LLT CastTy);
379379
LegalizeResult bitcastConcatVector(MachineInstr &MI, unsigned TypeIdx,
380380
LLT CastTy);
381+
LegalizeResult bitcastExtractSubvector(MachineInstr &MI, unsigned TypeIdx,
382+
LLT CastTy);
381383

382384
LegalizeResult lowerConstant(MachineInstr &MI);
383385
LegalizeResult lowerFConstant(MachineInstr &MI);

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3666,6 +3666,65 @@ LegalizerHelper::bitcastConcatVector(MachineInstr &MI, unsigned TypeIdx,
36663666
return Legalized;
36673667
}
36683668

3669+
/// This attempts to bitcast G_EXTRACT_SUBVECTOR to CastTy.
3670+
///
3671+
/// <vscale x 8 x i1> = G_EXTRACT_SUBVECTOR <vscale x 16 x i1>, N
3672+
///
3673+
/// ===>
3674+
///
3675+
/// <vscale x 2 x i1> = G_BITCAST <vscale x 16 x i1>
3676+
/// <vscale x 1 x i8> = G_EXTRACT_SUBVECTOR <vscale x 2 x i1>, N / 8
3677+
/// <vscale x 8 x i1> = G_BITCAST <vscale x 1 x i8>
3678+
LegalizerHelper::LegalizeResult
3679+
LegalizerHelper::bitcastExtractSubvector(MachineInstr &MI, unsigned TypeIdx,
3680+
LLT CastTy) {
3681+
auto ES = cast<GExtractSubvector>(&MI);
3682+
3683+
if (!CastTy.isVector())
3684+
return UnableToLegalize;
3685+
3686+
if (TypeIdx != 0)
3687+
return UnableToLegalize;
3688+
3689+
Register Dst = ES->getReg(0);
3690+
Register Src = ES->getSrcVec();
3691+
uint64_t Idx = ES->getIndexImm();
3692+
3693+
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
3694+
3695+
LLT DstTy = MRI.getType(Dst);
3696+
LLT SrcTy = MRI.getType(Src);
3697+
ElementCount DstTyEC = DstTy.getElementCount();
3698+
ElementCount SrcTyEC = SrcTy.getElementCount();
3699+
auto DstTyMinElts = DstTyEC.getKnownMinValue();
3700+
auto SrcTyMinElts = SrcTyEC.getKnownMinValue();
3701+
3702+
if (DstTy == CastTy)
3703+
return Legalized;
3704+
3705+
if (DstTy.getSizeInBits() != CastTy.getSizeInBits())
3706+
return UnableToLegalize;
3707+
3708+
unsigned CastEltSize = CastTy.getElementType().getSizeInBits();
3709+
unsigned DstEltSize = DstTy.getElementType().getSizeInBits();
3710+
if (CastEltSize < DstEltSize)
3711+
return UnableToLegalize;
3712+
3713+
auto AdjustAmt = CastEltSize / DstEltSize;
3714+
if (Idx % AdjustAmt != 0 || DstTyMinElts % AdjustAmt != 0 ||
3715+
SrcTyMinElts % AdjustAmt != 0)
3716+
return UnableToLegalize;
3717+
3718+
Idx /= AdjustAmt;
3719+
SrcTy = LLT::vector(SrcTyEC.divideCoefficientBy(AdjustAmt), AdjustAmt);
3720+
auto CastVec = MIRBuilder.buildBitcast(SrcTy, Src);
3721+
auto PromotedES = MIRBuilder.buildExtractSubvector(CastTy, CastVec, Idx);
3722+
MIRBuilder.buildBitcast(Dst, PromotedES);
3723+
3724+
ES->eraseFromParent();
3725+
return Legalized;
3726+
}
3727+
36693728
LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) {
36703729
// Lower to a memory-width G_LOAD and a G_SEXT/G_ZEXT/G_ANYEXT
36713730
Register DstReg = LoadMI.getDstReg();
@@ -3972,6 +4031,8 @@ LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) {
39724031
return bitcastInsertVectorElt(MI, TypeIdx, CastTy);
39734032
case TargetOpcode::G_CONCAT_VECTORS:
39744033
return bitcastConcatVector(MI, TypeIdx, CastTy);
4034+
case TargetOpcode::G_EXTRACT_SUBVECTOR:
4035+
return bitcastExtractSubvector(MI, TypeIdx, CastTy);
39754036
default:
39764037
return UnableToLegalize;
39774038
}

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,31 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
597597

598598
SplatActions.clampScalar(1, sXLen, sXLen);
599599

600+
LegalityPredicate ExtractSubvecBitcastPred = [=](const LegalityQuery &Query) {
601+
LLT DstTy = Query.Types[0];
602+
LLT SrcTy = Query.Types[1];
603+
return DstTy.getElementType() == LLT::scalar(1) &&
604+
DstTy.getElementCount().getKnownMinValue() >= 8 &&
605+
SrcTy.getElementCount().getKnownMinValue() >= 8;
606+
};
607+
getActionDefinitionsBuilder(G_EXTRACT_SUBVECTOR)
608+
// We don't have the ability to slide mask vectors down indexed by their
609+
// i1 elements; the smallest we can do is i8. Often we are able to bitcast
610+
// to equivalent i8 vectors.
611+
.bitcastIf(
612+
all(typeIsLegalBoolVec(0, BoolVecTys, ST),
613+
typeIsLegalBoolVec(1, BoolVecTys, ST), ExtractSubvecBitcastPred),
614+
[=](const LegalityQuery &Query) {
615+
LLT CastTy = LLT::vector(
616+
Query.Types[0].getElementCount().divideCoefficientBy(8), 8);
617+
return std::pair(0, CastTy);
618+
})
619+
.customIf(LegalityPredicates::any(
620+
all(typeIsLegalBoolVec(0, BoolVecTys, ST),
621+
typeIsLegalBoolVec(1, BoolVecTys, ST)),
622+
all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
623+
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST))));
624+
600625
getLegacyLegalizerInfo().computeTables();
601626
}
602627

@@ -931,6 +956,105 @@ bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
931956
return true;
932957
}
933958

959+
static LLT getLMUL1Ty(LLT VecTy) {
960+
assert(VecTy.getElementType().getSizeInBits() <= 64 &&
961+
"Unexpected vector LLT");
962+
return LLT::scalable_vector(RISCV::RVVBitsPerBlock /
963+
VecTy.getElementType().getSizeInBits(),
964+
VecTy.getElementType());
965+
}
966+
967+
bool RISCVLegalizerInfo::legalizeExtractSubvector(MachineInstr &MI,
968+
LegalizerHelper &Helper,
969+
MachineIRBuilder &MIB) const {
970+
GExtractSubvector &ES = cast<GExtractSubvector>(MI);
971+
972+
MachineRegisterInfo &MRI = *MIB.getMRI();
973+
974+
Register Dst = ES.getReg(0);
975+
Register Src = ES.getSrcVec();
976+
uint64_t Idx = ES.getIndexImm();
977+
978+
// With an index of 0 this is a cast-like subvector, which can be performed
979+
// with subregister operations.
980+
if (Idx == 0)
981+
return true;
982+
983+
LLT LitTy = MRI.getType(Dst);
984+
LLT BigTy = MRI.getType(Src);
985+
986+
if (LitTy.getElementType() == LLT::scalar(1)) {
987+
// We can't slide this mask vector up indexed by its i1 elements.
988+
// This poses a problem when we wish to insert a scalable vector which
989+
// can't be re-expressed as a larger type. Just choose the slow path and
990+
// extend to a larger type, then truncate back down.
991+
LLT ExtBigTy = BigTy.changeElementType(LLT::scalar(8));
992+
LLT ExtLitTy = LitTy.changeElementType(LLT::scalar(8));
993+
auto BigZExt = MIB.buildZExt(ExtBigTy, Src);
994+
auto ExtractZExt = MIB.buildExtractSubvector(ExtLitTy, BigZExt, Idx);
995+
auto SplatZero = MIB.buildSplatVector(
996+
ExtLitTy, MIB.buildConstant(ExtLitTy.getElementType(), 0));
997+
MIB.buildICmp(CmpInst::Predicate::ICMP_NE, Dst, ExtractZExt, SplatZero);
998+
MI.eraseFromParent();
999+
return true;
1000+
}
1001+
1002+
// extract_subvector scales the index by vscale if the subvector is scalable,
1003+
// and decomposeSubvectorInsertExtractToSubRegs takes this into account.
1004+
const RISCVRegisterInfo *TRI = STI.getRegisterInfo();
1005+
MVT LitTyMVT = getMVTForLLT(LitTy);
1006+
auto Decompose =
1007+
RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
1008+
getMVTForLLT(BigTy), LitTyMVT, Idx, TRI);
1009+
unsigned RemIdx = Decompose.second;
1010+
1011+
// If the Idx has been completely eliminated then this is a subvector extract
1012+
// which naturally aligns to a vector register. These can easily be handled
1013+
// using subregister manipulation.
1014+
if (RemIdx == 0)
1015+
return true;
1016+
1017+
// Else LitTy is M1 or smaller and may need to be slid down: if LitTy
1018+
// was > M1 then the index would need to be a multiple of VLMAX, and so would
1019+
// divide exactly.
1020+
assert(
1021+
RISCVVType::decodeVLMUL(RISCVTargetLowering::getLMUL(LitTyMVT)).second ||
1022+
RISCVTargetLowering::getLMUL(LitTyMVT) == RISCVII::VLMUL::LMUL_1);
1023+
1024+
// If the vector type is an LMUL-group type, extract a subvector equal to the
1025+
// nearest full vector register type.
1026+
LLT InterLitTy = BigTy;
1027+
Register Vec = Src;
1028+
if (TypeSize::isKnownGT(BigTy.getSizeInBits(),
1029+
getLMUL1Ty(BigTy).getSizeInBits())) {
1030+
// If BigTy has an LMUL > 1, then LitTy should have a smaller LMUL, and
1031+
// we should have successfully decomposed the extract into a subregister.
1032+
assert(Decompose.first != RISCV::NoSubRegister);
1033+
InterLitTy = getLMUL1Ty(BigTy);
1034+
// SDAG builds a TargetExtractSubreg. We cannot create a a Copy with SubReg
1035+
// specified on the source Register (the equivalent) since generic virtual
1036+
// register does not allow subregister index.
1037+
Vec = MIB.buildExtractSubvector(InterLitTy, Src, Idx - RemIdx).getReg(0);
1038+
}
1039+
1040+
// Slide this vector register down by the desired number of elements in order
1041+
// to place the desired subvector starting at element 0.
1042+
const LLT XLenTy(STI.getXLenVT());
1043+
auto SlidedownAmt = MIB.buildVScale(XLenTy, RemIdx);
1044+
auto [Mask, VL] = buildDefaultVLOps(LitTy, MIB, MRI);
1045+
uint64_t Policy = RISCVII::TAIL_AGNOSTIC | RISCVII::MASK_AGNOSTIC;
1046+
auto Slidedown = MIB.buildInstr(
1047+
RISCV::G_VSLIDEDOWN_VL, {InterLitTy},
1048+
{MIB.buildUndef(InterLitTy), Vec, SlidedownAmt, Mask, VL, Policy});
1049+
1050+
// Now the vector is in the right position, extract our final subvector. This
1051+
// should resolve to a COPY.
1052+
MIB.buildExtractSubvector(Dst, Slidedown, 0);
1053+
1054+
MI.eraseFromParent();
1055+
return true;
1056+
}
1057+
9341058
bool RISCVLegalizerInfo::legalizeCustom(
9351059
LegalizerHelper &Helper, MachineInstr &MI,
9361060
LostDebugLocObserver &LocObserver) const {
@@ -1001,6 +1125,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
10011125
return legalizeExt(MI, MIRBuilder);
10021126
case TargetOpcode::G_SPLAT_VECTOR:
10031127
return legalizeSplatVector(MI, MIRBuilder);
1128+
case TargetOpcode::G_EXTRACT_SUBVECTOR:
1129+
return legalizeExtractSubvector(MI, Helper, MIRBuilder);
10041130
case TargetOpcode::G_LOAD:
10051131
case TargetOpcode::G_STORE:
10061132
return legalizeLoadStore(MI, Helper, MIRBuilder);

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class RISCVLegalizerInfo : public LegalizerInfo {
4646
bool legalizeVScale(MachineInstr &MI, MachineIRBuilder &MIB) const;
4747
bool legalizeExt(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
4848
bool legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const;
49+
bool legalizeExtractSubvector(MachineInstr &MI, LegalizerHelper &Helper,
50+
MachineIRBuilder &MIB) const;
4951
bool legalizeLoadStore(MachineInstr &MI, LegalizerHelper &Helper,
5052
MachineIRBuilder &MIB) const;
5153
};

llvm/lib/Target/RISCV/RISCVInstrGISel.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,13 @@ def G_SPLAT_VECTOR_SPLIT_I64_VL : RISCVGenericInstruction {
5757
let InOperandList = (ins type0:$passthru, type1:$hi, type1:$lo, type2:$vl);
5858
let hasSideEffects = false;
5959
}
60+
61+
// Pseudo equivalent to a RISCVISD::VSLIDEDOWN_VL
62+
def G_VSLIDEDOWN_VL : RISCVGenericInstruction {
63+
let OutOperandList = (outs type0:$dst);
64+
let InOperandList = (ins type0:$merge, type0:$vec, type1:$idx, type2:$mask,
65+
type1:$vl, type1:$policy);
66+
let hasSideEffects = false;
67+
}
68+
def : GINodeEquiv<G_VSLIDEDOWN_VL, riscv_slidedown_vl>;
69+

0 commit comments

Comments
 (0)