Skip to content

[AMDGPU] Handle natively unsupported types in addrspace(7) lowering #110572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jan 20, 2025

Conversation

krzysz00
Copy link
Contributor

The current lowering for ptr addrspace(7) assumed that the instruction selector can handle arbtrary LLVM types, which is not the case. Code generation can't deal with

  • Values that aren't 8, 16, 32, 64, 96, or 128 bits long
  • Aggregates (this commit only handles arrays of scalars, more may come)
  • Vectors of more than one byte
  • 3-word values that aren't a vector of 3 32-bit values (for axample, a <6 x half>)

This commit adds a buffer contents type legalizer that adds the needed bitcasts, zero-extensions, and splits into subcompnents needed to convert a load or store operation into one that can be successfully lowered through code generation.

In the long run, some of the involved bitcasts (though potentially not the buffer operation splitting) ought to be handled by the instruction legalizer, but SelectionDAG makes this difficult.

It also takes advantage of the new nuw flag on getelementptr when lowering GEPs to offset additions.

We don't currently plumb through nsw on GEPs since that should likely be a separate change and would require declaring what we mean by "the address" in the context of the GEP guarantees.

@llvmbot
Copy link
Member

llvmbot commented Sep 30, 2024

@llvm/pr-subscribers-llvm-globalisel
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-amdgpu

Author: Krzysztof Drewniak (krzysz00)

Changes

The current lowering for ptr addrspace(7) assumed that the instruction selector can handle arbtrary LLVM types, which is not the case. Code generation can't deal with

  • Values that aren't 8, 16, 32, 64, 96, or 128 bits long
  • Aggregates (this commit only handles arrays of scalars, more may come)
  • Vectors of more than one byte
  • 3-word values that aren't a vector of 3 32-bit values (for axample, a <6 x half>)

This commit adds a buffer contents type legalizer that adds the needed bitcasts, zero-extensions, and splits into subcompnents needed to convert a load or store operation into one that can be successfully lowered through code generation.

In the long run, some of the involved bitcasts (though potentially not the buffer operation splitting) ought to be handled by the instruction legalizer, but SelectionDAG makes this difficult.

It also takes advantage of the new nuw flag on getelementptr when lowering GEPs to offset additions.

We don't currently plumb through nsw on GEPs since that should likely be a separate change and would require declaring what we mean by "the address" in the context of the GEP guarantees.


Patch is 307.63 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110572.diff

9 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp (+40-16)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h (+5-7)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp (+455-8)
  • (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+1-1)
  • (modified) llvm/lib/Transforms/Utils/Local.cpp (+12)
  • (added) llvm/test/CodeGen/AMDGPU/buffer-fat-pointers-contents-legalization.ll (+4871)
  • (modified) llvm/test/CodeGen/AMDGPU/lower-buffer-fat-pointers-calls.ll (+7-2)
  • (modified) llvm/test/CodeGen/AMDGPU/lower-buffer-fat-pointers-contents-legalization.ll (+346-84)
  • (modified) llvm/test/CodeGen/AMDGPU/lower-buffer-fat-pointers-unoptimized-debug-data.ll (+6-1)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index 271c8d45fd4a21..1da029444027e0 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -5794,8 +5794,9 @@ Register AMDGPULegalizerInfo::handleD16VData(MachineIRBuilder &B,
   return Reg;
 }
 
-Register AMDGPULegalizerInfo::fixStoreSourceType(
-  MachineIRBuilder &B, Register VData, bool IsFormat) const {
+Register AMDGPULegalizerInfo::fixStoreSourceType(MachineIRBuilder &B,
+                                                 Register VData, LLT MemTy,
+                                                 bool IsFormat) const {
   MachineRegisterInfo *MRI = B.getMRI();
   LLT Ty = MRI->getType(VData);
 
@@ -5805,6 +5806,10 @@ Register AMDGPULegalizerInfo::fixStoreSourceType(
   if (hasBufferRsrcWorkaround(Ty))
     return castBufferRsrcToV4I32(VData, B);
 
+  if (shouldBitcastLoadStoreType(ST, Ty, MemTy) || Ty.isPointerVector()) {
+    Ty = getBitcastRegisterType(Ty);
+    VData = B.buildBitcast(Ty, VData).getReg(0);
+  }
   // Fixup illegal register types for i8 stores.
   if (Ty == LLT::scalar(8) || Ty == S16) {
     Register AnyExt = B.buildAnyExt(LLT::scalar(32), VData).getReg(0);
@@ -5822,22 +5827,26 @@ Register AMDGPULegalizerInfo::fixStoreSourceType(
 }
 
 bool AMDGPULegalizerInfo::legalizeBufferStore(MachineInstr &MI,
-                                              MachineRegisterInfo &MRI,
-                                              MachineIRBuilder &B,
+                                              LegalizerHelper &Helper,
                                               bool IsTyped,
                                               bool IsFormat) const {
+  MachineIRBuilder &B = Helper.MIRBuilder;
+  MachineRegisterInfo &MRI = *B.getMRI();
+
   Register VData = MI.getOperand(1).getReg();
   LLT Ty = MRI.getType(VData);
   LLT EltTy = Ty.getScalarType();
   const bool IsD16 = IsFormat && (EltTy.getSizeInBits() == 16);
   const LLT S32 = LLT::scalar(32);
 
-  VData = fixStoreSourceType(B, VData, IsFormat);
-  castBufferRsrcArgToV4I32(MI, B, 2);
-  Register RSrc = MI.getOperand(2).getReg();
-
   MachineMemOperand *MMO = *MI.memoperands_begin();
   const int MemSize = MMO->getSize().getValue();
+  LLT MemTy = MMO->getMemoryType();
+
+  VData = fixStoreSourceType(B, VData, MemTy, IsFormat);
+
+  castBufferRsrcArgToV4I32(MI, B, 2);
+  Register RSrc = MI.getOperand(2).getReg();
 
   unsigned ImmOffset;
 
@@ -5930,10 +5939,13 @@ static void buildBufferLoad(unsigned Opc, Register LoadDstReg, Register RSrc,
 }
 
 bool AMDGPULegalizerInfo::legalizeBufferLoad(MachineInstr &MI,
-                                             MachineRegisterInfo &MRI,
-                                             MachineIRBuilder &B,
+                                             LegalizerHelper &Helper,
                                              bool IsFormat,
                                              bool IsTyped) const {
+  MachineIRBuilder &B = Helper.MIRBuilder;
+  MachineRegisterInfo &MRI = *B.getMRI();
+  GISelChangeObserver &Observer = Helper.Observer;
+
   // FIXME: Verifier should enforce 1 MMO for these intrinsics.
   MachineMemOperand *MMO = *MI.memoperands_begin();
   const LLT MemTy = MMO->getMemoryType();
@@ -5982,9 +5994,21 @@ bool AMDGPULegalizerInfo::legalizeBufferLoad(MachineInstr &MI,
   // Make addrspace 8 pointers loads into 4xs32 loads here, so the rest of the
   // logic doesn't have to handle that case.
   if (hasBufferRsrcWorkaround(Ty)) {
+    Observer.changingInstr(MI);
     Ty = castBufferRsrcFromV4I32(MI, B, MRI, 0);
+    Observer.changedInstr(MI);
     Dst = MI.getOperand(0).getReg();
+    B.setInsertPt(B.getMBB(), MI);
   }
+  if (shouldBitcastLoadStoreType(ST, Ty, MemTy) || Ty.isPointerVector()) {
+    Ty = getBitcastRegisterType(Ty);
+    Observer.changingInstr(MI);
+    Helper.bitcastDst(MI, Ty, 0);
+    Observer.changedInstr(MI);
+    Dst = MI.getOperand(0).getReg();
+    B.setInsertPt(B.getMBB(), MI);
+  }
+
   LLT EltTy = Ty.getScalarType();
   const bool IsD16 = IsFormat && (EltTy.getSizeInBits() == 16);
   const bool Unpacked = ST.hasUnpackedD16VMem();
@@ -7364,17 +7388,17 @@ bool AMDGPULegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
   case Intrinsic::amdgcn_raw_ptr_buffer_store:
   case Intrinsic::amdgcn_struct_buffer_store:
   case Intrinsic::amdgcn_struct_ptr_buffer_store:
-    return legalizeBufferStore(MI, MRI, B, false, false);
+    return legalizeBufferStore(MI, Helper, false, false);
   case Intrinsic::amdgcn_raw_buffer_store_format:
   case Intrinsic::amdgcn_raw_ptr_buffer_store_format:
   case Intrinsic::amdgcn_struct_buffer_store_format:
   case Intrinsic::amdgcn_struct_ptr_buffer_store_format:
-    return legalizeBufferStore(MI, MRI, B, false, true);
+    return legalizeBufferStore(MI, Helper, false, true);
   case Intrinsic::amdgcn_raw_tbuffer_store:
   case Intrinsic::amdgcn_raw_ptr_tbuffer_store:
   case Intrinsic::amdgcn_struct_tbuffer_store:
   case Intrinsic::amdgcn_struct_ptr_tbuffer_store:
-    return legalizeBufferStore(MI, MRI, B, true, true);
+    return legalizeBufferStore(MI, Helper, true, true);
   case Intrinsic::amdgcn_raw_buffer_load:
   case Intrinsic::amdgcn_raw_ptr_buffer_load:
   case Intrinsic::amdgcn_raw_atomic_buffer_load:
@@ -7383,17 +7407,17 @@ bool AMDGPULegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
   case Intrinsic::amdgcn_struct_ptr_buffer_load:
   case Intrinsic::amdgcn_struct_atomic_buffer_load:
   case Intrinsic::amdgcn_struct_ptr_atomic_buffer_load:
-    return legalizeBufferLoad(MI, MRI, B, false, false);
+    return legalizeBufferLoad(MI, Helper, false, false);
   case Intrinsic::amdgcn_raw_buffer_load_format:
   case Intrinsic::amdgcn_raw_ptr_buffer_load_format:
   case Intrinsic::amdgcn_struct_buffer_load_format:
   case Intrinsic::amdgcn_struct_ptr_buffer_load_format:
-    return legalizeBufferLoad(MI, MRI, B, true, false);
+    return legalizeBufferLoad(MI, Helper, true, false);
   case Intrinsic::amdgcn_raw_tbuffer_load:
   case Intrinsic::amdgcn_raw_ptr_tbuffer_load:
   case Intrinsic::amdgcn_struct_tbuffer_load:
   case Intrinsic::amdgcn_struct_ptr_tbuffer_load:
-    return legalizeBufferLoad(MI, MRI, B, true, true);
+    return legalizeBufferLoad(MI, Helper, true, true);
   case Intrinsic::amdgcn_raw_buffer_atomic_swap:
   case Intrinsic::amdgcn_raw_ptr_buffer_atomic_swap:
   case Intrinsic::amdgcn_struct_buffer_atomic_swap:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
index 84470dc75b60ef..86c15197805d23 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
@@ -195,15 +195,13 @@ class AMDGPULegalizerInfo final : public LegalizerInfo {
 
   Register handleD16VData(MachineIRBuilder &B, MachineRegisterInfo &MRI,
                           Register Reg, bool ImageStore = false) const;
-  Register fixStoreSourceType(MachineIRBuilder &B, Register VData,
+  Register fixStoreSourceType(MachineIRBuilder &B, Register VData, LLT MemTy,
                               bool IsFormat) const;
 
-  bool legalizeBufferStore(MachineInstr &MI, MachineRegisterInfo &MRI,
-                           MachineIRBuilder &B, bool IsTyped,
-                           bool IsFormat) const;
-  bool legalizeBufferLoad(MachineInstr &MI, MachineRegisterInfo &MRI,
-                          MachineIRBuilder &B, bool IsFormat,
-                          bool IsTyped) const;
+  bool legalizeBufferStore(MachineInstr &MI, LegalizerHelper &Helper,
+                           bool IsTyped, bool IsFormat) const;
+  bool legalizeBufferLoad(MachineInstr &MI, LegalizerHelper &Helper,
+                          bool IsFormat, bool IsTyped) const;
   bool legalizeBufferAtomic(MachineInstr &MI, MachineIRBuilder &B,
                             Intrinsic::ID IID) const;
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp b/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp
index 787747e6055805..831474c192526f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp
@@ -86,6 +86,25 @@
 // This phase also records intrinsics so that they can be remangled or deleted
 // later.
 //
+// ## Buffer contents type legalization
+//
+// The underlying buffer intrinsics only support types up to 128 bits long,
+// and don't support complex types. If buffer operations were
+// standard pointer operations that could be represented as MIR-level loads,
+// this would be handled by the various legalization schemes in instruction
+// selection. However, because we have to do the conversion from `load` and
+// `store` to intrinsics at LLVM IR level, we must perform that legalization
+// ourselves.
+//
+// This involves a combination of
+// - Converting arrays to vectors where possible
+// - Zero-extending things to fill a whole number of bytes
+// - Casting values of types that don't neatly correspond to supported machine
+// value
+//   (for example, an i96 or i256) into ones that would work (
+//    like <3 x i32> and <8 x i32>, respectively)
+// - Splitting values that are too long (such as aforementioned <8 x i32>) into
+//   multiple operations.
 //
 // ## Splitting pointer structs
 //
@@ -218,6 +237,7 @@
 #include "llvm/IR/ReplaceConstant.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/Alignment.h"
 #include "llvm/Support/AtomicOrdering.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -551,7 +571,6 @@ bool StoreFatPtrsAsIntsVisitor::visitLoadInst(LoadInst &LI) {
   auto *NLI = cast<LoadInst>(LI.clone());
   NLI->mutateType(IntTy);
   NLI = IRB.Insert(NLI);
-  copyMetadataForLoad(*NLI, LI);
   NLI->takeName(&LI);
 
   Value *CastBack = intsToFatPtrs(NLI, IntTy, Ty, NLI->getName());
@@ -576,6 +595,434 @@ bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) {
   return true;
 }
 
+namespace {
+/// Convert loads/stores of types that the buffer intrinsics can't handle into
+/// one ore more such loads/stores that consist of legal types.
+///
+/// Do this by
+/// 1. Converting arrays of non-aggregate, byte-sized types into their
+/// correspondinng vectors
+/// 2. Bitcasting unsupported types, namely overly-long scalars and byte
+/// vectors, into vectors of supported types.
+/// 3. Splitting up excessively long reads/writes into multiple operations.
+///
+/// Note that this doesn't handle complex data strucures, but, in the future,
+/// the aggregate load splitter from SROA could be refactored to allow for that
+/// case.
+class LegalizeBufferContentTypesVisitor
+    : public InstVisitor<LegalizeBufferContentTypesVisitor, bool> {
+  friend class InstVisitor<LegalizeBufferContentTypesVisitor, bool>;
+
+  IRBuilder<> IRB;
+
+  const DataLayout &DL;
+
+  /// If T is [N x U], where U is a scalar type, return the vector type
+  /// <N x U>, otherwise, return T.
+  Type *scalarArrayTypeAsVector(Type *MaybeArrayType);
+  Value *arrayToVector(Value *V, Type *TargetType, StringRef Name);
+  Value *vectorToArray(Value *V, Type *OrigType, StringRef Name);
+
+  /// Break up the loads of a struct into the loads of its components
+
+  /// Convert a vector or scalar type that can't be operated on by buffer
+  /// intrinsics to one that would be legal through bitcasts and/or truncation.
+  /// Uses the wider of i32, i16, or i8 where possible.
+  Type *legalNonAggregateFor(Type *T);
+  Value *makeLegalNonAggregate(Value *V, Type *TargetType, StringRef Name);
+  Value *makeIllegalNonAggregate(Value *V, Type *OrigType, StringRef Name);
+
+  struct Slice {
+    unsigned Offset;
+    unsigned Length;
+    Slice(unsigned Offset, unsigned Length) : Offset(Offset), Length(Length) {}
+  };
+  // Return the [offset, length] pairs into which `T` needs to be cut to form
+  // legal buffer load or store operations. Clears `Slices`. Creates an empty
+  // `Slices` for non-vector inputs and creates one slice if no slicing will be
+  // needed.
+  void getSlices(Type *T, SmallVectorImpl<Slice> &Slices);
+
+  Value *extractSlice(Value *Vec, Slice S, StringRef Name);
+  Value *insertSlice(Value *Whole, Value *Part, Slice S, StringRef Name);
+
+  // In most cases, return `LegalType`. However, when given an input that would
+  // normally be a legal type for the buffer intrinsics to return but that isn't
+  // hooked up through SelectionDAG, return a type of the same width that can be
+  // used with the relevant intrinsics. Specifically, handle the cases:
+  // - <1 x T> => T for all T
+  // - <N x i8> <=> i16, i32, 2xi32, 4xi32 (as needed)
+  // - <N x T> where T is under 32 bits and the total size is 96 bits <=> <3 x
+  // i32>
+  Type *intrinsicTypeFor(Type *LegalType);
+
+  bool visitInstruction(Instruction &I) { return false; }
+  bool visitLoadInst(LoadInst &LI);
+  bool visitStoreInst(StoreInst &SI);
+
+public:
+  LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx)
+      : IRB(Ctx), DL(DL) {}
+  bool processFunction(Function &F);
+};
+} // namespace
+
+Type *LegalizeBufferContentTypesVisitor::scalarArrayTypeAsVector(Type *T) {
+  ArrayType *AT = dyn_cast<ArrayType>(T);
+  if (!AT)
+    return T;
+  Type *ET = AT->getElementType();
+  if (!ET->isSingleValueType() || isa<VectorType>(ET))
+    report_fatal_error(
+        "loading non-scalar arrays from buffer fat pointers is unimplemented");
+  if (!DL.typeSizeEqualsStoreSize(AT))
+    report_fatal_error(
+        "loading padded arrays from buffer fat pinters is unimplemented");
+  return FixedVectorType::get(ET, AT->getNumElements());
+}
+
+Value *LegalizeBufferContentTypesVisitor::arrayToVector(Value *V,
+                                                        Type *TargetType,
+                                                        StringRef Name) {
+  Value *VectorRes = PoisonValue::get(TargetType);
+  auto *VT = cast<FixedVectorType>(TargetType);
+  unsigned EC = VT->getNumElements();
+  for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
+    Value *Elem = IRB.CreateExtractValue(V, I, Name + ".elem." + Twine(I));
+    VectorRes = IRB.CreateInsertElement(VectorRes, Elem, I,
+                                        Name + ".as.vec." + Twine(I));
+  }
+  return VectorRes;
+}
+
+Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V,
+                                                        Type *OrigType,
+                                                        StringRef Name) {
+  Value *ArrayRes = PoisonValue::get(OrigType);
+  ArrayType *AT = cast<ArrayType>(OrigType);
+  unsigned EC = AT->getNumElements();
+  for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
+    Value *Elem = IRB.CreateExtractElement(V, I, Name + ".elem." + Twine(I));
+    ArrayRes = IRB.CreateInsertValue(ArrayRes, Elem, I,
+                                     Name + ".as.array." + Twine(I));
+  }
+  return ArrayRes;
+}
+
+Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) {
+  TypeSize Size = DL.getTypeStoreSizeInBits(T);
+  // Implicitly zero-extend to the next byte if needed
+  if (!DL.typeSizeEqualsStoreSize(T))
+    T = IRB.getIntNTy(Size.getFixedValue());
+  auto *VT = dyn_cast<VectorType>(T);
+  Type *ElemTy = T;
+  if (VT) {
+    ElemTy = VT->getElementType();
+  }
+  if (isa<PointerType>(ElemTy))
+    return T; // Pointers are always big enough
+  unsigned ElemSize = DL.getTypeSizeInBits(ElemTy).getFixedValue();
+  if (isPowerOf2_32(ElemSize) && ElemSize >= 16 && ElemSize <= 128) {
+    // [vectors of] anything that's 16/32/64/128 bits can be cast and split into
+    // legal buffer operations.
+    return T;
+  }
+  Type *BestVectorElemType = nullptr;
+  if (Size.isKnownMultipleOf(32))
+    BestVectorElemType = IRB.getInt32Ty();
+  else if (Size.isKnownMultipleOf(16))
+    BestVectorElemType = IRB.getInt16Ty();
+  else
+    BestVectorElemType = IRB.getInt8Ty();
+  unsigned NumCastElems =
+      Size.getFixedValue() / BestVectorElemType->getIntegerBitWidth();
+  if (NumCastElems == 1)
+    return BestVectorElemType;
+  return FixedVectorType::get(BestVectorElemType, NumCastElems);
+}
+
+Value *LegalizeBufferContentTypesVisitor::makeLegalNonAggregate(
+    Value *V, Type *TargetType, StringRef Name) {
+  Type *SourceType = V->getType();
+  if (DL.getTypeSizeInBits(SourceType) != DL.getTypeSizeInBits(TargetType)) {
+    Type *ShortScalarTy =
+        IRB.getIntNTy(DL.getTypeSizeInBits(SourceType).getFixedValue());
+    Type *ByteScalarTy =
+        IRB.getIntNTy(DL.getTypeSizeInBits(TargetType).getFixedValue());
+    Value *AsScalar = IRB.CreateBitCast(V, ShortScalarTy, Name + ".as.scalar");
+    Value *Zext = IRB.CreateZExt(AsScalar, ByteScalarTy, Name + ".zext");
+    V = Zext;
+    SourceType = ByteScalarTy;
+  }
+  if (SourceType == TargetType)
+    return V;
+  return IRB.CreateBitCast(V, TargetType, Name + ".legal");
+}
+
+Value *LegalizeBufferContentTypesVisitor::makeIllegalNonAggregate(
+    Value *V, Type *OrigType, StringRef Name) {
+  Type *LegalType = V->getType();
+  if (DL.getTypeSizeInBits(LegalType) != DL.getTypeSizeInBits(OrigType)) {
+    Type *ShortScalarTy =
+        IRB.getIntNTy(DL.getTypeSizeInBits(OrigType).getFixedValue());
+    Type *ByteScalarTy =
+        IRB.getIntNTy(DL.getTypeSizeInBits(LegalType).getFixedValue());
+    Value *AsScalar = IRB.CreateBitCast(V, ByteScalarTy, Name + ".bytes.cast");
+    Value *Trunc = IRB.CreateTrunc(AsScalar, ShortScalarTy, Name + ".trunc");
+    if (OrigType != ShortScalarTy)
+      return IRB.CreateBitCast(Trunc, OrigType, Name + ".orig");
+    return Trunc;
+  }
+  if (LegalType == OrigType)
+    return V;
+  return IRB.CreateBitCast(V, OrigType, Name + ".real.ty");
+}
+
+Type *LegalizeBufferContentTypesVisitor::intrinsicTypeFor(Type *LegalType) {
+  auto *VT = dyn_cast<FixedVectorType>(LegalType);
+  if (!VT)
+    return LegalType;
+  Type *ET = VT->getElementType();
+  if (VT->getNumElements() == 1)
+    return ET;
+  if (DL.getTypeSizeInBits(LegalType) == 96 && DL.getTypeSizeInBits(ET) < 32)
+    return FixedVectorType::get(IRB.getInt32Ty(), 3);
+  if (ET->isIntegerTy(8)) {
+    switch (VT->getNumElements()) {
+    default:
+      return LegalType; // Let it crash later
+    case 1:
+      return IRB.getInt8Ty();
+    case 2:
+      return IRB.getInt16Ty();
+    case 4:
+      return IRB.getInt32Ty();
+    case 8:
+      return FixedVectorType::get(IRB.getInt32Ty(), 2);
+    case 16:
+      return FixedVectorType::get(IRB.getInt32Ty(), 4);
+    }
+  }
+  return LegalType;
+}
+
+void LegalizeBufferContentTypesVisitor::getSlices(
+    Type *T, SmallVectorImpl<Slice> &Slices) {
+  Slices.clear();
+  auto *VT = dyn_cast<FixedVectorType>(T);
+  if (!VT)
+    return;
+
+  unsigned ElemBitWidth =
+      DL.getTypeSizeInBits(VT->getElementType()).getFixedValue();
+
+  unsigned ElemsPer4Words = 128 / ElemBitWidth;
+  unsigned ElemsPer2Words = ElemsPer4Words / 2;
+  unsigned ElemsPerWord = ElemsPer2Words / 2;
+  unsigned ElemsPerShort = ElemsPerWord / 2;
+  unsigned ElemsPerByte = ElemsPerShort / 2;
+  // If the elements evenly pack into 32-bit words, we can use 3-word stores,
+  // such as for <6 x bfloat> or <3 x i32>, but we can't dot his for, for
+  // example, <3 x i64>, since that's not slicing.
+  unsigned ElemsPer3Words = ElemsPerWord * 3;
+
+  unsigned TotalElems = VT->getNumElements();
+  unsigned Off = 0;
+  auto TrySlice = [&](unsigned MaybeLen) {
+    if (MaybeLen > 0 && Off + MaybeLen <= TotalElems) {
+      Slices.emplace_back(/*Offset=*/Off, /*Length=*/MaybeLen);
+      Off += MaybeLen;
+      return true;
+    }
+    return false;
+  };
+  while (Off < TotalElems) {
+    TrySlice(ElemsPer4Words) || TrySlice(ElemsPer3Words) ||
+        TrySlice(ElemsPer2Words) || TrySlice(ElemsPerWord) ||
+        TrySlice(ElemsPerShort) || TrySlice(ElemsPerByte);
+  }
+}
+
+Value *LegalizeBufferContentTypesVisitor::extractSlice(Value *Vec, Slice S,
+                                                       StringRef Name) {
+  if (S.Length == 1)
+    return IRB.CreateExtractElement(Vec, S.Offset,
+                                    Name + ".slice." + Twine(S.Offset));
+  SmallVector<int> Mask = llvm::to_vector(llvm::iota_range<int>(
+      S.Offset, S.Offset + S.Length, /*Inclusive=*/false));
+  return IRB.CreateShuffleVector(Vec, Mask, Name + ".slice." + Twin...
[truncated]

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some tests for the error unsupported case, and add some scalable vectors to make sure those don't blow up

// Implicitly zero-extend to the next byte if needed
if (!DL.typeSizeEqualsStoreSize(T))
T = IRB.getIntNTy(Size.getFixedValue());
auto *VT = dyn_cast<VectorType>(T);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cast to FixedVectorType, it's assumed below anyway. I would expect scalable vectors to be ignored and pass through as-is

@krzysz00 krzysz00 force-pushed the buffer-load-all branch 2 times, most recently from 9ea6241 to 5a728c4 Compare October 3, 2024 20:42
@krzysz00 krzysz00 requested a review from arsenm October 4, 2024 17:45
The current lowering for ptr addrspace(7) assumed that the instruction
selector can handle arbtrary LLVM types, which is not the case. Code
generation can't deal with
- Values that aren't 8, 16, 32, 64, 96, or 128 bits long
- Aggregates (this commit only handles arrays of scalars, more may come)
- Vectors of more than one byte
- 3-word values that aren't a vector of 3 32-bit values (for axample, a
  <6 x half>)

This commit adds a buffer contents type legalizer that adds the needed
bitcasts, zero-extensions, and splits into subcompnents needed to convert a
load or store operation into one that can be successfully lowered through
code generation.

In the long run, some of the involved bitcasts (though potentially not
the buffer operation splitting) ought to be handled by the instruction
legalizer, but SelectionDAG makes this difficult.

It also takes advantage of the new `nuw` flag on `getelementptr` when
lowering GEPs to offset additions.

We don't currently plumb through `nsw` on GEPs since that should likely
be a separate change and would require declaring what we mean by
"the address" in the context of the GEP guarantees.
return VectorRes;
}

Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like the backwards direction to go?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How so?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we'd rather have IR vectors than IR arrays. They're directly codegenable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I have the same question

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, but, suppose that someone has written

%x = load [2 x i32], ptr addrspace(7) %p

We need to codegen this, and the underlying buffer intrinsics don't support arrays

So, we rewrite this to

%x.vec = load <2 x i32>, ptr addrspace(7) %p
%x = [the pile of extractvalue and insertelement needed to make a <2 x i32> into [2 x i32](%x.vec)

vectorToArray() is responsible for recovering an array from a loaded vector value, just like arrayToVector() takes a array that will be stored and turns it into its corresponding value.

(As to why arrays are supported ... LLPC had handling for it, and they've probably got good reason)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is for restoring all the uses later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

auto *NLI = cast<LoadInst>(OrigLI.clone());
NLI->mutateType(LoadableType);
NLI = IRB.Insert(NLI);
NLI->setName(Name + ".loadable");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just take name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've got one later, and this'll let us put the takeName on any terminal bitcasts and the like

Type *ElemTy = AT->getElementType();
TypeSize AllocSize = DL.getTypeAllocSize(ElemTy);
if (!(ElemTy->isSingleValueType() &&
DL.getTypeSizeInBits(ElemTy) == 8 * AllocSize &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getTypeStoreSize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to catch stuff like [8 x i4] here, which, if I've understood the ABI right, stores differently than <8 x i4>. And ... are there cases where the alloc size and the store size aren't equal?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-byte vector elements has never been handled consistently. The alloc size and store size are different when the alignment is higher. The signature example would be 3 element vectors have 4 element allocation size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah

So what's the condition for being able to replace a store of a [N x T] with a store of a <N x T> (or a load)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignoring the non-byte element case, where I'm not sure what the bit layout is, you can always do it if you don't rely on the natural alignment. If you use an explicit alignment it should be fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And since, last I checked, we don't have hard alignment requirements ... I'll simplify this

Copy link

github-actions bot commented Oct 31, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@krzysz00 krzysz00 requested a review from arsenm November 4, 2024 19:23
return VectorRes;
}

Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we'd rather have IR vectors than IR arrays. They're directly codegenable

@krzysz00
Copy link
Contributor Author

krzysz00 commented Nov 5, 2024

Re the arrays, vectorToArray is for having loaded a vector and turning it back into an array so that you can RAUW in that array.

@krzysz00
Copy link
Contributor Author

krzysz00 commented Nov 5, 2024

Maybe arrayFromVector would make it clearer?

@krzysz00 krzysz00 requested a review from arsenm November 6, 2024 17:44
@krzysz00 krzysz00 requested a review from shiltian November 20, 2024 17:53
@krzysz00
Copy link
Contributor Author

krzysz00 commented Jan 7, 2025

@arsenm Are there any major comments I've missed?

target triple = "amdgcn--"

;;; Legal types. These are natively supported, no casts should be performed.

define i8 @load_i8(ptr addrspace(8) %buf) {
define i8 @load_i8(ptr addrspace(8) inreg %buf) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding all these inregs is a separate change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Oh, whoops, missed this comment. Should I revert this off?

@krzysz00 krzysz00 merged commit 3805355 into llvm:main Jan 20, 2025
8 checks passed
krzysz00 added a commit to krzysz00/llvm-project that referenced this pull request Jan 20, 2025
shiltian pushed a commit that referenced this pull request Jan 20, 2025
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Jan 20, 2025
searlmc1 pushed a commit to ROCm/llvm-project that referenced this pull request Mar 24, 2025
…lvm#110572)

The current lowering for ptr addrspace(7) assumed that the instruction
selector can handle arbtrary LLVM types, which is not the case. Code
generation can't deal with
- Values that aren't 8, 16, 32, 64, 96, or 128 bits long
- Aggregates (this commit only handles arrays of scalars, more may come)
- Vectors of more than one byte
- 3-word values that aren't a vector of 3 32-bit values (for axample, a
<6 x half>)

This commit adds a buffer contents type legalizer that adds the needed
bitcasts, zero-extensions, and splits into subcompnents needed to
convert a load or store operation into one that can be successfully
lowered through code generation.

In the long run, some of the involved bitcasts (though potentially not
the buffer operation splitting) ought to be handled by the instruction
legalizer, but SelectionDAG makes this difficult.

It also takes advantage of the new `nuw` flag on `getelementptr` when
lowering GEPs to offset additions.

We don't currently plumb through `nsw` on GEPs since that should likely
be a separate change and would require declaring what we mean by "the
address" in the context of the GEP guarantees.
searlmc1 pushed a commit to ROCm/llvm-project that referenced this pull request Mar 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants