Skip to content

Reapply "[AMDGPU] Handle memcpy()-like ops in LowerBufferFatPointers (#126621)" #129078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 98 additions & 24 deletions llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@
//
// This pass proceeds in three main phases:
//
// ## Rewriting loads and stores of p7
// ## Rewriting loads and stores of p7 and memcpy()-like handling
//
// The first phase is to rewrite away all loads and stors of `ptr addrspace(7)`,
// including aggregates containing such pointers, to ones that use `i160`. This
// is handled by `StoreFatPtrsAsIntsVisitor` , which visits loads, stores, and
// allocas and, if the loaded or stored type contains `ptr addrspace(7)`,
// rewrites that type to one where the p7s are replaced by i160s, copying other
// parts of aggregates as needed. In the case of a store, each pointer is
// `ptrtoint`d to i160 before storing, and load integers are `inttoptr`d back.
// This same transformation is applied to vectors of pointers.
// is handled by `StoreFatPtrsAsIntsAndExpandMemcpyVisitor` , which visits
// loads, stores, and allocas and, if the loaded or stored type contains `ptr
// addrspace(7)`, rewrites that type to one where the p7s are replaced by i160s,
// copying other parts of aggregates as needed. In the case of a store, each
// pointer is `ptrtoint`d to i160 before storing, and load integers are
// `inttoptr`d back. This same transformation is applied to vectors of pointers.
//
// Such a transformation allows the later phases of the pass to not need
// to handle buffer fat pointers moving to and from memory, where we load
Expand All @@ -66,6 +66,10 @@
// Atomics operations on `ptr addrspace(7)` values are not suppported, as the
// hardware does not include a 160-bit atomic.
//
// In order to save on O(N) work and to ensure that the contents type
// legalizer correctly splits up wide loads, also unconditionally lower
// memcpy-like intrinsics into loops here.
//
// ## Buffer contents type legalization
//
// The underlying buffer intrinsics only support types up to 128 bits long,
Expand Down Expand Up @@ -231,20 +235,24 @@
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstVisitor.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/IR/ValueHandle.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/AMDGPUAddrSpace.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/AtomicOrdering.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
#include "llvm/Transforms/Utils/ValueMapper.h"

#define DEBUG_TYPE "amdgpu-lower-buffer-fat-pointers"
Expand Down Expand Up @@ -431,14 +439,16 @@ namespace {
/// marshalling costs when reading or storing these values, but since placing
/// such pointers into memory is an uncommon operation at best, we feel that
/// this cost is acceptable for better performance in the common case.
class StoreFatPtrsAsIntsVisitor
: public InstVisitor<StoreFatPtrsAsIntsVisitor, bool> {
class StoreFatPtrsAsIntsAndExpandMemcpyVisitor
: public InstVisitor<StoreFatPtrsAsIntsAndExpandMemcpyVisitor, bool> {
BufferFatPtrToIntTypeMap *TypeMap;

ValueToValueMapTy ConvertedForStore;

IRBuilder<> IRB;

const TargetMachine *TM;

// Convert all the buffer fat pointers within the input value to inttegers
// so that it can be stored in memory.
Value *fatPtrsToInts(Value *V, Type *From, Type *To, const Twine &Name);
Expand All @@ -448,20 +458,27 @@ class StoreFatPtrsAsIntsVisitor
Value *intsToFatPtrs(Value *V, Type *From, Type *To, const Twine &Name);

public:
StoreFatPtrsAsIntsVisitor(BufferFatPtrToIntTypeMap *TypeMap, LLVMContext &Ctx)
: TypeMap(TypeMap), IRB(Ctx) {}
StoreFatPtrsAsIntsAndExpandMemcpyVisitor(BufferFatPtrToIntTypeMap *TypeMap,
LLVMContext &Ctx,
const TargetMachine *TM)
: TypeMap(TypeMap), IRB(Ctx), TM(TM) {}
bool processFunction(Function &F);

bool visitInstruction(Instruction &I) { return false; }
bool visitAllocaInst(AllocaInst &I);
bool visitLoadInst(LoadInst &LI);
bool visitStoreInst(StoreInst &SI);
bool visitGetElementPtrInst(GetElementPtrInst &I);

bool visitMemCpyInst(MemCpyInst &MCI);
bool visitMemMoveInst(MemMoveInst &MMI);
bool visitMemSetInst(MemSetInst &MSI);
bool visitMemSetPatternInst(MemSetPatternInst &MSPI);
};
} // namespace

Value *StoreFatPtrsAsIntsVisitor::fatPtrsToInts(Value *V, Type *From, Type *To,
const Twine &Name) {
Value *StoreFatPtrsAsIntsAndExpandMemcpyVisitor::fatPtrsToInts(
Value *V, Type *From, Type *To, const Twine &Name) {
if (From == To)
return V;
ValueToValueMapTy::iterator Find = ConvertedForStore.find(V);
Expand Down Expand Up @@ -498,8 +515,8 @@ Value *StoreFatPtrsAsIntsVisitor::fatPtrsToInts(Value *V, Type *From, Type *To,
return Ret;
}

Value *StoreFatPtrsAsIntsVisitor::intsToFatPtrs(Value *V, Type *From, Type *To,
const Twine &Name) {
Value *StoreFatPtrsAsIntsAndExpandMemcpyVisitor::intsToFatPtrs(
Value *V, Type *From, Type *To, const Twine &Name) {
if (From == To)
return V;
if (isBufferFatPtrOrVector(To)) {
Expand Down Expand Up @@ -531,18 +548,25 @@ Value *StoreFatPtrsAsIntsVisitor::intsToFatPtrs(Value *V, Type *From, Type *To,
return Ret;
}

bool StoreFatPtrsAsIntsVisitor::processFunction(Function &F) {
bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::processFunction(Function &F) {
bool Changed = false;
// The visitors will mutate GEPs and allocas, but will push loads and stores
// to the worklist to avoid invalidation.
// Process memcpy-like instructions after the main iteration because they can
// invalidate iterators.
SmallVector<WeakTrackingVH> CanBecomeLoops;
for (Instruction &I : make_early_inc_range(instructions(F))) {
Changed |= visit(I);
if (isa<MemTransferInst, MemSetInst, MemSetPatternInst>(I))
CanBecomeLoops.push_back(&I);
else
Changed |= visit(I);
}
for (WeakTrackingVH VH : make_early_inc_range(CanBecomeLoops)) {
Changed |= visit(cast<Instruction>(VH));
}
ConvertedForStore.clear();
return Changed;
}

bool StoreFatPtrsAsIntsVisitor::visitAllocaInst(AllocaInst &I) {
bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitAllocaInst(AllocaInst &I) {
Type *Ty = I.getAllocatedType();
Type *NewTy = TypeMap->remapType(Ty);
if (Ty == NewTy)
Expand All @@ -551,7 +575,8 @@ bool StoreFatPtrsAsIntsVisitor::visitAllocaInst(AllocaInst &I) {
return true;
}

bool StoreFatPtrsAsIntsVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitGetElementPtrInst(
GetElementPtrInst &I) {
Type *Ty = I.getSourceElementType();
Type *NewTy = TypeMap->remapType(Ty);
if (Ty == NewTy)
Expand All @@ -563,7 +588,7 @@ bool StoreFatPtrsAsIntsVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
return true;
}

bool StoreFatPtrsAsIntsVisitor::visitLoadInst(LoadInst &LI) {
bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitLoadInst(LoadInst &LI) {
Type *Ty = LI.getType();
Type *IntTy = TypeMap->remapType(Ty);
if (Ty == IntTy)
Expand All @@ -581,7 +606,7 @@ bool StoreFatPtrsAsIntsVisitor::visitLoadInst(LoadInst &LI) {
return true;
}

bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) {
bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitStoreInst(StoreInst &SI) {
Value *V = SI.getValueOperand();
Type *Ty = V->getType();
Type *IntTy = TypeMap->remapType(Ty);
Expand All @@ -597,6 +622,47 @@ bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) {
return true;
}

bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemCpyInst(
MemCpyInst &MCI) {
// TODO: Allow memcpy.p7.p3 as a synonym for the direct-to-LDS copy, which'll
// need loop expansion here.
if (MCI.getSourceAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER &&
MCI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
return false;
llvm::expandMemCpyAsLoop(&MCI,
TM->getTargetTransformInfo(*MCI.getFunction()));
MCI.eraseFromParent();
return true;
}

bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemMoveInst(
MemMoveInst &MMI) {
if (MMI.getSourceAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER &&
MMI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
return false;
report_fatal_error(
"memmove() on buffer descriptors is not implemented because pointer "
"comparison on buffer descriptors isn't implemented\n");
}

bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemSetInst(
MemSetInst &MSI) {
if (MSI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
return false;
llvm::expandMemSetAsLoop(&MSI);
MSI.eraseFromParent();
return true;
}

bool StoreFatPtrsAsIntsAndExpandMemcpyVisitor::visitMemSetPatternInst(
MemSetPatternInst &MSPI) {
if (MSPI.getDestAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
return false;
llvm::expandMemSetPatternAsLoop(&MSPI);
MSPI.eraseFromParent();
return true;
}

namespace {
/// Convert loads/stores of types that the buffer intrinsics can't handle into
/// one ore more such loads/stores that consist of legal types.
Expand Down Expand Up @@ -1127,6 +1193,7 @@ bool LegalizeBufferContentTypesVisitor::visitStoreInst(StoreInst &SI) {

bool LegalizeBufferContentTypesVisitor::processFunction(Function &F) {
bool Changed = false;
// Note, memory transfer intrinsics won't
for (Instruction &I : make_early_inc_range(instructions(F))) {
Changed |= visit(I);
}
Expand Down Expand Up @@ -2084,6 +2151,12 @@ static bool isRemovablePointerIntrinsic(Intrinsic::ID IID) {
case Intrinsic::invariant_end:
case Intrinsic::launder_invariant_group:
case Intrinsic::strip_invariant_group:
case Intrinsic::memcpy:
case Intrinsic::memcpy_inline:
case Intrinsic::memmove:
case Intrinsic::memset:
case Intrinsic::memset_inline:
case Intrinsic::experimental_memset_pattern:
return true;
}
}
Expand Down Expand Up @@ -2353,7 +2426,8 @@ bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
/*RemoveDeadConstants=*/false, /*IncludeSelf=*/true);
}

StoreFatPtrsAsIntsVisitor MemOpsRewrite(&IntTM, M.getContext());
StoreFatPtrsAsIntsAndExpandMemcpyVisitor MemOpsRewrite(&IntTM, M.getContext(),
&TM);
LegalizeBufferContentTypesVisitor BufferContentsTypeRewrite(DL,
M.getContext());
for (Function &F : M.functions()) {
Expand Down
Loading