Skip to content

[GlobalISel] Refactor extractParts() #75223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,22 +187,6 @@ class LegalizerHelper {
LegalizeResult widenScalarMulo(MachineInstr &MI, unsigned TypeIdx,
LLT WideTy);

/// Helper function to split a wide generic register into bitwise blocks with
/// the given Type (which implies the number of blocks needed). The generic
/// registers created are appended to Ops, starting at bit 0 of Reg.
void extractParts(Register Reg, LLT Ty, int NumParts,
SmallVectorImpl<Register> &VRegs);

/// Version which handles irregular splits.
bool extractParts(Register Reg, LLT RegTy, LLT MainTy,
LLT &LeftoverTy,
SmallVectorImpl<Register> &VRegs,
SmallVectorImpl<Register> &LeftoverVRegs);

/// Version which handles irregular sub-vector splits.
void extractVectorParts(Register Reg, unsigned NumElst,
SmallVectorImpl<Register> &VRegs);

/// Helper function to build a wide generic register \p DstReg of type \p
/// RegTy from smaller parts. This will produce a G_MERGE_VALUES,
/// G_BUILD_VECTOR, G_CONCAT_VECTORS, or sequence of G_INSERT as appropriate
Expand Down
19 changes: 19 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class BlockFrequencyInfo;
class GISelKnownBits;
class MachineFunction;
class MachineInstr;
class MachineIRBuilder;
class MachineOperand;
class MachineOptimizationRemarkEmitter;
class MachineOptimizationRemarkMissed;
Expand Down Expand Up @@ -247,6 +248,24 @@ MachineInstr *getDefIgnoringCopies(Register Reg,
/// Also walks through hints such as G_ASSERT_ZEXT.
Register getSrcRegIgnoringCopies(Register Reg, const MachineRegisterInfo &MRI);

/// Helper function to split a wide generic register into bitwise blocks with
/// the given Type (which implies the number of blocks needed). The generic
/// registers created are appended to Ops, starting at bit 0 of Reg.
void extractParts(Register Reg, LLT Ty, int NumParts,
SmallVectorImpl<Register> &VRegs,
MachineIRBuilder &MIRBuilder, MachineRegisterInfo &MRI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: There is in-class MachineIRBuidler::getMRI(), no need to pass MRI.


/// Version which handles irregular splits.
bool extractParts(Register Reg, LLT RegTy, LLT MainTy, LLT &LeftoverTy,
SmallVectorImpl<Register> &VRegs,
SmallVectorImpl<Register> &LeftoverVRegs,
MachineIRBuilder &MIRBuilder, MachineRegisterInfo &MRI);

/// Version which handles irregular sub-vector splits.
void extractVectorParts(Register Reg, unsigned NumElts,
SmallVectorImpl<Register> &VRegs,
MachineIRBuilder &MIRBuilder, MachineRegisterInfo &MRI);

// Templated variant of getOpcodeDef returning a MachineInstr derived T.
/// See if Reg is defined by an single def instruction of type T
/// Also try to do trivial folding if it's a COPY with
Expand Down
152 changes: 33 additions & 119 deletions llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,100 +158,6 @@ LegalizerHelper::legalizeInstrStep(MachineInstr &MI,
}
}

void LegalizerHelper::extractParts(Register Reg, LLT Ty, int NumParts,
SmallVectorImpl<Register> &VRegs) {
for (int i = 0; i < NumParts; ++i)
VRegs.push_back(MRI.createGenericVirtualRegister(Ty));
MIRBuilder.buildUnmerge(VRegs, Reg);
}

bool LegalizerHelper::extractParts(Register Reg, LLT RegTy,
LLT MainTy, LLT &LeftoverTy,
SmallVectorImpl<Register> &VRegs,
SmallVectorImpl<Register> &LeftoverRegs) {
assert(!LeftoverTy.isValid() && "this is an out argument");

unsigned RegSize = RegTy.getSizeInBits();
unsigned MainSize = MainTy.getSizeInBits();
unsigned NumParts = RegSize / MainSize;
unsigned LeftoverSize = RegSize - NumParts * MainSize;

// Use an unmerge when possible.
if (LeftoverSize == 0) {
for (unsigned I = 0; I < NumParts; ++I)
VRegs.push_back(MRI.createGenericVirtualRegister(MainTy));
MIRBuilder.buildUnmerge(VRegs, Reg);
return true;
}

// Perform irregular split. Leftover is last element of RegPieces.
if (MainTy.isVector()) {
SmallVector<Register, 8> RegPieces;
extractVectorParts(Reg, MainTy.getNumElements(), RegPieces);
for (unsigned i = 0; i < RegPieces.size() - 1; ++i)
VRegs.push_back(RegPieces[i]);
LeftoverRegs.push_back(RegPieces[RegPieces.size() - 1]);
LeftoverTy = MRI.getType(LeftoverRegs[0]);
return true;
}

LeftoverTy = LLT::scalar(LeftoverSize);
// For irregular sizes, extract the individual parts.
for (unsigned I = 0; I != NumParts; ++I) {
Register NewReg = MRI.createGenericVirtualRegister(MainTy);
VRegs.push_back(NewReg);
MIRBuilder.buildExtract(NewReg, Reg, MainSize * I);
}

for (unsigned Offset = MainSize * NumParts; Offset < RegSize;
Offset += LeftoverSize) {
Register NewReg = MRI.createGenericVirtualRegister(LeftoverTy);
LeftoverRegs.push_back(NewReg);
MIRBuilder.buildExtract(NewReg, Reg, Offset);
}

return true;
}

void LegalizerHelper::extractVectorParts(Register Reg, unsigned NumElts,
SmallVectorImpl<Register> &VRegs) {
LLT RegTy = MRI.getType(Reg);
assert(RegTy.isVector() && "Expected a vector type");

LLT EltTy = RegTy.getElementType();
LLT NarrowTy = (NumElts == 1) ? EltTy : LLT::fixed_vector(NumElts, EltTy);
unsigned RegNumElts = RegTy.getNumElements();
unsigned LeftoverNumElts = RegNumElts % NumElts;
unsigned NumNarrowTyPieces = RegNumElts / NumElts;

// Perfect split without leftover
if (LeftoverNumElts == 0)
return extractParts(Reg, NarrowTy, NumNarrowTyPieces, VRegs);

// Irregular split. Provide direct access to all elements for artifact
// combiner using unmerge to elements. Then build vectors with NumElts
// elements. Remaining element(s) will be (used to build vector) Leftover.
SmallVector<Register, 8> Elts;
extractParts(Reg, EltTy, RegNumElts, Elts);

unsigned Offset = 0;
// Requested sub-vectors of NarrowTy.
for (unsigned i = 0; i < NumNarrowTyPieces; ++i, Offset += NumElts) {
ArrayRef<Register> Pieces(&Elts[Offset], NumElts);
VRegs.push_back(MIRBuilder.buildMergeLikeInstr(NarrowTy, Pieces).getReg(0));
}

// Leftover element(s).
if (LeftoverNumElts == 1) {
VRegs.push_back(Elts[Offset]);
} else {
LLT LeftoverTy = LLT::fixed_vector(LeftoverNumElts, EltTy);
ArrayRef<Register> Pieces(&Elts[Offset], LeftoverNumElts);
VRegs.push_back(
MIRBuilder.buildMergeLikeInstr(LeftoverTy, Pieces).getReg(0));
}
}

void LegalizerHelper::insertParts(Register DstReg,
LLT ResultTy, LLT PartTy,
ArrayRef<Register> PartRegs,
Expand Down Expand Up @@ -293,7 +199,8 @@ void LegalizerHelper::appendVectorElts(SmallVectorImpl<Register> &Elts,
Register Reg) {
LLT Ty = MRI.getType(Reg);
SmallVector<Register, 8> RegElts;
extractParts(Reg, Ty.getScalarType(), Ty.getNumElements(), RegElts);
extractParts(Reg, Ty.getScalarType(), Ty.getNumElements(), RegElts,
MIRBuilder, MRI);
Elts.append(RegElts);
}

Expand Down Expand Up @@ -1542,7 +1449,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalar(MachineInstr &MI,
MachineBasicBlock &OpMBB = *MI.getOperand(i + 1).getMBB();
MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
extractParts(MI.getOperand(i).getReg(), NarrowTy, NumParts,
SrcRegs[i / 2]);
SrcRegs[i / 2], MIRBuilder, MRI);
}
MachineBasicBlock &MBB = *MI.getParent();
MIRBuilder.setInsertPt(MBB, MI);
Expand Down Expand Up @@ -1584,13 +1491,13 @@ LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalar(MachineInstr &MI,
LLT LeftoverTy; // Example: s88 -> s64 (NarrowTy) + s24 (leftover)
SmallVector<Register, 4> LHSPartRegs, LHSLeftoverRegs;
if (!extractParts(LHS, SrcTy, NarrowTy, LeftoverTy, LHSPartRegs,
LHSLeftoverRegs))
LHSLeftoverRegs, MIRBuilder, MRI))
return UnableToLegalize;

LLT Unused; // Matches LeftoverTy; G_ICMP LHS and RHS are the same type.
SmallVector<Register, 4> RHSPartRegs, RHSLeftoverRegs;
if (!extractParts(MI.getOperand(3).getReg(), SrcTy, NarrowTy, Unused,
RHSPartRegs, RHSLeftoverRegs))
RHSPartRegs, RHSLeftoverRegs, MIRBuilder, MRI))
return UnableToLegalize;

// We now have the LHS and RHS of the compare split into narrow-type
Expand Down Expand Up @@ -1744,7 +1651,8 @@ LegalizerHelper::LegalizeResult LegalizerHelper::narrowScalar(MachineInstr &MI,
Observer.changingInstr(MI);
SmallVector<Register, 2> SrcRegs, DstRegs;
unsigned NumParts = SizeOp0 / NarrowSize;
extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs);
extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs,
MIRBuilder, MRI);

for (unsigned i = 0; i < NumParts; ++i) {
auto DstPart = MIRBuilder.buildInstr(MI.getOpcode(), {NarrowTy},
Expand Down Expand Up @@ -4194,7 +4102,8 @@ LegalizerHelper::fewerElementsVectorMultiEltType(
MI.getOperand(UseIdx));
} else {
SmallVector<Register, 8> SplitPieces;
extractVectorParts(MI.getReg(UseIdx), NumElts, SplitPieces);
extractVectorParts(MI.getReg(UseIdx), NumElts, SplitPieces, MIRBuilder,
MRI);
for (auto Reg : SplitPieces)
InputOpsPieces[UseNo].push_back(Reg);
}
Expand Down Expand Up @@ -4250,7 +4159,8 @@ LegalizerHelper::fewerElementsVectorPhi(GenericMachineInstr &MI,
UseIdx += 2, ++UseNo) {
MachineBasicBlock &OpMBB = *MI.getOperand(UseIdx + 1).getMBB();
MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
extractVectorParts(MI.getReg(UseIdx), NumElts, InputOpsPieces[UseNo]);
extractVectorParts(MI.getReg(UseIdx), NumElts, InputOpsPieces[UseNo],
MIRBuilder, MRI);
}

// Build PHIs with fewer elements.
Expand Down Expand Up @@ -4519,7 +4429,7 @@ LegalizerHelper::reduceLoadStoreWidth(GLoadStore &LdStMI, unsigned TypeIdx,
std::tie(NumParts, NumLeftover) = getNarrowTypeBreakDown(ValTy, NarrowTy, LeftoverTy);
} else {
if (extractParts(ValReg, ValTy, NarrowTy, LeftoverTy, NarrowRegs,
NarrowLeftoverRegs)) {
NarrowLeftoverRegs, MIRBuilder, MRI)) {
NumParts = NarrowRegs.size();
NumLeftover = NarrowLeftoverRegs.size();
}
Expand Down Expand Up @@ -4765,8 +4675,8 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorShuffle(
unsigned NewElts = NarrowTy.getNumElements();

SmallVector<Register> SplitSrc1Regs, SplitSrc2Regs;
extractParts(Src1Reg, NarrowTy, 2, SplitSrc1Regs);
extractParts(Src2Reg, NarrowTy, 2, SplitSrc2Regs);
extractParts(Src1Reg, NarrowTy, 2, SplitSrc1Regs, MIRBuilder, MRI);
extractParts(Src2Reg, NarrowTy, 2, SplitSrc2Regs, MIRBuilder, MRI);
Register Inputs[4] = {SplitSrc1Regs[0], SplitSrc1Regs[1], SplitSrc2Regs[0],
SplitSrc2Regs[1]};

Expand Down Expand Up @@ -4900,7 +4810,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
NarrowTy.isVector() ? SrcTy.getNumElements() / NarrowTy.getNumElements()
: SrcTy.getNumElements();

extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs, MIRBuilder, MRI);
if (NarrowTy.isScalar()) {
if (DstTy != NarrowTy)
return UnableToLegalize; // FIXME: handle implicit extensions.
Expand Down Expand Up @@ -4983,7 +4893,7 @@ LegalizerHelper::fewerElementsVectorSeqReductions(MachineInstr &MI,

SmallVector<Register> SplitSrcs;
unsigned NumParts = SrcTy.getNumElements();
extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs, MIRBuilder, MRI);
Register Acc = ScalarReg;
for (unsigned i = 0; i < NumParts; i++)
Acc = MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {Acc, SplitSrcs[i]})
Expand All @@ -5001,7 +4911,8 @@ LegalizerHelper::tryNarrowPow2Reduction(MachineInstr &MI, Register SrcReg,
SmallVector<Register> SplitSrcs;
// Split the sources into NarrowTy size pieces.
extractParts(SrcReg, NarrowTy,
SrcTy.getNumElements() / NarrowTy.getNumElements(), SplitSrcs);
SrcTy.getNumElements() / NarrowTy.getNumElements(), SplitSrcs,
MIRBuilder, MRI);
// We're going to do a tree reduction using vector operations until we have
// one NarrowTy size value left.
while (SplitSrcs.size() > 1) {
Expand Down Expand Up @@ -5640,8 +5551,10 @@ LegalizerHelper::narrowScalarAddSub(MachineInstr &MI, unsigned TypeIdx,
LLT RegTy = MRI.getType(MI.getOperand(0).getReg());
LLT LeftoverTy, DummyTy;
SmallVector<Register, 2> Src1Regs, Src2Regs, Src1Left, Src2Left, DstRegs;
extractParts(Src1, RegTy, NarrowTy, LeftoverTy, Src1Regs, Src1Left);
extractParts(Src2, RegTy, NarrowTy, DummyTy, Src2Regs, Src2Left);
extractParts(Src1, RegTy, NarrowTy, LeftoverTy, Src1Regs, Src1Left,
MIRBuilder, MRI);
extractParts(Src2, RegTy, NarrowTy, DummyTy, Src2Regs, Src2Left, MIRBuilder,
MRI);

int NarrowParts = Src1Regs.size();
for (int I = 0, E = Src1Left.size(); I != E; ++I) {
Expand Down Expand Up @@ -5699,8 +5612,8 @@ LegalizerHelper::narrowScalarMul(MachineInstr &MI, LLT NarrowTy) {

SmallVector<Register, 2> Src1Parts, Src2Parts;
SmallVector<Register, 2> DstTmpRegs(DstTmpParts);
extractParts(Src1, NarrowTy, NumParts, Src1Parts);
extractParts(Src2, NarrowTy, NumParts, Src2Parts);
extractParts(Src1, NarrowTy, NumParts, Src1Parts, MIRBuilder, MRI);
extractParts(Src2, NarrowTy, NumParts, Src2Parts, MIRBuilder, MRI);
multiplyRegisters(DstTmpRegs, Src1Parts, Src2Parts, NarrowTy);

// Take only high half of registers if this is high mul.
Expand Down Expand Up @@ -5752,7 +5665,8 @@ LegalizerHelper::narrowScalarExtract(MachineInstr &MI, unsigned TypeIdx,

SmallVector<Register, 2> SrcRegs, DstRegs;
SmallVector<uint64_t, 2> Indexes;
extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs);
extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs,
MIRBuilder, MRI);

Register OpReg = MI.getOperand(0).getReg();
uint64_t OpStart = MI.getOperand(2).getImm();
Expand Down Expand Up @@ -5814,7 +5728,7 @@ LegalizerHelper::narrowScalarInsert(MachineInstr &MI, unsigned TypeIdx,
LLT RegTy = MRI.getType(MI.getOperand(0).getReg());
LLT LeftoverTy;
extractParts(MI.getOperand(1).getReg(), RegTy, NarrowTy, LeftoverTy, SrcRegs,
LeftoverRegs);
LeftoverRegs, MIRBuilder, MRI);

for (Register Reg : LeftoverRegs)
SrcRegs.push_back(Reg);
Expand Down Expand Up @@ -5899,12 +5813,12 @@ LegalizerHelper::narrowScalarBasic(MachineInstr &MI, unsigned TypeIdx,
SmallVector<Register, 4> Src1Regs, Src1LeftoverRegs;
LLT LeftoverTy;
if (!extractParts(MI.getOperand(1).getReg(), DstTy, NarrowTy, LeftoverTy,
Src0Regs, Src0LeftoverRegs))
Src0Regs, Src0LeftoverRegs, MIRBuilder, MRI))
return UnableToLegalize;

LLT Unused;
if (!extractParts(MI.getOperand(2).getReg(), DstTy, NarrowTy, Unused,
Src1Regs, Src1LeftoverRegs))
Src1Regs, Src1LeftoverRegs, MIRBuilder, MRI))
llvm_unreachable("inconsistent extractParts result");

for (unsigned I = 0, E = Src1Regs.size(); I != E; ++I) {
Expand Down Expand Up @@ -5967,12 +5881,12 @@ LegalizerHelper::narrowScalarSelect(MachineInstr &MI, unsigned TypeIdx,
SmallVector<Register, 4> Src2Regs, Src2LeftoverRegs;
LLT LeftoverTy;
if (!extractParts(MI.getOperand(2).getReg(), DstTy, NarrowTy, LeftoverTy,
Src1Regs, Src1LeftoverRegs))
Src1Regs, Src1LeftoverRegs, MIRBuilder, MRI))
return UnableToLegalize;

LLT Unused;
if (!extractParts(MI.getOperand(3).getReg(), DstTy, NarrowTy, Unused,
Src2Regs, Src2LeftoverRegs))
Src2Regs, Src2LeftoverRegs, MIRBuilder, MRI))
llvm_unreachable("inconsistent extractParts result");

for (unsigned I = 0, E = Src1Regs.size(); I != E; ++I) {
Expand Down Expand Up @@ -6468,7 +6382,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerTRUNC(MachineInstr &MI) {

// First, split the source into two smaller vectors.
SmallVector<Register, 2> SplitSrcs;
extractParts(SrcReg, SplitSrcTy, 2, SplitSrcs);
extractParts(SrcReg, SplitSrcTy, 2, SplitSrcs, MIRBuilder, MRI);

// Truncate the splits into intermediate narrower elements.
LLT InterTy;
Expand Down Expand Up @@ -7208,7 +7122,7 @@ LegalizerHelper::lowerExtractInsertVectorElt(MachineInstr &MI) {
int64_t IdxVal;
if (mi_match(Idx, MRI, m_ICst(IdxVal)) && IdxVal <= NumElts) {
SmallVector<Register, 8> SrcRegs;
extractParts(SrcVec, EltTy, NumElts, SrcRegs);
extractParts(SrcVec, EltTy, NumElts, SrcRegs, MIRBuilder, MRI);

if (InsertVal) {
SrcRegs[IdxVal] = MI.getOperand(2).getReg();
Expand Down
Loading