Skip to content

AMDGPU/GlobalISel: RegBankLegalize rules for load #112882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 24, 2025

Conversation

petar-avramovic
Copy link
Collaborator

@petar-avramovic petar-avramovic commented Oct 18, 2024

Add IDs for bit width that cover multiple LLTs: B32 B64 etc.
"Predicate" wrapper class for bool predicate functions used to
write pretty rules. Predicates can be combined using &&, || and !.
Lowering for splitting and widening loads.
Write rules for loads to not change existing mir tests from old
regbankselect.

@llvmbot
Copy link
Member

llvmbot commented Oct 18, 2024

@llvm/pr-subscribers-llvm-globalisel

@llvm/pr-subscribers-backend-amdgpu

Author: Petar Avramovic (petar-avramovic)

Changes

Add IDs for bit width that cover multiple LLTs: B32 B64 etc.
"Predicate" wrapper class for bool predicate functions used to
write pretty rules. Predicates can be combined using &&, || and !.
Lowering for splitting and widening loads.
Write rules for loads to not change existing mir tests from old
regbankselect.


Patch is 81.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112882.diff

6 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp (+297-5)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h (+4-3)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp (+300-7)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.h (+63-2)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect-load.mir (+271-49)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect-zextload.mir (+7-2)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
index a0f6ecedab7a83..f58f0a315096d2 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
@@ -37,6 +37,97 @@ bool RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
   return true;
 }
 
+void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
+                                      ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
+  MachineFunction &MF = B.getMF();
+  assert(MI.getNumMemOperands() == 1);
+  MachineMemOperand &BaseMMO = **MI.memoperands_begin();
+  Register Dst = MI.getOperand(0).getReg();
+  const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
+  Register BasePtrReg = MI.getOperand(1).getReg();
+  LLT PtrTy = MRI.getType(BasePtrReg);
+  const RegisterBank *PtrRB = MRI.getRegBankOrNull(BasePtrReg);
+  LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
+  SmallVector<Register, 4> LoadPartRegs;
+
+  unsigned ByteOffset = 0;
+  for (LLT PartTy : LLTBreakdown) {
+    Register BasePtrPlusOffsetReg;
+    if (ByteOffset == 0) {
+      BasePtrPlusOffsetReg = BasePtrReg;
+    } else {
+      BasePtrPlusOffsetReg = MRI.createVirtualRegister({PtrRB, PtrTy});
+      Register OffsetReg = MRI.createVirtualRegister({PtrRB, OffsetTy});
+      B.buildConstant(OffsetReg, ByteOffset);
+      B.buildPtrAdd(BasePtrPlusOffsetReg, BasePtrReg, OffsetReg);
+    }
+    MachineMemOperand *BasePtrPlusOffsetMMO =
+        MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
+    Register PartLoad = MRI.createVirtualRegister({DstRB, PartTy});
+    B.buildLoad(PartLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
+    LoadPartRegs.push_back(PartLoad);
+    ByteOffset += PartTy.getSizeInBytes();
+  }
+
+  if (!MergeTy.isValid()) {
+    // Loads are of same size, concat or merge them together.
+    B.buildMergeLikeInstr(Dst, LoadPartRegs);
+  } else {
+    // Load(s) are not all of same size, need to unmerge them to smaller pieces
+    // of MergeTy type, then merge them all together in Dst.
+    SmallVector<Register, 4> MergeTyParts;
+    for (Register Reg : LoadPartRegs) {
+      if (MRI.getType(Reg) == MergeTy) {
+        MergeTyParts.push_back(Reg);
+      } else {
+        auto Unmerge = B.buildUnmerge(MergeTy, Reg);
+        for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) {
+          Register UnmergeReg = Unmerge->getOperand(i).getReg();
+          MRI.setRegBank(UnmergeReg, *DstRB);
+          MergeTyParts.push_back(UnmergeReg);
+        }
+      }
+    }
+    B.buildMergeLikeInstr(Dst, MergeTyParts);
+  }
+  MI.eraseFromParent();
+}
+
+void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
+                                      LLT MergeTy) {
+  MachineFunction &MF = B.getMF();
+  assert(MI.getNumMemOperands() == 1);
+  MachineMemOperand &BaseMMO = **MI.memoperands_begin();
+  Register Dst = MI.getOperand(0).getReg();
+  const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
+  Register BasePtrReg = MI.getOperand(1).getReg();
+
+  Register BasePtrPlusOffsetReg;
+  BasePtrPlusOffsetReg = BasePtrReg;
+
+  MachineMemOperand *BasePtrPlusOffsetMMO =
+      MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
+  Register WideLoad = MRI.createVirtualRegister({DstRB, WideTy});
+  B.buildLoad(WideLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
+
+  if (WideTy.isScalar()) {
+    B.buildTrunc(Dst, WideLoad);
+  } else {
+    SmallVector<Register, 4> MergeTyParts;
+    unsigned NumEltsMerge =
+        MRI.getType(Dst).getSizeInBits() / MergeTy.getSizeInBits();
+    auto Unmerge = B.buildUnmerge(MergeTy, WideLoad);
+    for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) {
+      Register UnmergeReg = Unmerge->getOperand(i).getReg();
+      MRI.setRegBank(UnmergeReg, *DstRB);
+      if (i < NumEltsMerge)
+        MergeTyParts.push_back(UnmergeReg);
+    }
+    B.buildMergeLikeInstr(Dst, MergeTyParts);
+  }
+  MI.eraseFromParent();
+}
+
 void RegBankLegalizeHelper::lower(MachineInstr &MI,
                                   const RegBankLLTMapping &Mapping,
                                   SmallSet<Register, 4> &WaterfallSGPRs) {
@@ -119,6 +210,53 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
     MI.eraseFromParent();
     break;
   }
+  case SplitLoad: {
+    LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+    LLT V8S16 = LLT::fixed_vector(8, S16);
+    LLT V4S32 = LLT::fixed_vector(4, S32);
+    LLT V2S64 = LLT::fixed_vector(2, S64);
+
+    if (DstTy == LLT::fixed_vector(8, S32))
+      splitLoad(MI, {V4S32, V4S32});
+    else if (DstTy == LLT::fixed_vector(16, S32))
+      splitLoad(MI, {V4S32, V4S32, V4S32, V4S32});
+    else if (DstTy == LLT::fixed_vector(4, S64))
+      splitLoad(MI, {V2S64, V2S64});
+    else if (DstTy == LLT::fixed_vector(8, S64))
+      splitLoad(MI, {V2S64, V2S64, V2S64, V2S64});
+    else if (DstTy == LLT::fixed_vector(16, S16))
+      splitLoad(MI, {V8S16, V8S16});
+    else if (DstTy == LLT::fixed_vector(32, S16))
+      splitLoad(MI, {V8S16, V8S16, V8S16, V8S16});
+    else if (DstTy == LLT::scalar(256))
+      splitLoad(MI, {LLT::scalar(128), LLT::scalar(128)});
+    else if (DstTy == LLT::scalar(96))
+      splitLoad(MI, {S64, S32}, S32);
+    else if (DstTy == LLT::fixed_vector(3, S32))
+      splitLoad(MI, {LLT::fixed_vector(2, S32), S32}, S32);
+    else if (DstTy == LLT::fixed_vector(6, S16))
+      splitLoad(MI, {LLT::fixed_vector(4, S16), LLT::fixed_vector(2, S16)},
+                LLT::fixed_vector(2, S16));
+    else {
+      MI.dump();
+      llvm_unreachable("SplitLoad type not supported\n");
+    }
+    break;
+  }
+  case WidenLoad: {
+    LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+    if (DstTy == LLT::scalar(96))
+      widenLoad(MI, LLT::scalar(128));
+    else if (DstTy == LLT::fixed_vector(3, S32))
+      widenLoad(MI, LLT::fixed_vector(4, S32), S32);
+    else if (DstTy == LLT::fixed_vector(6, S16))
+      widenLoad(MI, LLT::fixed_vector(8, S16), LLT::fixed_vector(2, S16));
+    else {
+      MI.dump();
+      llvm_unreachable("WidenLoad type not supported\n");
+    }
+    break;
+  }
   }
 
   // TODO: executeInWaterfallLoop(... WaterfallSGPRs)
@@ -142,13 +280,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMapingApplyID ID) {
   case Sgpr64:
   case Vgpr64:
     return LLT::scalar(64);
-
+  case SgprP1:
+  case VgprP1:
+    return LLT::pointer(1, 64);
+  case SgprP3:
+  case VgprP3:
+    return LLT::pointer(3, 32);
+  case SgprP4:
+  case VgprP4:
+    return LLT::pointer(4, 64);
+  case SgprP5:
+  case VgprP5:
+    return LLT::pointer(5, 32);
   case SgprV4S32:
   case VgprV4S32:
   case UniInVgprV4S32:
     return LLT::fixed_vector(4, 32);
-  case VgprP1:
-    return LLT::pointer(1, 64);
+  default:
+    return LLT();
+  }
+}
+
+LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty) {
+  switch (ID) {
+  case SgprB32:
+  case VgprB32:
+  case UniInVgprB32:
+    if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
+        Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
+        Ty == LLT::pointer(6, 32))
+      return Ty;
+    return LLT();
+  case SgprB64:
+  case VgprB64:
+  case UniInVgprB64:
+    if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
+        Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
+        Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
+      return Ty;
+    return LLT();
+  case SgprB96:
+  case VgprB96:
+  case UniInVgprB96:
+    if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
+        Ty == LLT::fixed_vector(6, 16))
+      return Ty;
+    return LLT();
+  case SgprB128:
+  case VgprB128:
+  case UniInVgprB128:
+    if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
+        Ty == LLT::fixed_vector(2, 64))
+      return Ty;
+    return LLT();
+  case SgprB256:
+  case VgprB256:
+  case UniInVgprB256:
+    if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
+        Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
+      return Ty;
+    return LLT();
+  case SgprB512:
+  case VgprB512:
+  case UniInVgprB512:
+    if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
+        Ty == LLT::fixed_vector(8, 64))
+      return Ty;
+    return LLT();
   default:
     return LLT();
   }
@@ -163,10 +361,26 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
   case Sgpr16:
   case Sgpr32:
   case Sgpr64:
+  case SgprP1:
+  case SgprP3:
+  case SgprP4:
+  case SgprP5:
   case SgprV4S32:
+  case SgprB32:
+  case SgprB64:
+  case SgprB96:
+  case SgprB128:
+  case SgprB256:
+  case SgprB512:
   case UniInVcc:
   case UniInVgprS32:
   case UniInVgprV4S32:
+  case UniInVgprB32:
+  case UniInVgprB64:
+  case UniInVgprB96:
+  case UniInVgprB128:
+  case UniInVgprB256:
+  case UniInVgprB512:
   case Sgpr32Trunc:
   case Sgpr32AExt:
   case Sgpr32AExtBoolInReg:
@@ -176,7 +390,16 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
   case Vgpr32:
   case Vgpr64:
   case VgprP1:
+  case VgprP3:
+  case VgprP4:
+  case VgprP5:
   case VgprV4S32:
+  case VgprB32:
+  case VgprB64:
+  case VgprB96:
+  case VgprB128:
+  case VgprB256:
+  case VgprB512:
     return VgprRB;
 
   default:
@@ -202,17 +425,42 @@ void RegBankLegalizeHelper::applyMappingDst(
     case Sgpr16:
     case Sgpr32:
     case Sgpr64:
+    case SgprP1:
+    case SgprP3:
+    case SgprP4:
+    case SgprP5:
     case SgprV4S32:
     case Vgpr32:
     case Vgpr64:
     case VgprP1:
+    case VgprP3:
+    case VgprP4:
+    case VgprP5:
     case VgprV4S32: {
       assert(Ty == getTyFromID(MethodIDs[OpIdx]));
       assert(RB == getRBFromID(MethodIDs[OpIdx]));
       break;
     }
 
-    // uniform in vcc/vgpr: scalars and vectors
+    // sgpr and vgpr B-types
+    case SgprB32:
+    case SgprB64:
+    case SgprB96:
+    case SgprB128:
+    case SgprB256:
+    case SgprB512:
+    case VgprB32:
+    case VgprB64:
+    case VgprB96:
+    case VgprB128:
+    case VgprB256:
+    case VgprB512: {
+      assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
+      assert(RB == getRBFromID(MethodIDs[OpIdx]));
+      break;
+    }
+
+    // uniform in vcc/vgpr: scalars, vectors and B-types
     case UniInVcc: {
       assert(Ty == S1);
       assert(RB == SgprRB);
@@ -229,6 +477,17 @@ void RegBankLegalizeHelper::applyMappingDst(
       AMDGPU::buildReadAnyLaneDst(B, MI, RBI);
       break;
     }
+    case UniInVgprB32:
+    case UniInVgprB64:
+    case UniInVgprB96:
+    case UniInVgprB128:
+    case UniInVgprB256:
+    case UniInVgprB512: {
+      assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
+      assert(RB == SgprRB);
+      AMDGPU::buildReadAnyLaneDst(B, MI, RBI);
+      break;
+    }
 
     // sgpr trunc
     case Sgpr32Trunc: {
@@ -279,16 +538,34 @@ void RegBankLegalizeHelper::applyMappingSrc(
     case Sgpr16:
     case Sgpr32:
     case Sgpr64:
+    case SgprP1:
+    case SgprP3:
+    case SgprP4:
+    case SgprP5:
     case SgprV4S32: {
       assert(Ty == getTyFromID(MethodIDs[i]));
       assert(RB == getRBFromID(MethodIDs[i]));
       break;
     }
+    // sgpr B-types
+    case SgprB32:
+    case SgprB64:
+    case SgprB96:
+    case SgprB128:
+    case SgprB256:
+    case SgprB512: {
+      assert(Ty == getBTyFromID(MethodIDs[i], Ty));
+      assert(RB == getRBFromID(MethodIDs[i]));
+      break;
+    }
 
     // vgpr scalars, pointers and vectors
     case Vgpr32:
     case Vgpr64:
     case VgprP1:
+    case VgprP3:
+    case VgprP4:
+    case VgprP5:
     case VgprV4S32: {
       assert(Ty == getTyFromID(MethodIDs[i]));
       if (RB != VgprRB) {
@@ -298,6 +575,21 @@ void RegBankLegalizeHelper::applyMappingSrc(
       }
       break;
     }
+    // vgpr B-types
+    case VgprB32:
+    case VgprB64:
+    case VgprB96:
+    case VgprB128:
+    case VgprB256:
+    case VgprB512: {
+      assert(Ty == getBTyFromID(MethodIDs[i], Ty));
+      if (RB != VgprRB) {
+        auto CopyToVgpr =
+            B.buildCopy(createVgpr(getBTyFromID(MethodIDs[i], Ty)), Reg);
+        Op.setReg(CopyToVgpr.getReg(0));
+      }
+      break;
+    }
 
     // sgpr and vgpr scalars with extend
     case Sgpr32AExt: {
@@ -372,7 +664,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
   // We accept all types that can fit in some register class.
   // Uniform G_PHIs have all sgpr registers.
   // Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
-  if (Ty == LLT::scalar(32)) {
+  if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
     return;
   }
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
index e23dfcebe3fe3f..c409df54519c5c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
@@ -92,6 +92,7 @@ class RegBankLegalizeHelper {
                               SmallSet<Register, 4> &SGPROperandRegs);
 
   LLT getTyFromID(RegBankLLTMapingApplyID ID);
+  LLT getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty);
 
   const RegisterBank *getRBFromID(RegBankLLTMapingApplyID ID);
 
@@ -104,9 +105,9 @@ class RegBankLegalizeHelper {
                   const SmallVectorImpl<RegBankLLTMapingApplyID> &MethodIDs,
                   SmallSet<Register, 4> &SGPRWaterfallOperandRegs);
 
-  unsigned setBufferOffsets(MachineIRBuilder &B, Register CombinedOffset,
-                            Register &VOffsetReg, Register &SOffsetReg,
-                            int64_t &InstOffsetVal, Align Alignment);
+  void splitLoad(MachineInstr &MI, ArrayRef<LLT> LLTBreakdown,
+                 LLT MergeTy = LLT());
+  void widenLoad(MachineInstr &MI, LLT WideTy, LLT MergeTy = LLT());
 
   void lower(MachineInstr &MI, const RegBankLLTMapping &Mapping,
              SmallSet<Register, 4> &SGPRWaterfallOperandRegs);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
index 1266f99c79c395..895a596cf84f40 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
@@ -14,9 +14,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "AMDGPURBLegalizeRules.h"
+#include "AMDGPUInstrInfo.h"
 #include "GCNSubtarget.h"
 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/IR/IntrinsicsAMDGPU.h"
+#include "llvm/Support/AMDGPUAddrSpace.h"
 
 using namespace llvm;
 using namespace AMDGPU;
@@ -47,6 +49,24 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::scalar(64);
   case P1:
     return MRI.getType(Reg) == LLT::pointer(1, 64);
+  case P3:
+    return MRI.getType(Reg) == LLT::pointer(3, 32);
+  case P4:
+    return MRI.getType(Reg) == LLT::pointer(4, 64);
+  case P5:
+    return MRI.getType(Reg) == LLT::pointer(5, 32);
+  case B32:
+    return MRI.getType(Reg).getSizeInBits() == 32;
+  case B64:
+    return MRI.getType(Reg).getSizeInBits() == 64;
+  case B96:
+    return MRI.getType(Reg).getSizeInBits() == 96;
+  case B128:
+    return MRI.getType(Reg).getSizeInBits() == 128;
+  case B256:
+    return MRI.getType(Reg).getSizeInBits() == 256;
+  case B512:
+    return MRI.getType(Reg).getSizeInBits() == 512;
 
   case UniS1:
     return MRI.getType(Reg) == LLT::scalar(1) && MUI.isUniform(Reg);
@@ -56,6 +76,26 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::scalar(32) && MUI.isUniform(Reg);
   case UniS64:
     return MRI.getType(Reg) == LLT::scalar(64) && MUI.isUniform(Reg);
+  case UniP1:
+    return MRI.getType(Reg) == LLT::pointer(1, 64) && MUI.isUniform(Reg);
+  case UniP3:
+    return MRI.getType(Reg) == LLT::pointer(3, 32) && MUI.isUniform(Reg);
+  case UniP4:
+    return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg);
+  case UniP5:
+    return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg);
+  case UniB32:
+    return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isUniform(Reg);
+  case UniB64:
+    return MRI.getType(Reg).getSizeInBits() == 64 && MUI.isUniform(Reg);
+  case UniB96:
+    return MRI.getType(Reg).getSizeInBits() == 96 && MUI.isUniform(Reg);
+  case UniB128:
+    return MRI.getType(Reg).getSizeInBits() == 128 && MUI.isUniform(Reg);
+  case UniB256:
+    return MRI.getType(Reg).getSizeInBits() == 256 && MUI.isUniform(Reg);
+  case UniB512:
+    return MRI.getType(Reg).getSizeInBits() == 512 && MUI.isUniform(Reg);
 
   case DivS1:
     return MRI.getType(Reg) == LLT::scalar(1) && MUI.isDivergent(Reg);
@@ -65,6 +105,24 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::scalar(64) && MUI.isDivergent(Reg);
   case DivP1:
     return MRI.getType(Reg) == LLT::pointer(1, 64) && MUI.isDivergent(Reg);
+  case DivP3:
+    return MRI.getType(Reg) == LLT::pointer(3, 32) && MUI.isDivergent(Reg);
+  case DivP4:
+    return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg);
+  case DivP5:
+    return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg);
+  case DivB32:
+    return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isDivergent(Reg);
+  case DivB64:
+    return MRI.getType(Reg).getSizeInBits() == 64 && MUI.isDivergent(Reg);
+  case DivB96:
+    return MRI.getType(Reg).getSizeInBits() == 96 && MUI.isDivergent(Reg);
+  case DivB128:
+    return MRI.getType(Reg).getSizeInBits() == 128 && MUI.isDivergent(Reg);
+  case DivB256:
+    return MRI.getType(Reg).getSizeInBits() == 256 && MUI.isDivergent(Reg);
+  case DivB512:
+    return MRI.getType(Reg).getSizeInBits() == 512 && MUI.isDivergent(Reg);
 
   case _:
     return true;
@@ -124,6 +182,22 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) {
   return _;
 }
 
+UniformityLLTOpPredicateID LLTToBId(LLT Ty) {
+  if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
+      Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
+      Ty == LLT::pointer(6, 32))
+    return B32;
+  if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
+      Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(1, 64) ||
+      Ty == LLT::pointer(4, 64))
+    return B64;
+  if (Ty == LLT::fixed_vector(3, 32))
+    return B96;
+  if (Ty == LLT::fixed_vector(4, 32))
+    return B128;
+  return _;
+}
+
 const RegBankLLTMapping &
 SetOfRulesForOpcode::findMappingForMI(const MachineInstr &MI,
                                       const MachineRegisterInfo &MRI,
@@ -134,7 +208,12 @@ SetOfRulesForOpcode::findMappingForMI(const MachineInstr &MI,
   // returned which results in failure, does not search "Slow Rules".
   if (FastTypes != No) {
     Register Reg = MI.getOperand(0).getReg();
-    int Slot = getFastPredicateSlot(LLTToId(MRI.getType(Reg)));
+    int Slot;
+    if (FastTypes == StandardB)
+      Slot = getFastPredicateSlot(LLTToBId(MRI.getType(Reg)));
+    else
+      Slot = getFastPredicateSlot(LLTToId(MRI.getType(Reg)));
+
     if (Slot != -1) {
       if (MUI.isUniform(Reg))
         return Uni[Slot];
@@ -184,6 +263,19 @@ int SetOfRulesForOpcode::getFastPredicateSlot(
     default:
       return -1;
     }
+  case StandardB:
+    switch (Ty) {
+    case B32:
+      return 0;
+    case B64:
+      return 1;
+    case B96:
+      return 2;
+    case B128:
+      return 3;
+    default:
+      return -1;
+    }
   case Vector:
     switch (Ty) {
     case S32:
@@ -236,6 +328,127 @@ RegBankLegalizeRules::getRulesForOpc(MachineInstr &MI) const {
   return GRules.at(GRulesAlias.at(Opc));
 }
 
+// Syntactic sugar wrapper for predicate lambda that enables '&&', '||' and '!'.
+class Predicate {
+public:
+  struct Elt {
+    // Save formula composed of Pred, '&&', '||' and '!' as a jump table.
+    // Sink ! to Pred. For example !((A && !B) || C) -> (!A || B) && !C
+    // Sequences of && and || will be represented by jumps, for example:
+    // (A && B && ... X) or (A && B && ... X) || Y
+    //   A == true jump to B
+    //   A == false jump to end or Y, result is A(false) or Y
+    // (A || B || ... X) or (A || B || ... X) && Y
+    //   A == true jump to end or Y, result is B(true) or Y
+    //   A == false jump B
+    // Notice that when negating expression, we apply simply flip Neg on each
+    // Pred and swap TJumpOffset and FJumpOffset (&& becomes ||, || becomes &&)....
[truncated]

@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from fbaf393 to 921a702 Compare October 18, 2024 12:18
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 5cc5acd to f354d30 Compare October 18, 2024 12:18
LLT V4S32 = LLT::fixed_vector(4, S32);
LLT V2S64 = LLT::fixed_vector(2, S64);

if (DstTy == LLT::fixed_vector(8, S32))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rework this into be a function that returns the type to use for the load?

@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from 921a702 to db1cdae Compare October 22, 2024 16:20
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from f354d30 to 0be714e Compare October 22, 2024 16:21
@petar-avramovic petar-avramovic changed the title AMDGPU/GlobalISel: RBLegalize rules for load AMDGPU/GlobalISel: RegBankLegalize rules for load Oct 23, 2024
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from db1cdae to 3370bba Compare October 28, 2024 14:48
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 0be714e to 619288b Compare October 28, 2024 14:48
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from 3370bba to 3f085e7 Compare October 28, 2024 14:57
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 619288b to 19999f2 Compare October 28, 2024 14:57
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from 3f085e7 to f09e1dc Compare October 28, 2024 16:03
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 19999f2 to b2fd498 Compare October 28, 2024 16:04
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from f09e1dc to f7ee75a Compare October 30, 2024 14:45
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from b2fd498 to 4675f79 Compare October 30, 2024 14:45
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from f7ee75a to 6008bb3 Compare October 31, 2024 13:14
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 4675f79 to e6285ef Compare October 31, 2024 13:14
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from b736620 to 3854308 Compare November 12, 2024 10:54
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from 31b4c0d to da35db7 Compare November 27, 2024 11:24
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 3854308 to 59e70ef Compare November 27, 2024 11:25
@petar-avramovic
Copy link
Collaborator Author

ping

Copy link
Collaborator

@nhaehnle nhaehnle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a bunch of comments, but apart from that the change LGTM

// A == true jump to B
// A == false jump to end or Y, result is A(false) or Y
// (A || B || ... X) or (A || B || ... X) && Y
// A == true jump to end or Y, result is B(true) or Y
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// A == true jump to end or Y, result is B(true) or Y
// A == true jump to end or Y, result is A(true) or Y

unsigned FJumpOffset;
};

SmallVector<Elt, 8> Expression;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this and Elt private. Elt is quite subtle and should be hidden as an implementation detail.

Expression.push_back({Pred, false, 1, 1});
};

Predicate(SmallVectorImpl<Elt> &Expr) { Expression.swap(Expr); };
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find having a constructor that is destructive on its argument in this way to be quite surprising. It is against good patterns in modern C++. Better have it take an rvalue reference (&&-reference). Then its callers have to explicitly std::move, but that makes it more obvious what happens at the call sites.

Also, this constructor should be private just like Elt.

return Result;
};

Predicate operator!() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a const method.

return Predicate(NegExpression);
};

Predicate operator&&(Predicate &RHS) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the constructor above, the baseline of this should be a const method that accepts a const reference as an argument.

Do you have a good argument for the rvalue-argument overload below? If not, please just remove it.

The same applies to operator||.

@@ -290,7 +504,86 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
.Any({{UniS64, S32}, {{Sgpr64}, {Sgpr32}, Ext32To64}})
.Any({{DivS64, S32}, {{Vgpr64}, {Vgpr32}, Ext32To64}});

addRulesForGOpcs({G_LOAD}).Any({{DivS32, DivP1}, {{Vgpr32}, {VgprP1}}});
bool hasUnAlignedLoads = ST->getGeneration() >= AMDGPUSubtarget::GFX12;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"unaligned" is a single word:

Suggested change
bool hasUnAlignedLoads = ST->getGeneration() >= AMDGPUSubtarget::GFX12;
bool hasUnalignedLoads = ST->getGeneration() >= AMDGPUSubtarget::GFX12;

Comment on lines 208 to 223
if (Size / 128 == 2)
splitLoad(MI, {B128, B128});
if (Size / 128 == 4)
splitLoad(MI, {B128, B128, B128, B128});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect this to be an else-if chain with an else llvm_unreachable at the end.

Comment on lines +553 to +524
auto isUL = !isAtomicMMO && isUniMMO && (isConst || !isVolatileMMO) &&
(isConst || isInvMMO || isNoClobberMMO);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the logic behind the isConst || !isVolatileMMO part of this predicate? const && volatile doesn't make much sense to me, so why isn't this just !isVolatileMMO?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is copied from current implementation in AMDGPURegisterBankInfo::isScalarLoadLegal
isConst checks for address space in MMO, which can be different that address space of pointer operand

@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from da35db7 to 2b10761 Compare November 28, 2024 18:01
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 59e70ef to 75694f8 Compare November 28, 2024 18:01
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from 2b10761 to 15a9f49 Compare December 4, 2024 17:49
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 75694f8 to 3f80c88 Compare December 4, 2024 17:49
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from 15a9f49 to 4f1b347 Compare December 5, 2024 13:31
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 3f80c88 to f563cce Compare December 5, 2024 13:31
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from 4f1b347 to a93b3b6 Compare December 16, 2024 16:16
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from f563cce to 3fa31ae Compare December 16, 2024 16:17
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from a93b3b6 to d7060d4 Compare December 17, 2024 14:18
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 3fa31ae to 78ffc92 Compare December 17, 2024 14:18
Copy link
Collaborator

@nhaehnle nhaehnle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from d7060d4 to 5aed391 Compare January 13, 2025 10:57
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 78ffc92 to 23bb343 Compare January 13, 2025 10:57
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from 5aed391 to 42019ed Compare January 23, 2025 14:39
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 23bb343 to 0adced1 Compare January 23, 2025 14:39
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from 42019ed to 813ee0e Compare January 24, 2025 09:33
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 0adced1 to 8cb73f4 Compare January 24, 2025 09:33
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from 813ee0e to b95adf5 Compare January 24, 2025 10:22
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 8cb73f4 to 0030251 Compare January 24, 2025 10:22
Base automatically changed from users/petar-avramovic/new-rbs-rb-legalize to main January 24, 2025 11:12
Add IDs for bit width that cover multiple LLTs: B32 B64 etc.
"Predicate" wrapper class for bool predicate functions used to
write pretty rules. Predicates can be combined using &&, || and !.
Lowering for splitting and widening loads.
Write rules for loads to not change existing mir tests from old
regbankselect.
@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-load-rules branch from 0030251 to 654d07b Compare January 24, 2025 11:17
Copy link
Collaborator Author

petar-avramovic commented Jan 24, 2025

Merge activity

  • Jan 24, 6:34 AM EST: A user started a stack merge that includes this pull request via Graphite.
  • Jan 24, 6:36 AM EST: A user merged this pull request with Graphite.

@petar-avramovic petar-avramovic merged commit 4831fa8 into main Jan 24, 2025
5 of 7 checks passed
@petar-avramovic petar-avramovic deleted the users/petar-avramovic/new-rbs-rb-load-rules branch January 24, 2025 11:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants