Skip to content

[RISCV][GISEL] Legalize G_EXTRACT_SUBVECTOR #109426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,17 @@ class GInsertVectorElement : public GenericMachineInstr {
}
};

/// Represents an extract subvector.
class GExtractSubvector : public GenericMachineInstr {
public:
Register getSrcVec() const { return getOperand(1).getReg(); }
uint64_t getIndexImm() const { return getOperand(2).getImm(); }

static bool classof(const MachineInstr *MI) {
return MI->getOpcode() == TargetOpcode::G_EXTRACT_SUBVECTOR;
}
};

/// Represents a freeze.
class GFreeze : public GenericMachineInstr {
public:
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ class LegalizerHelper {
LLT CastTy);
LegalizeResult bitcastConcatVector(MachineInstr &MI, unsigned TypeIdx,
LLT CastTy);
LegalizeResult bitcastExtractSubvector(MachineInstr &MI, unsigned TypeIdx,
LLT CastTy);

LegalizeResult lowerConstant(MachineInstr &MI);
LegalizeResult lowerFConstant(MachineInstr &MI);
Expand Down
61 changes: 61 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3666,6 +3666,65 @@ LegalizerHelper::bitcastConcatVector(MachineInstr &MI, unsigned TypeIdx,
return Legalized;
}

/// This attempts to bitcast G_EXTRACT_SUBVECTOR to CastTy.
///
/// <vscale x 8 x i1> = G_EXTRACT_SUBVECTOR <vscale x 16 x i1>, N
///
/// ===>
///
/// <vscale x 2 x i1> = G_BITCAST <vscale x 16 x i1>
/// <vscale x 1 x i8> = G_EXTRACT_SUBVECTOR <vscale x 2 x i1>, N / 8
/// <vscale x 8 x i1> = G_BITCAST <vscale x 1 x i8>
LegalizerHelper::LegalizeResult
LegalizerHelper::bitcastExtractSubvector(MachineInstr &MI, unsigned TypeIdx,
LLT CastTy) {
auto ES = cast<GExtractSubvector>(&MI);

if (!CastTy.isVector())
return UnableToLegalize;

if (TypeIdx != 0)
return UnableToLegalize;

Register Dst = ES->getReg(0);
Register Src = ES->getSrcVec();
uint64_t Idx = ES->getIndexImm();

MachineRegisterInfo &MRI = *MIRBuilder.getMRI();

LLT DstTy = MRI.getType(Dst);
LLT SrcTy = MRI.getType(Src);
ElementCount DstTyEC = DstTy.getElementCount();
ElementCount SrcTyEC = SrcTy.getElementCount();
auto DstTyMinElts = DstTyEC.getKnownMinValue();
auto SrcTyMinElts = SrcTyEC.getKnownMinValue();

if (DstTy == CastTy)
return Legalized;

if (DstTy.getSizeInBits() != CastTy.getSizeInBits())
return UnableToLegalize;

unsigned CastEltSize = CastTy.getElementType().getSizeInBits();
unsigned DstEltSize = DstTy.getElementType().getSizeInBits();
if (CastEltSize < DstEltSize)
return UnableToLegalize;

auto AdjustAmt = CastEltSize / DstEltSize;
if (Idx % AdjustAmt != 0 || DstTyMinElts % AdjustAmt != 0 ||
SrcTyMinElts % AdjustAmt != 0)
return UnableToLegalize;

Idx /= AdjustAmt;
SrcTy = LLT::vector(SrcTyEC.divideCoefficientBy(AdjustAmt), AdjustAmt);
auto CastVec = MIRBuilder.buildBitcast(SrcTy, Src);
auto PromotedES = MIRBuilder.buildExtractSubvector(CastTy, CastVec, Idx);
MIRBuilder.buildBitcast(Dst, PromotedES);

ES->eraseFromParent();
return Legalized;
}

LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) {
// Lower to a memory-width G_LOAD and a G_SEXT/G_ZEXT/G_ANYEXT
Register DstReg = LoadMI.getDstReg();
Expand Down Expand Up @@ -3972,6 +4031,8 @@ LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) {
return bitcastInsertVectorElt(MI, TypeIdx, CastTy);
case TargetOpcode::G_CONCAT_VECTORS:
return bitcastConcatVector(MI, TypeIdx, CastTy);
case TargetOpcode::G_EXTRACT_SUBVECTOR:
return bitcastExtractSubvector(MI, TypeIdx, CastTy);
default:
return UnableToLegalize;
}
Expand Down
126 changes: 126 additions & 0 deletions llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,31 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)

SplatActions.clampScalar(1, sXLen, sXLen);

LegalityPredicate ExtractSubvecBitcastPred = [=](const LegalityQuery &Query) {
LLT DstTy = Query.Types[0];
LLT SrcTy = Query.Types[1];
return DstTy.getElementType() == LLT::scalar(1) &&
DstTy.getElementCount().getKnownMinValue() >= 8 &&
SrcTy.getElementCount().getKnownMinValue() >= 8;
};
getActionDefinitionsBuilder(G_EXTRACT_SUBVECTOR)
// We don't have the ability to slide mask vectors down indexed by their
// i1 elements; the smallest we can do is i8. Often we are able to bitcast
// to equivalent i8 vectors.
.bitcastIf(
all(typeIsLegalBoolVec(0, BoolVecTys, ST),
typeIsLegalBoolVec(1, BoolVecTys, ST), ExtractSubvecBitcastPred),
[=](const LegalityQuery &Query) {
LLT CastTy = LLT::vector(
Query.Types[0].getElementCount().divideCoefficientBy(8), 8);
return std::pair(0, CastTy);
})
.customIf(LegalityPredicates::any(
all(typeIsLegalBoolVec(0, BoolVecTys, ST),
typeIsLegalBoolVec(1, BoolVecTys, ST)),
all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST))));

getLegacyLegalizerInfo().computeTables();
}

Expand Down Expand Up @@ -931,6 +956,105 @@ bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
return true;
}

static LLT getLMUL1Ty(LLT VecTy) {
assert(VecTy.getElementType().getSizeInBits() <= 64 &&
"Unexpected vector LLT");
return LLT::scalable_vector(RISCV::RVVBitsPerBlock /
VecTy.getElementType().getSizeInBits(),
VecTy.getElementType());
}

bool RISCVLegalizerInfo::legalizeExtractSubvector(MachineInstr &MI,
LegalizerHelper &Helper,
MachineIRBuilder &MIB) const {
GExtractSubvector &ES = cast<GExtractSubvector>(MI);

MachineRegisterInfo &MRI = *MIB.getMRI();

Register Dst = ES.getReg(0);
Register Src = ES.getSrcVec();
uint64_t Idx = ES.getIndexImm();

// With an index of 0 this is a cast-like subvector, which can be performed
// with subregister operations.
if (Idx == 0)
return true;

LLT LitTy = MRI.getType(Dst);
LLT BigTy = MRI.getType(Src);

if (LitTy.getElementType() == LLT::scalar(1)) {
// We can't slide this mask vector up indexed by its i1 elements.
// This poses a problem when we wish to insert a scalable vector which
// can't be re-expressed as a larger type. Just choose the slow path and
// extend to a larger type, then truncate back down.
LLT ExtBigTy = BigTy.changeElementType(LLT::scalar(8));
LLT ExtLitTy = LitTy.changeElementType(LLT::scalar(8));
auto BigZExt = MIB.buildZExt(ExtBigTy, Src);
auto ExtractZExt = MIB.buildExtractSubvector(ExtLitTy, BigZExt, Idx);
auto SplatZero = MIB.buildSplatVector(
ExtLitTy, MIB.buildConstant(ExtLitTy.getElementType(), 0));
MIB.buildICmp(CmpInst::Predicate::ICMP_NE, Dst, ExtractZExt, SplatZero);
MI.eraseFromParent();
return true;
}

// extract_subvector scales the index by vscale if the subvector is scalable,
// and decomposeSubvectorInsertExtractToSubRegs takes this into account.
const RISCVRegisterInfo *TRI = STI.getRegisterInfo();
MVT LitTyMVT = getMVTForLLT(LitTy);
auto Decompose =
RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
getMVTForLLT(BigTy), LitTyMVT, Idx, TRI);
unsigned RemIdx = Decompose.second;

// If the Idx has been completely eliminated then this is a subvector extract
// which naturally aligns to a vector register. These can easily be handled
// using subregister manipulation.
if (RemIdx == 0)
return true;

// Else LitTy is M1 or smaller and may need to be slid down: if LitTy
// was > M1 then the index would need to be a multiple of VLMAX, and so would
// divide exactly.
assert(
RISCVVType::decodeVLMUL(RISCVTargetLowering::getLMUL(LitTyMVT)).second ||
RISCVTargetLowering::getLMUL(LitTyMVT) == RISCVII::VLMUL::LMUL_1);

// If the vector type is an LMUL-group type, extract a subvector equal to the
// nearest full vector register type.
LLT InterLitTy = BigTy;
Register Vec = Src;
if (TypeSize::isKnownGT(BigTy.getSizeInBits(),
getLMUL1Ty(BigTy).getSizeInBits())) {
// If BigTy has an LMUL > 1, then LitTy should have a smaller LMUL, and
// we should have successfully decomposed the extract into a subregister.
assert(Decompose.first != RISCV::NoSubRegister);
InterLitTy = getLMUL1Ty(BigTy);
// SDAG builds a TargetExtractSubreg. We cannot create a a Copy with SubReg
// specified on the source Register (the equivalent) since generic virtual
// register does not allow subregister index.
Vec = MIB.buildExtractSubvector(InterLitTy, Src, Idx - RemIdx).getReg(0);
}

// Slide this vector register down by the desired number of elements in order
// to place the desired subvector starting at element 0.
const LLT XLenTy(STI.getXLenVT());
auto SlidedownAmt = MIB.buildVScale(XLenTy, RemIdx);
auto [Mask, VL] = buildDefaultVLOps(LitTy, MIB, MRI);
uint64_t Policy = RISCVII::TAIL_AGNOSTIC | RISCVII::MASK_AGNOSTIC;
auto Slidedown = MIB.buildInstr(
RISCV::G_VSLIDEDOWN_VL, {InterLitTy},
{MIB.buildUndef(InterLitTy), Vec, SlidedownAmt, Mask, VL, Policy});

// Now the vector is in the right position, extract our final subvector. This
// should resolve to a COPY.
MIB.buildExtractSubvector(Dst, Slidedown, 0);

MI.eraseFromParent();
return true;
}

bool RISCVLegalizerInfo::legalizeCustom(
LegalizerHelper &Helper, MachineInstr &MI,
LostDebugLocObserver &LocObserver) const {
Expand Down Expand Up @@ -1001,6 +1125,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
return legalizeExt(MI, MIRBuilder);
case TargetOpcode::G_SPLAT_VECTOR:
return legalizeSplatVector(MI, MIRBuilder);
case TargetOpcode::G_EXTRACT_SUBVECTOR:
return legalizeExtractSubvector(MI, Helper, MIRBuilder);
case TargetOpcode::G_LOAD:
case TargetOpcode::G_STORE:
return legalizeLoadStore(MI, Helper, MIRBuilder);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class RISCVLegalizerInfo : public LegalizerInfo {
bool legalizeVScale(MachineInstr &MI, MachineIRBuilder &MIB) const;
bool legalizeExt(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
bool legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const;
bool legalizeExtractSubvector(MachineInstr &MI, LegalizerHelper &Helper,
MachineIRBuilder &MIB) const;
bool legalizeLoadStore(MachineInstr &MI, LegalizerHelper &Helper,
MachineIRBuilder &MIB) const;
};
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrGISel.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,13 @@ def G_SPLAT_VECTOR_SPLIT_I64_VL : RISCVGenericInstruction {
let InOperandList = (ins type0:$passthru, type1:$hi, type1:$lo, type2:$vl);
let hasSideEffects = false;
}

// Pseudo equivalent to a RISCVISD::VSLIDEDOWN_VL
def G_VSLIDEDOWN_VL : RISCVGenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type0:$merge, type0:$vec, type1:$idx, type2:$mask,
type1:$vl, type1:$policy);
let hasSideEffects = false;
}
def : GINodeEquiv<G_VSLIDEDOWN_VL, riscv_slidedown_vl>;

Loading