-
Notifications
You must be signed in to change notification settings - Fork 14.3k
AMDGPU/GlobalISel: RegBankLegalize rules for load #112882
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AMDGPU/GlobalISel: RegBankLegalize rules for load #112882
Conversation
This stack of pull requests is managed by Graphite. Learn more about stacking. |
@llvm/pr-subscribers-llvm-globalisel @llvm/pr-subscribers-backend-amdgpu Author: Petar Avramovic (petar-avramovic) ChangesAdd IDs for bit width that cover multiple LLTs: B32 B64 etc. Patch is 81.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112882.diff 6 Files Affected:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
index a0f6ecedab7a83..f58f0a315096d2 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
@@ -37,6 +37,97 @@ bool RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
return true;
}
+void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
+ ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
+ MachineFunction &MF = B.getMF();
+ assert(MI.getNumMemOperands() == 1);
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin();
+ Register Dst = MI.getOperand(0).getReg();
+ const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
+ Register BasePtrReg = MI.getOperand(1).getReg();
+ LLT PtrTy = MRI.getType(BasePtrReg);
+ const RegisterBank *PtrRB = MRI.getRegBankOrNull(BasePtrReg);
+ LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
+ SmallVector<Register, 4> LoadPartRegs;
+
+ unsigned ByteOffset = 0;
+ for (LLT PartTy : LLTBreakdown) {
+ Register BasePtrPlusOffsetReg;
+ if (ByteOffset == 0) {
+ BasePtrPlusOffsetReg = BasePtrReg;
+ } else {
+ BasePtrPlusOffsetReg = MRI.createVirtualRegister({PtrRB, PtrTy});
+ Register OffsetReg = MRI.createVirtualRegister({PtrRB, OffsetTy});
+ B.buildConstant(OffsetReg, ByteOffset);
+ B.buildPtrAdd(BasePtrPlusOffsetReg, BasePtrReg, OffsetReg);
+ }
+ MachineMemOperand *BasePtrPlusOffsetMMO =
+ MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
+ Register PartLoad = MRI.createVirtualRegister({DstRB, PartTy});
+ B.buildLoad(PartLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
+ LoadPartRegs.push_back(PartLoad);
+ ByteOffset += PartTy.getSizeInBytes();
+ }
+
+ if (!MergeTy.isValid()) {
+ // Loads are of same size, concat or merge them together.
+ B.buildMergeLikeInstr(Dst, LoadPartRegs);
+ } else {
+ // Load(s) are not all of same size, need to unmerge them to smaller pieces
+ // of MergeTy type, then merge them all together in Dst.
+ SmallVector<Register, 4> MergeTyParts;
+ for (Register Reg : LoadPartRegs) {
+ if (MRI.getType(Reg) == MergeTy) {
+ MergeTyParts.push_back(Reg);
+ } else {
+ auto Unmerge = B.buildUnmerge(MergeTy, Reg);
+ for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) {
+ Register UnmergeReg = Unmerge->getOperand(i).getReg();
+ MRI.setRegBank(UnmergeReg, *DstRB);
+ MergeTyParts.push_back(UnmergeReg);
+ }
+ }
+ }
+ B.buildMergeLikeInstr(Dst, MergeTyParts);
+ }
+ MI.eraseFromParent();
+}
+
+void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
+ LLT MergeTy) {
+ MachineFunction &MF = B.getMF();
+ assert(MI.getNumMemOperands() == 1);
+ MachineMemOperand &BaseMMO = **MI.memoperands_begin();
+ Register Dst = MI.getOperand(0).getReg();
+ const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
+ Register BasePtrReg = MI.getOperand(1).getReg();
+
+ Register BasePtrPlusOffsetReg;
+ BasePtrPlusOffsetReg = BasePtrReg;
+
+ MachineMemOperand *BasePtrPlusOffsetMMO =
+ MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
+ Register WideLoad = MRI.createVirtualRegister({DstRB, WideTy});
+ B.buildLoad(WideLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
+
+ if (WideTy.isScalar()) {
+ B.buildTrunc(Dst, WideLoad);
+ } else {
+ SmallVector<Register, 4> MergeTyParts;
+ unsigned NumEltsMerge =
+ MRI.getType(Dst).getSizeInBits() / MergeTy.getSizeInBits();
+ auto Unmerge = B.buildUnmerge(MergeTy, WideLoad);
+ for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) {
+ Register UnmergeReg = Unmerge->getOperand(i).getReg();
+ MRI.setRegBank(UnmergeReg, *DstRB);
+ if (i < NumEltsMerge)
+ MergeTyParts.push_back(UnmergeReg);
+ }
+ B.buildMergeLikeInstr(Dst, MergeTyParts);
+ }
+ MI.eraseFromParent();
+}
+
void RegBankLegalizeHelper::lower(MachineInstr &MI,
const RegBankLLTMapping &Mapping,
SmallSet<Register, 4> &WaterfallSGPRs) {
@@ -119,6 +210,53 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
MI.eraseFromParent();
break;
}
+ case SplitLoad: {
+ LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+ LLT V8S16 = LLT::fixed_vector(8, S16);
+ LLT V4S32 = LLT::fixed_vector(4, S32);
+ LLT V2S64 = LLT::fixed_vector(2, S64);
+
+ if (DstTy == LLT::fixed_vector(8, S32))
+ splitLoad(MI, {V4S32, V4S32});
+ else if (DstTy == LLT::fixed_vector(16, S32))
+ splitLoad(MI, {V4S32, V4S32, V4S32, V4S32});
+ else if (DstTy == LLT::fixed_vector(4, S64))
+ splitLoad(MI, {V2S64, V2S64});
+ else if (DstTy == LLT::fixed_vector(8, S64))
+ splitLoad(MI, {V2S64, V2S64, V2S64, V2S64});
+ else if (DstTy == LLT::fixed_vector(16, S16))
+ splitLoad(MI, {V8S16, V8S16});
+ else if (DstTy == LLT::fixed_vector(32, S16))
+ splitLoad(MI, {V8S16, V8S16, V8S16, V8S16});
+ else if (DstTy == LLT::scalar(256))
+ splitLoad(MI, {LLT::scalar(128), LLT::scalar(128)});
+ else if (DstTy == LLT::scalar(96))
+ splitLoad(MI, {S64, S32}, S32);
+ else if (DstTy == LLT::fixed_vector(3, S32))
+ splitLoad(MI, {LLT::fixed_vector(2, S32), S32}, S32);
+ else if (DstTy == LLT::fixed_vector(6, S16))
+ splitLoad(MI, {LLT::fixed_vector(4, S16), LLT::fixed_vector(2, S16)},
+ LLT::fixed_vector(2, S16));
+ else {
+ MI.dump();
+ llvm_unreachable("SplitLoad type not supported\n");
+ }
+ break;
+ }
+ case WidenLoad: {
+ LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+ if (DstTy == LLT::scalar(96))
+ widenLoad(MI, LLT::scalar(128));
+ else if (DstTy == LLT::fixed_vector(3, S32))
+ widenLoad(MI, LLT::fixed_vector(4, S32), S32);
+ else if (DstTy == LLT::fixed_vector(6, S16))
+ widenLoad(MI, LLT::fixed_vector(8, S16), LLT::fixed_vector(2, S16));
+ else {
+ MI.dump();
+ llvm_unreachable("WidenLoad type not supported\n");
+ }
+ break;
+ }
}
// TODO: executeInWaterfallLoop(... WaterfallSGPRs)
@@ -142,13 +280,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMapingApplyID ID) {
case Sgpr64:
case Vgpr64:
return LLT::scalar(64);
-
+ case SgprP1:
+ case VgprP1:
+ return LLT::pointer(1, 64);
+ case SgprP3:
+ case VgprP3:
+ return LLT::pointer(3, 32);
+ case SgprP4:
+ case VgprP4:
+ return LLT::pointer(4, 64);
+ case SgprP5:
+ case VgprP5:
+ return LLT::pointer(5, 32);
case SgprV4S32:
case VgprV4S32:
case UniInVgprV4S32:
return LLT::fixed_vector(4, 32);
- case VgprP1:
- return LLT::pointer(1, 64);
+ default:
+ return LLT();
+ }
+}
+
+LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty) {
+ switch (ID) {
+ case SgprB32:
+ case VgprB32:
+ case UniInVgprB32:
+ if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
+ Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
+ Ty == LLT::pointer(6, 32))
+ return Ty;
+ return LLT();
+ case SgprB64:
+ case VgprB64:
+ case UniInVgprB64:
+ if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
+ Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
+ Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
+ return Ty;
+ return LLT();
+ case SgprB96:
+ case VgprB96:
+ case UniInVgprB96:
+ if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
+ Ty == LLT::fixed_vector(6, 16))
+ return Ty;
+ return LLT();
+ case SgprB128:
+ case VgprB128:
+ case UniInVgprB128:
+ if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
+ Ty == LLT::fixed_vector(2, 64))
+ return Ty;
+ return LLT();
+ case SgprB256:
+ case VgprB256:
+ case UniInVgprB256:
+ if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
+ Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
+ return Ty;
+ return LLT();
+ case SgprB512:
+ case VgprB512:
+ case UniInVgprB512:
+ if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
+ Ty == LLT::fixed_vector(8, 64))
+ return Ty;
+ return LLT();
default:
return LLT();
}
@@ -163,10 +361,26 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
case Sgpr16:
case Sgpr32:
case Sgpr64:
+ case SgprP1:
+ case SgprP3:
+ case SgprP4:
+ case SgprP5:
case SgprV4S32:
+ case SgprB32:
+ case SgprB64:
+ case SgprB96:
+ case SgprB128:
+ case SgprB256:
+ case SgprB512:
case UniInVcc:
case UniInVgprS32:
case UniInVgprV4S32:
+ case UniInVgprB32:
+ case UniInVgprB64:
+ case UniInVgprB96:
+ case UniInVgprB128:
+ case UniInVgprB256:
+ case UniInVgprB512:
case Sgpr32Trunc:
case Sgpr32AExt:
case Sgpr32AExtBoolInReg:
@@ -176,7 +390,16 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
case Vgpr32:
case Vgpr64:
case VgprP1:
+ case VgprP3:
+ case VgprP4:
+ case VgprP5:
case VgprV4S32:
+ case VgprB32:
+ case VgprB64:
+ case VgprB96:
+ case VgprB128:
+ case VgprB256:
+ case VgprB512:
return VgprRB;
default:
@@ -202,17 +425,42 @@ void RegBankLegalizeHelper::applyMappingDst(
case Sgpr16:
case Sgpr32:
case Sgpr64:
+ case SgprP1:
+ case SgprP3:
+ case SgprP4:
+ case SgprP5:
case SgprV4S32:
case Vgpr32:
case Vgpr64:
case VgprP1:
+ case VgprP3:
+ case VgprP4:
+ case VgprP5:
case VgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
assert(RB == getRBFromID(MethodIDs[OpIdx]));
break;
}
- // uniform in vcc/vgpr: scalars and vectors
+ // sgpr and vgpr B-types
+ case SgprB32:
+ case SgprB64:
+ case SgprB96:
+ case SgprB128:
+ case SgprB256:
+ case SgprB512:
+ case VgprB32:
+ case VgprB64:
+ case VgprB96:
+ case VgprB128:
+ case VgprB256:
+ case VgprB512: {
+ assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
+ assert(RB == getRBFromID(MethodIDs[OpIdx]));
+ break;
+ }
+
+ // uniform in vcc/vgpr: scalars, vectors and B-types
case UniInVcc: {
assert(Ty == S1);
assert(RB == SgprRB);
@@ -229,6 +477,17 @@ void RegBankLegalizeHelper::applyMappingDst(
AMDGPU::buildReadAnyLaneDst(B, MI, RBI);
break;
}
+ case UniInVgprB32:
+ case UniInVgprB64:
+ case UniInVgprB96:
+ case UniInVgprB128:
+ case UniInVgprB256:
+ case UniInVgprB512: {
+ assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
+ assert(RB == SgprRB);
+ AMDGPU::buildReadAnyLaneDst(B, MI, RBI);
+ break;
+ }
// sgpr trunc
case Sgpr32Trunc: {
@@ -279,16 +538,34 @@ void RegBankLegalizeHelper::applyMappingSrc(
case Sgpr16:
case Sgpr32:
case Sgpr64:
+ case SgprP1:
+ case SgprP3:
+ case SgprP4:
+ case SgprP5:
case SgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[i]));
assert(RB == getRBFromID(MethodIDs[i]));
break;
}
+ // sgpr B-types
+ case SgprB32:
+ case SgprB64:
+ case SgprB96:
+ case SgprB128:
+ case SgprB256:
+ case SgprB512: {
+ assert(Ty == getBTyFromID(MethodIDs[i], Ty));
+ assert(RB == getRBFromID(MethodIDs[i]));
+ break;
+ }
// vgpr scalars, pointers and vectors
case Vgpr32:
case Vgpr64:
case VgprP1:
+ case VgprP3:
+ case VgprP4:
+ case VgprP5:
case VgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[i]));
if (RB != VgprRB) {
@@ -298,6 +575,21 @@ void RegBankLegalizeHelper::applyMappingSrc(
}
break;
}
+ // vgpr B-types
+ case VgprB32:
+ case VgprB64:
+ case VgprB96:
+ case VgprB128:
+ case VgprB256:
+ case VgprB512: {
+ assert(Ty == getBTyFromID(MethodIDs[i], Ty));
+ if (RB != VgprRB) {
+ auto CopyToVgpr =
+ B.buildCopy(createVgpr(getBTyFromID(MethodIDs[i], Ty)), Reg);
+ Op.setReg(CopyToVgpr.getReg(0));
+ }
+ break;
+ }
// sgpr and vgpr scalars with extend
case Sgpr32AExt: {
@@ -372,7 +664,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
// We accept all types that can fit in some register class.
// Uniform G_PHIs have all sgpr registers.
// Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
- if (Ty == LLT::scalar(32)) {
+ if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
return;
}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
index e23dfcebe3fe3f..c409df54519c5c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
@@ -92,6 +92,7 @@ class RegBankLegalizeHelper {
SmallSet<Register, 4> &SGPROperandRegs);
LLT getTyFromID(RegBankLLTMapingApplyID ID);
+ LLT getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty);
const RegisterBank *getRBFromID(RegBankLLTMapingApplyID ID);
@@ -104,9 +105,9 @@ class RegBankLegalizeHelper {
const SmallVectorImpl<RegBankLLTMapingApplyID> &MethodIDs,
SmallSet<Register, 4> &SGPRWaterfallOperandRegs);
- unsigned setBufferOffsets(MachineIRBuilder &B, Register CombinedOffset,
- Register &VOffsetReg, Register &SOffsetReg,
- int64_t &InstOffsetVal, Align Alignment);
+ void splitLoad(MachineInstr &MI, ArrayRef<LLT> LLTBreakdown,
+ LLT MergeTy = LLT());
+ void widenLoad(MachineInstr &MI, LLT WideTy, LLT MergeTy = LLT());
void lower(MachineInstr &MI, const RegBankLLTMapping &Mapping,
SmallSet<Register, 4> &SGPRWaterfallOperandRegs);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
index 1266f99c79c395..895a596cf84f40 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
@@ -14,9 +14,11 @@
//===----------------------------------------------------------------------===//
#include "AMDGPURBLegalizeRules.h"
+#include "AMDGPUInstrInfo.h"
#include "GCNSubtarget.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
+#include "llvm/Support/AMDGPUAddrSpace.h"
using namespace llvm;
using namespace AMDGPU;
@@ -47,6 +49,24 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::scalar(64);
case P1:
return MRI.getType(Reg) == LLT::pointer(1, 64);
+ case P3:
+ return MRI.getType(Reg) == LLT::pointer(3, 32);
+ case P4:
+ return MRI.getType(Reg) == LLT::pointer(4, 64);
+ case P5:
+ return MRI.getType(Reg) == LLT::pointer(5, 32);
+ case B32:
+ return MRI.getType(Reg).getSizeInBits() == 32;
+ case B64:
+ return MRI.getType(Reg).getSizeInBits() == 64;
+ case B96:
+ return MRI.getType(Reg).getSizeInBits() == 96;
+ case B128:
+ return MRI.getType(Reg).getSizeInBits() == 128;
+ case B256:
+ return MRI.getType(Reg).getSizeInBits() == 256;
+ case B512:
+ return MRI.getType(Reg).getSizeInBits() == 512;
case UniS1:
return MRI.getType(Reg) == LLT::scalar(1) && MUI.isUniform(Reg);
@@ -56,6 +76,26 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::scalar(32) && MUI.isUniform(Reg);
case UniS64:
return MRI.getType(Reg) == LLT::scalar(64) && MUI.isUniform(Reg);
+ case UniP1:
+ return MRI.getType(Reg) == LLT::pointer(1, 64) && MUI.isUniform(Reg);
+ case UniP3:
+ return MRI.getType(Reg) == LLT::pointer(3, 32) && MUI.isUniform(Reg);
+ case UniP4:
+ return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg);
+ case UniP5:
+ return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg);
+ case UniB32:
+ return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isUniform(Reg);
+ case UniB64:
+ return MRI.getType(Reg).getSizeInBits() == 64 && MUI.isUniform(Reg);
+ case UniB96:
+ return MRI.getType(Reg).getSizeInBits() == 96 && MUI.isUniform(Reg);
+ case UniB128:
+ return MRI.getType(Reg).getSizeInBits() == 128 && MUI.isUniform(Reg);
+ case UniB256:
+ return MRI.getType(Reg).getSizeInBits() == 256 && MUI.isUniform(Reg);
+ case UniB512:
+ return MRI.getType(Reg).getSizeInBits() == 512 && MUI.isUniform(Reg);
case DivS1:
return MRI.getType(Reg) == LLT::scalar(1) && MUI.isDivergent(Reg);
@@ -65,6 +105,24 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::scalar(64) && MUI.isDivergent(Reg);
case DivP1:
return MRI.getType(Reg) == LLT::pointer(1, 64) && MUI.isDivergent(Reg);
+ case DivP3:
+ return MRI.getType(Reg) == LLT::pointer(3, 32) && MUI.isDivergent(Reg);
+ case DivP4:
+ return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg);
+ case DivP5:
+ return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg);
+ case DivB32:
+ return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isDivergent(Reg);
+ case DivB64:
+ return MRI.getType(Reg).getSizeInBits() == 64 && MUI.isDivergent(Reg);
+ case DivB96:
+ return MRI.getType(Reg).getSizeInBits() == 96 && MUI.isDivergent(Reg);
+ case DivB128:
+ return MRI.getType(Reg).getSizeInBits() == 128 && MUI.isDivergent(Reg);
+ case DivB256:
+ return MRI.getType(Reg).getSizeInBits() == 256 && MUI.isDivergent(Reg);
+ case DivB512:
+ return MRI.getType(Reg).getSizeInBits() == 512 && MUI.isDivergent(Reg);
case _:
return true;
@@ -124,6 +182,22 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) {
return _;
}
+UniformityLLTOpPredicateID LLTToBId(LLT Ty) {
+ if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
+ Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
+ Ty == LLT::pointer(6, 32))
+ return B32;
+ if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
+ Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(1, 64) ||
+ Ty == LLT::pointer(4, 64))
+ return B64;
+ if (Ty == LLT::fixed_vector(3, 32))
+ return B96;
+ if (Ty == LLT::fixed_vector(4, 32))
+ return B128;
+ return _;
+}
+
const RegBankLLTMapping &
SetOfRulesForOpcode::findMappingForMI(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
@@ -134,7 +208,12 @@ SetOfRulesForOpcode::findMappingForMI(const MachineInstr &MI,
// returned which results in failure, does not search "Slow Rules".
if (FastTypes != No) {
Register Reg = MI.getOperand(0).getReg();
- int Slot = getFastPredicateSlot(LLTToId(MRI.getType(Reg)));
+ int Slot;
+ if (FastTypes == StandardB)
+ Slot = getFastPredicateSlot(LLTToBId(MRI.getType(Reg)));
+ else
+ Slot = getFastPredicateSlot(LLTToId(MRI.getType(Reg)));
+
if (Slot != -1) {
if (MUI.isUniform(Reg))
return Uni[Slot];
@@ -184,6 +263,19 @@ int SetOfRulesForOpcode::getFastPredicateSlot(
default:
return -1;
}
+ case StandardB:
+ switch (Ty) {
+ case B32:
+ return 0;
+ case B64:
+ return 1;
+ case B96:
+ return 2;
+ case B128:
+ return 3;
+ default:
+ return -1;
+ }
case Vector:
switch (Ty) {
case S32:
@@ -236,6 +328,127 @@ RegBankLegalizeRules::getRulesForOpc(MachineInstr &MI) const {
return GRules.at(GRulesAlias.at(Opc));
}
+// Syntactic sugar wrapper for predicate lambda that enables '&&', '||' and '!'.
+class Predicate {
+public:
+ struct Elt {
+ // Save formula composed of Pred, '&&', '||' and '!' as a jump table.
+ // Sink ! to Pred. For example !((A && !B) || C) -> (!A || B) && !C
+ // Sequences of && and || will be represented by jumps, for example:
+ // (A && B && ... X) or (A && B && ... X) || Y
+ // A == true jump to B
+ // A == false jump to end or Y, result is A(false) or Y
+ // (A || B || ... X) or (A || B || ... X) && Y
+ // A == true jump to end or Y, result is B(true) or Y
+ // A == false jump B
+ // Notice that when negating expression, we apply simply flip Neg on each
+ // Pred and swap TJumpOffset and FJumpOffset (&& becomes ||, || becomes &&)....
[truncated]
|
fbaf393
to
921a702
Compare
5cc5acd
to
f354d30
Compare
LLT V4S32 = LLT::fixed_vector(4, S32); | ||
LLT V2S64 = LLT::fixed_vector(2, S64); | ||
|
||
if (DstTy == LLT::fixed_vector(8, S32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rework this into be a function that returns the type to use for the load?
921a702
to
db1cdae
Compare
f354d30
to
0be714e
Compare
db1cdae
to
3370bba
Compare
0be714e
to
619288b
Compare
3370bba
to
3f085e7
Compare
619288b
to
19999f2
Compare
3f085e7
to
f09e1dc
Compare
19999f2
to
b2fd498
Compare
f09e1dc
to
f7ee75a
Compare
b2fd498
to
4675f79
Compare
f7ee75a
to
6008bb3
Compare
4675f79
to
e6285ef
Compare
b736620
to
3854308
Compare
31b4c0d
to
da35db7
Compare
3854308
to
59e70ef
Compare
ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a bunch of comments, but apart from that the change LGTM
// A == true jump to B | ||
// A == false jump to end or Y, result is A(false) or Y | ||
// (A || B || ... X) or (A || B || ... X) && Y | ||
// A == true jump to end or Y, result is B(true) or Y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// A == true jump to end or Y, result is B(true) or Y | |
// A == true jump to end or Y, result is A(true) or Y |
unsigned FJumpOffset; | ||
}; | ||
|
||
SmallVector<Elt, 8> Expression; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this and Elt
private. Elt
is quite subtle and should be hidden as an implementation detail.
Expression.push_back({Pred, false, 1, 1}); | ||
}; | ||
|
||
Predicate(SmallVectorImpl<Elt> &Expr) { Expression.swap(Expr); }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find having a constructor that is destructive on its argument in this way to be quite surprising. It is against good patterns in modern C++. Better have it take an rvalue reference (&&-reference). Then its callers have to explicitly std::move
, but that makes it more obvious what happens at the call sites.
Also, this constructor should be private just like Elt
.
return Result; | ||
}; | ||
|
||
Predicate operator!() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be a const method.
return Predicate(NegExpression); | ||
}; | ||
|
||
Predicate operator&&(Predicate &RHS) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the constructor above, the baseline of this should be a const method that accepts a const reference as an argument.
Do you have a good argument for the rvalue-argument overload below? If not, please just remove it.
The same applies to operator||
.
@@ -290,7 +504,86 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST, | |||
.Any({{UniS64, S32}, {{Sgpr64}, {Sgpr32}, Ext32To64}}) | |||
.Any({{DivS64, S32}, {{Vgpr64}, {Vgpr32}, Ext32To64}}); | |||
|
|||
addRulesForGOpcs({G_LOAD}).Any({{DivS32, DivP1}, {{Vgpr32}, {VgprP1}}}); | |||
bool hasUnAlignedLoads = ST->getGeneration() >= AMDGPUSubtarget::GFX12; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"unaligned" is a single word:
bool hasUnAlignedLoads = ST->getGeneration() >= AMDGPUSubtarget::GFX12; | |
bool hasUnalignedLoads = ST->getGeneration() >= AMDGPUSubtarget::GFX12; |
if (Size / 128 == 2) | ||
splitLoad(MI, {B128, B128}); | ||
if (Size / 128 == 4) | ||
splitLoad(MI, {B128, B128, B128, B128}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect this to be an else-if chain with an else llvm_unreachable
at the end.
auto isUL = !isAtomicMMO && isUniMMO && (isConst || !isVolatileMMO) && | ||
(isConst || isInvMMO || isNoClobberMMO); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the logic behind the isConst || !isVolatileMMO
part of this predicate? const && volatile doesn't make much sense to me, so why isn't this just !isVolatileMMO
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is copied from current implementation in AMDGPURegisterBankInfo::isScalarLoadLegal
isConst checks for address space in MMO, which can be different that address space of pointer operand
da35db7
to
2b10761
Compare
59e70ef
to
75694f8
Compare
2b10761
to
15a9f49
Compare
75694f8
to
3f80c88
Compare
15a9f49
to
4f1b347
Compare
3f80c88
to
f563cce
Compare
4f1b347
to
a93b3b6
Compare
f563cce
to
3fa31ae
Compare
a93b3b6
to
d7060d4
Compare
3fa31ae
to
78ffc92
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
d7060d4
to
5aed391
Compare
78ffc92
to
23bb343
Compare
5aed391
to
42019ed
Compare
23bb343
to
0adced1
Compare
42019ed
to
813ee0e
Compare
0adced1
to
8cb73f4
Compare
813ee0e
to
b95adf5
Compare
8cb73f4
to
0030251
Compare
Add IDs for bit width that cover multiple LLTs: B32 B64 etc. "Predicate" wrapper class for bool predicate functions used to write pretty rules. Predicates can be combined using &&, || and !. Lowering for splitting and widening loads. Write rules for loads to not change existing mir tests from old regbankselect.
0030251
to
654d07b
Compare
Add IDs for bit width that cover multiple LLTs: B32 B64 etc.
"Predicate" wrapper class for bool predicate functions used to
write pretty rules. Predicates can be combined using &&, || and !.
Lowering for splitting and widening loads.
Write rules for loads to not change existing mir tests from old
regbankselect.