Skip to content

AMDGPU/GlobalISel: RegBankLegalize rules for load #112882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 284 additions & 4 deletions llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,83 @@ void RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
lower(MI, Mapping, WaterfallSgprs);
}

void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
MachineFunction &MF = B.getMF();
assert(MI.getNumMemOperands() == 1);
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
Register Dst = MI.getOperand(0).getReg();
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
Register Base = MI.getOperand(1).getReg();
LLT PtrTy = MRI.getType(Base);
const RegisterBank *PtrRB = MRI.getRegBankOrNull(Base);
LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
SmallVector<Register, 4> LoadPartRegs;

unsigned ByteOffset = 0;
for (LLT PartTy : LLTBreakdown) {
Register BasePlusOffset;
if (ByteOffset == 0) {
BasePlusOffset = Base;
} else {
auto Offset = B.buildConstant({PtrRB, OffsetTy}, ByteOffset);
BasePlusOffset = B.buildPtrAdd({PtrRB, PtrTy}, Base, Offset).getReg(0);
}
auto *OffsetMMO = MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
auto LoadPart = B.buildLoad({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
LoadPartRegs.push_back(LoadPart.getReg(0));
ByteOffset += PartTy.getSizeInBytes();
}

if (!MergeTy.isValid()) {
// Loads are of same size, concat or merge them together.
B.buildMergeLikeInstr(Dst, LoadPartRegs);
} else {
// Loads are not all of same size, need to unmerge them to smaller pieces
// of MergeTy type, then merge pieces to Dst.
SmallVector<Register, 4> MergeTyParts;
for (Register Reg : LoadPartRegs) {
if (MRI.getType(Reg) == MergeTy) {
MergeTyParts.push_back(Reg);
} else {
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, Reg);
for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i)
MergeTyParts.push_back(Unmerge.getReg(i));
}
}
B.buildMergeLikeInstr(Dst, MergeTyParts);
}
MI.eraseFromParent();
}

void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
LLT MergeTy) {
MachineFunction &MF = B.getMF();
assert(MI.getNumMemOperands() == 1);
MachineMemOperand &BaseMMO = **MI.memoperands_begin();
Register Dst = MI.getOperand(0).getReg();
const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
Register Base = MI.getOperand(1).getReg();

MachineMemOperand *WideMMO = MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
auto WideLoad = B.buildLoad({DstRB, WideTy}, Base, *WideMMO);

if (WideTy.isScalar()) {
B.buildTrunc(Dst, WideLoad);
} else {
SmallVector<Register, 4> MergeTyParts;
auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, WideLoad);

LLT DstTy = MRI.getType(Dst);
unsigned NumElts = DstTy.getSizeInBits() / MergeTy.getSizeInBits();
for (unsigned i = 0; i < NumElts; ++i) {
MergeTyParts.push_back(Unmerge.getReg(i));
}
B.buildMergeLikeInstr(Dst, MergeTyParts);
}
MI.eraseFromParent();
}

void RegBankLegalizeHelper::lower(MachineInstr &MI,
const RegBankLLTMapping &Mapping,
SmallSet<Register, 4> &WaterfallSgprs) {
Expand Down Expand Up @@ -128,6 +205,54 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
MI.eraseFromParent();
break;
}
case SplitLoad: {
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
unsigned Size = DstTy.getSizeInBits();
// Even split to 128-bit loads
if (Size > 128) {
LLT B128;
if (DstTy.isVector()) {
LLT EltTy = DstTy.getElementType();
B128 = LLT::fixed_vector(128 / EltTy.getSizeInBits(), EltTy);
} else {
B128 = LLT::scalar(128);
}
if (Size / 128 == 2)
splitLoad(MI, {B128, B128});
else if (Size / 128 == 4)
splitLoad(MI, {B128, B128, B128, B128});
else {
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
llvm_unreachable("SplitLoad type not supported for MI");
}
}
// 64 and 32 bit load
else if (DstTy == S96)
splitLoad(MI, {S64, S32}, S32);
else if (DstTy == V3S32)
splitLoad(MI, {V2S32, S32}, S32);
else if (DstTy == V6S16)
splitLoad(MI, {V4S16, V2S16}, V2S16);
else {
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
llvm_unreachable("SplitLoad type not supported for MI");
}
break;
}
case WidenLoad: {
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
if (DstTy == S96)
widenLoad(MI, S128);
else if (DstTy == V3S32)
widenLoad(MI, V4S32, S32);
else if (DstTy == V6S16)
widenLoad(MI, V8S16, V2S16);
else {
LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
llvm_unreachable("WidenLoad type not supported for MI");
}
break;
}
}

// TODO: executeInWaterfallLoop(... WaterfallSgprs)
Expand All @@ -151,12 +276,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
case Sgpr64:
case Vgpr64:
return LLT::scalar(64);
case SgprP1:
case VgprP1:
return LLT::pointer(1, 64);
case SgprP3:
case VgprP3:
return LLT::pointer(3, 32);
case SgprP4:
case VgprP4:
return LLT::pointer(4, 64);
case SgprP5:
case VgprP5:
return LLT::pointer(5, 32);
case SgprV4S32:
case VgprV4S32:
case UniInVgprV4S32:
return LLT::fixed_vector(4, 32);
case VgprP1:
return LLT::pointer(1, 64);
default:
return LLT();
}
}

LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
switch (ID) {
case SgprB32:
case VgprB32:
case UniInVgprB32:
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
Ty == LLT::pointer(6, 32))
return Ty;
return LLT();
case SgprB64:
case VgprB64:
case UniInVgprB64:
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
return Ty;
return LLT();
case SgprB96:
case VgprB96:
case UniInVgprB96:
if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
Ty == LLT::fixed_vector(6, 16))
return Ty;
return LLT();
case SgprB128:
case VgprB128:
case UniInVgprB128:
if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
Ty == LLT::fixed_vector(2, 64))
return Ty;
return LLT();
case SgprB256:
case VgprB256:
case UniInVgprB256:
if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
return Ty;
return LLT();
case SgprB512:
case VgprB512:
case UniInVgprB512:
if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
Ty == LLT::fixed_vector(8, 64))
return Ty;
return LLT();
default:
return LLT();
}
Expand All @@ -170,10 +356,26 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
case Sgpr16:
case Sgpr32:
case Sgpr64:
case SgprP1:
case SgprP3:
case SgprP4:
case SgprP5:
case SgprV4S32:
case SgprB32:
case SgprB64:
case SgprB96:
case SgprB128:
case SgprB256:
case SgprB512:
case UniInVcc:
case UniInVgprS32:
case UniInVgprV4S32:
case UniInVgprB32:
case UniInVgprB64:
case UniInVgprB96:
case UniInVgprB128:
case UniInVgprB256:
case UniInVgprB512:
case Sgpr32Trunc:
case Sgpr32AExt:
case Sgpr32AExtBoolInReg:
Expand All @@ -182,7 +384,16 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
case Vgpr32:
case Vgpr64:
case VgprP1:
case VgprP3:
case VgprP4:
case VgprP5:
case VgprV4S32:
case VgprB32:
case VgprB64:
case VgprB96:
case VgprB128:
case VgprB256:
case VgprB512:
return VgprRB;
default:
return nullptr;
Expand All @@ -207,16 +418,40 @@ void RegBankLegalizeHelper::applyMappingDst(
case Sgpr16:
case Sgpr32:
case Sgpr64:
case SgprP1:
case SgprP3:
case SgprP4:
case SgprP5:
case SgprV4S32:
case Vgpr32:
case Vgpr64:
case VgprP1:
case VgprP3:
case VgprP4:
case VgprP5:
case VgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
break;
}
// uniform in vcc/vgpr: scalars and vectors
// sgpr and vgpr B-types
case SgprB32:
case SgprB64:
case SgprB96:
case SgprB128:
case SgprB256:
case SgprB512:
case VgprB32:
case VgprB64:
case VgprB96:
case VgprB128:
case VgprB256:
case VgprB512: {
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
break;
}
// uniform in vcc/vgpr: scalars, vectors and B-types
case UniInVcc: {
assert(Ty == S1);
assert(RB == SgprRB);
Expand All @@ -236,6 +471,19 @@ void RegBankLegalizeHelper::applyMappingDst(
buildReadAnyLane(B, Reg, NewVgprDst, RBI);
break;
}
case UniInVgprB32:
case UniInVgprB64:
case UniInVgprB96:
case UniInVgprB128:
case UniInVgprB256:
case UniInVgprB512: {
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
assert(RB == SgprRB);
Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty});
Op.setReg(NewVgprDst);
AMDGPU::buildReadAnyLane(B, Reg, NewVgprDst, RBI);
break;
}
// sgpr trunc
case Sgpr32Trunc: {
assert(Ty.getSizeInBits() < 32);
Expand Down Expand Up @@ -284,15 +532,33 @@ void RegBankLegalizeHelper::applyMappingSrc(
case Sgpr16:
case Sgpr32:
case Sgpr64:
case SgprP1:
case SgprP3:
case SgprP4:
case SgprP5:
case SgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[i]));
assert(RB == getRegBankFromID(MethodIDs[i]));
break;
}
// sgpr B-types
case SgprB32:
case SgprB64:
case SgprB96:
case SgprB128:
case SgprB256:
case SgprB512: {
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
assert(RB == getRegBankFromID(MethodIDs[i]));
break;
}
// vgpr scalars, pointers and vectors
case Vgpr32:
case Vgpr64:
case VgprP1:
case VgprP3:
case VgprP4:
case VgprP5:
case VgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[i]));
if (RB != VgprRB) {
Expand All @@ -301,6 +567,20 @@ void RegBankLegalizeHelper::applyMappingSrc(
}
break;
}
// vgpr B-types
case VgprB32:
case VgprB64:
case VgprB96:
case VgprB128:
case VgprB256:
case VgprB512: {
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
if (RB != VgprRB) {
auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
Op.setReg(CopyToVgpr.getReg(0));
}
break;
}
// sgpr and vgpr scalars with extend
case Sgpr32AExt: {
// Note: this ext allows S1, and it is meant to be combined away.
Expand Down Expand Up @@ -373,7 +653,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
// We accept all types that can fit in some register class.
// Uniform G_PHIs have all sgpr registers.
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
if (Ty == LLT::scalar(32)) {
if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
return;
}

Expand Down
Loading
Loading