Skip to content

[AMDGPULowerBufferFatPointers] Expand const exprs using fat pointers #95558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion llvm/include/llvm/IR/ReplaceConstant.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ class Function;
/// RemoveDeadConstants by default will remove all dead constants as
/// the final step of the function after replacement, when passed
/// false it will skip this final step.
///
/// If \p IncludeSelf is enabled, also convert the passed constants themselves
/// to instructions, rather than only their users.
bool convertUsersOfConstantsToInstructions(ArrayRef<Constant *> Consts,
Function *RestrictToFunc = nullptr,
bool RemoveDeadConstants = true);
bool RemoveDeadConstants = true,
bool IncludeSelf = false);

} // end namespace llvm

Expand Down
17 changes: 12 additions & 5 deletions llvm/lib/IR/ReplaceConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,20 @@ static SmallVector<Instruction *, 4> expandUser(BasicBlock::iterator InsertPt,

bool convertUsersOfConstantsToInstructions(ArrayRef<Constant *> Consts,
Function *RestrictToFunc,
bool RemoveDeadConstants) {
bool RemoveDeadConstants,
bool IncludeSelf) {
// Find all expandable direct users of Consts.
SmallVector<Constant *> Stack;
for (Constant *C : Consts)
for (User *U : C->users())
if (isExpandableUser(U))
Stack.push_back(cast<Constant>(U));
for (Constant *C : Consts) {
if (IncludeSelf) {
assert(isExpandableUser(C) && "One of the constants is not expandable");
Stack.push_back(C);
} else {
for (User *U : C->users())
if (isExpandableUser(U))
Stack.push_back(cast<Constant>(U));
}
}

// Include transitive users.
SetVector<Constant *> ExpandableUsers;
Expand Down
202 changes: 45 additions & 157 deletions llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/AtomicOrdering.h"
Expand Down Expand Up @@ -579,18 +580,14 @@ bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) {
/// buffer fat pointer constant.
static std::pair<Constant *, Constant *>
splitLoweredFatBufferConst(Constant *C) {
if (auto *AZ = dyn_cast<ConstantAggregateZero>(C))
return std::make_pair(AZ->getStructElement(0), AZ->getStructElement(1));
if (auto *SC = dyn_cast<ConstantStruct>(C))
return std::make_pair(SC->getOperand(0), SC->getOperand(1));
llvm_unreachable("Conversion should've created a {p8, i32} struct");
assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer");
return std::make_pair(C->getAggregateElement(0u), C->getAggregateElement(1u));
}

namespace {
/// Handle the remapping of ptr addrspace(7) constants.
class FatPtrConstMaterializer final : public ValueMaterializer {
BufferFatPtrToStructTypeMap *TypeMap;
BufferFatPtrToIntTypeMap *IntTypeMap;
// An internal mapper that is used to recurse into the arguments of constants.
// While the documentation for `ValueMapper` specifies not to use it
// recursively, examination of the logic in mapValue() shows that it can
Expand All @@ -600,16 +597,12 @@ class FatPtrConstMaterializer final : public ValueMaterializer {

Constant *materializeBufferFatPtrConst(Constant *C);

const DataLayout &DL;

public:
// UnderlyingMap is the value map this materializer will be filling.
FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap,
ValueToValueMapTy &UnderlyingMap,
BufferFatPtrToIntTypeMap *IntTypeMap,
const DataLayout &DL)
: TypeMap(TypeMap), IntTypeMap(IntTypeMap),
InternalMapper(UnderlyingMap, RF_None, TypeMap, this), DL(DL) {}
ValueToValueMapTy &UnderlyingMap)
: TypeMap(TypeMap),
InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {}
virtual ~FatPtrConstMaterializer() = default;

Value *materialize(Value *V) override;
Expand All @@ -632,10 +625,6 @@ Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
UndefValue::get(NewTy->getElementType(1))});
}

if (isa<GlobalValue>(C))
report_fatal_error("Global values containing ptr addrspace(7) (buffer "
"fat pointer) values are not supported");

if (auto *VC = dyn_cast<ConstantVector>(C)) {
if (Constant *S = VC->getSplatValue()) {
Constant *NewS = InternalMapper.mapConstant(*S);
Expand All @@ -661,147 +650,21 @@ Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
return ConstantStruct::get(NewTy, {RsrcVec, OffVec});
}

// Constant expressions. This code mirrors how we fix up the equivalent
// instructions later.
auto *CE = dyn_cast<ConstantExpr>(C);
if (!CE)
return nullptr;
if (auto *GEPO = dyn_cast<GEPOperator>(C)) {
Constant *RemappedPtr =
InternalMapper.mapConstant(*cast<Constant>(GEPO->getPointerOperand()));
auto [Rsrc, Off] = splitLoweredFatBufferConst(RemappedPtr);
Type *OffTy = Off->getType();
bool InBounds = GEPO->isInBounds();

MapVector<Value *, APInt> VariableOffs;
APInt NewConstOffVal = APInt::getZero(BufferOffsetWidth);
if (!GEPO->collectOffset(DL, BufferOffsetWidth, VariableOffs,
NewConstOffVal))
report_fatal_error(
"Scalable vector or unsized struct in fat pointer GEP");
Constant *OffAccum = nullptr;
for (auto [Arg, Multiple] : VariableOffs) {
Constant *NewArg = InternalMapper.mapConstant(*cast<Constant>(Arg));
NewArg = ConstantFoldIntegerCast(NewArg, OffTy, /*IsSigned=*/true, DL);
if (!Multiple.isOne()) {
if (Multiple.isPowerOf2()) {
NewArg = ConstantExpr::getShl(
NewArg, CE->getIntegerValue(OffTy, APInt(BufferOffsetWidth,
Multiple.logBase2())));
} else {
NewArg = ConstantExpr::getMul(NewArg,
CE->getIntegerValue(OffTy, Multiple));
}
}
if (OffAccum) {
OffAccum = ConstantExpr::getAdd(OffAccum, NewArg);
} else {
OffAccum = NewArg;
}
}
Constant *NewConstOff = CE->getIntegerValue(OffTy, NewConstOffVal);
if (OffAccum)
OffAccum = ConstantExpr::getAdd(OffAccum, NewConstOff);
else
OffAccum = NewConstOff;
bool HasNonNegativeOff = false;
if (auto *CI = dyn_cast<ConstantInt>(OffAccum)) {
HasNonNegativeOff = !CI->isNegative();
}
Constant *NewOff = ConstantExpr::getAdd(
Off, OffAccum, /*hasNUW=*/InBounds && HasNonNegativeOff,
/*hasNSW=*/false);
return ConstantStruct::get(NewTy, {Rsrc, NewOff});
}

if (auto *PI = dyn_cast<PtrToIntOperator>(CE)) {
Constant *Parts =
InternalMapper.mapConstant(*cast<Constant>(PI->getPointerOperand()));
auto [Rsrc, Off] = splitLoweredFatBufferConst(Parts);
// Here, we take advantage of the fact that ptrtoint has a built-in
// zero-extension behavior.
unsigned FatPtrWidth =
DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER);
Constant *RsrcInt = CE->getPtrToInt(Rsrc, SrcTy);
unsigned Width = SrcTy->getScalarSizeInBits();
Constant *Shift =
CE->getIntegerValue(SrcTy, APInt(Width, BufferOffsetWidth));
Constant *OffCast =
ConstantFoldIntegerCast(Off, SrcTy, /*IsSigned=*/false, DL);
Constant *RsrcHi = ConstantExpr::getShl(
RsrcInt, Shift, Width >= FatPtrWidth, Width > FatPtrWidth);
// This should be an or, but those got recently removed.
Constant *Result = ConstantExpr::getAdd(RsrcHi, OffCast, true, true);
return Result;
}
if (isa<GlobalValue>(C))
report_fatal_error("Global values containing ptr addrspace(7) (buffer "
"fat pointer) values are not supported");

if (CE->getOpcode() == Instruction::IntToPtr) {
auto *Arg = cast<Constant>(CE->getOperand(0));
unsigned FatPtrWidth =
DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER);
unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_RESOURCE);
auto *WantedTy = Arg->getType()->getWithNewBitWidth(FatPtrWidth);
Arg = ConstantFoldIntegerCast(Arg, WantedTy, /*IsSigned=*/false, DL);

Constant *Shift =
CE->getIntegerValue(WantedTy, APInt(FatPtrWidth, BufferOffsetWidth));
Type *RsrcIntType = WantedTy->getWithNewBitWidth(RsrcPtrWidth);
Type *RsrcTy = NewTy->getElementType(0);
Type *OffTy = WantedTy->getWithNewBitWidth(BufferOffsetWidth);
Constant *RsrcInt = CE->getTrunc(
ConstantFoldBinaryOpOperands(Instruction::LShr, Arg, Shift, DL),
RsrcIntType);
Constant *Rsrc = CE->getIntToPtr(RsrcInt, RsrcTy);
Constant *Off = ConstantFoldIntegerCast(Arg, OffTy, /*isSigned=*/false, DL);

return ConstantStruct::get(NewTy, {Rsrc, Off});
}
if (isa<ConstantExpr>(C))
report_fatal_error("Constant exprs containing ptr addrspace(7) (buffer "
"fat pointer) values should have been expanded earlier");

if (auto *AC = dyn_cast<AddrSpaceCastOperator>(CE)) {
unsigned SrcAS = AC->getSrcAddressSpace();
unsigned DstAS = AC->getDestAddressSpace();
auto *Arg = cast<Constant>(AC->getPointerOperand());
auto *NewArg = InternalMapper.mapConstant(*Arg);
if (!NewArg)
return nullptr;
if (SrcAS == AMDGPUAS::BUFFER_FAT_POINTER &&
DstAS == AMDGPUAS::BUFFER_FAT_POINTER)
return NewArg;
if (SrcAS == AMDGPUAS::BUFFER_RESOURCE &&
DstAS == AMDGPUAS::BUFFER_FAT_POINTER) {
auto *NullOff = CE->getNullValue(NewTy->getElementType(1));
return ConstantStruct::get(NewTy, {NewArg, NullOff});
}
report_fatal_error(
"Unsupported address space cast for a buffer fat pointer");
}
return nullptr;
}

Value *FatPtrConstMaterializer::materialize(Value *V) {
Constant *C = dyn_cast<Constant>(V);
if (!C)
return nullptr;
if (auto *GEPO = dyn_cast<GEPOperator>(C)) {
// As a special case, adjust GEP constants that have a ptr addrspace(7) in
// their source types here, since the earlier local changes didn't handle
// htis.
Type *SrcTy = GEPO->getSourceElementType();
Type *NewSrcTy = IntTypeMap->remapType(SrcTy);
if (SrcTy != NewSrcTy) {
SmallVector<Constant *> Ops;
Ops.reserve(GEPO->getNumOperands());
for (const Use &U : GEPO->operands())
Ops.push_back(cast<Constant>(U.get()));
auto *NewGEP = ConstantExpr::getGetElementPtr(
NewSrcTy, Ops[0], ArrayRef<Constant *>(Ops).slice(1),
GEPO->getNoWrapFlags(), GEPO->getInRange());
LLVM_DEBUG(dbgs() << "p7-getting GEP: " << *GEPO << " becomes " << *NewGEP
<< "\n");
Value *FurtherMap = materialize(NewGEP);
return FurtherMap ? FurtherMap : NewGEP;
}
}
// Structs and other types that happen to contain fat pointers get remapped
// by the mapValue() logic.
if (!isBufferFatPtrConst(C))
Expand Down Expand Up @@ -1782,14 +1645,9 @@ class AMDGPULowerBufferFatPointers : public ModulePass {
static bool containsBufferFatPointers(const Function &F,
BufferFatPtrToStructTypeMap *TypeMap) {
bool HasFatPointers = false;
for (const BasicBlock &BB : F) {
for (const Instruction &I : BB) {
for (const BasicBlock &BB : F)
for (const Instruction &I : BB)
HasFatPointers |= (I.getType() != TypeMap->remapType(I.getType()));
for (const Use &U : I.operands())
if (auto *C = dyn_cast<Constant>(U.get()))
HasFatPointers |= isBufferFatPtrConst(C);
}
}
return HasFatPointers;
}

Expand Down Expand Up @@ -1888,6 +1746,36 @@ bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
"buffer resource pointers (address space 8) instead.");
}

{
// Collect all constant exprs and aggregates referenced by any function.
SmallVector<Constant *, 8> Worklist;
for (Function &F : M.functions())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Braces

for (Instruction &I : instructions(F))
for (Value *Op : I.operands())
if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op))
Worklist.push_back(cast<Constant>(Op));

// Recursively look for any referenced buffer pointer constants.
SmallPtrSet<Constant *, 8> Visited;
SetVector<Constant *> BufferFatPtrConsts;
while (!Worklist.empty()) {
Constant *C = Worklist.pop_back_val();
if (!Visited.insert(C).second)
continue;
if (isBufferFatPtrOrVector(C->getType()))
BufferFatPtrConsts.insert(C);
for (Value *Op : C->operands())
if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op))
Worklist.push_back(cast<Constant>(Op));
}

// Expand all constant expressions using fat buffer pointers to
// instructions.
Changed |= convertUsersOfConstantsToInstructions(
BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr,
/*RemoveDeadConstants=*/false, /*IncludeSelf=*/true);
}

StoreFatPtrsAsIntsVisitor MemOpsRewrite(&IntTM, M.getContext());
for (Function &F : M.functions()) {
bool InterfaceChange = hasFatPointerInterface(F, &StructTM);
Expand All @@ -1903,7 +1791,7 @@ bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
SmallVector<Function *> Intrinsics;
// Keep one big map so as to memoize constants across functions.
ValueToValueMapTy CloneMap;
FatPtrConstMaterializer Materializer(&StructTM, CloneMap, &IntTM, DL);
FatPtrConstMaterializer Materializer(&StructTM, CloneMap);

ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer);
for (auto [F, InterfaceChange] : NeedsRemap) {
Expand Down
13 changes: 8 additions & 5 deletions llvm/test/CodeGen/AMDGPU/lower-buffer-fat-pointers-constants.ll
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ define ptr addrspace(7) @gep_p7_from_p7() {
define i160 @ptrtoint() {
; CHECK-LABEL: define i160 @ptrtoint
; CHECK-SAME: () #[[ATTR0]] {
; CHECK-NEXT: ret i160 add nuw nsw (i160 shl nuw (i160 ptrtoint (ptr addrspace(8) @buf to i160), i160 32), i160 12)
; CHECK-NEXT: [[TMP1:%.*]] = or i160 shl nuw (i160 ptrtoint (ptr addrspace(8) @buf to i160), i160 32), 12
; CHECK-NEXT: ret i160 [[TMP1]]
;
ret i160 ptrtoint(
ptr addrspace(7) getelementptr(
Expand All @@ -154,7 +155,8 @@ define i160 @ptrtoint() {
define i256 @ptrtoint_long() {
; CHECK-LABEL: define i256 @ptrtoint_long
; CHECK-SAME: () #[[ATTR0]] {
; CHECK-NEXT: ret i256 add nuw nsw (i256 shl nuw nsw (i256 ptrtoint (ptr addrspace(8) @buf to i256), i256 32), i256 12)
; CHECK-NEXT: [[TMP1:%.*]] = or i256 shl nuw nsw (i256 ptrtoint (ptr addrspace(8) @buf to i256), i256 32), 12
; CHECK-NEXT: ret i256 [[TMP1]]
;
ret i256 ptrtoint(
ptr addrspace(7) getelementptr(
Expand All @@ -165,7 +167,8 @@ define i256 @ptrtoint_long() {
define i64 @ptrtoint_short() {
; CHECK-LABEL: define i64 @ptrtoint_short
; CHECK-SAME: () #[[ATTR0]] {
; CHECK-NEXT: ret i64 add nuw nsw (i64 shl (i64 ptrtoint (ptr addrspace(8) @buf to i64), i64 32), i64 12)
; CHECK-NEXT: [[TMP1:%.*]] = or i64 shl (i64 ptrtoint (ptr addrspace(8) @buf to i64), i64 32), 12
; CHECK-NEXT: ret i64 [[TMP1]]
;
ret i64 ptrtoint(
ptr addrspace(7) getelementptr(
Expand All @@ -176,7 +179,7 @@ define i64 @ptrtoint_short() {
define i32 @ptrtoint_very_short() {
; CHECK-LABEL: define i32 @ptrtoint_very_short
; CHECK-SAME: () #[[ATTR0]] {
; CHECK-NEXT: ret i32 add nuw nsw (i32 shl (i32 ptrtoint (ptr addrspace(8) @buf to i32), i32 32), i32 12)
; CHECK-NEXT: ret i32 12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lost depending on the global at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as #95543, but for constants. It previously produced a poison shift that shouldn't be there.

;
ret i32 ptrtoint(
ptr addrspace(7) getelementptr(
Expand Down Expand Up @@ -212,7 +215,7 @@ define <2 x ptr addrspace(7)> @inttoptr_vec() {
define i32 @fancy_zero() {
; CHECK-LABEL: define i32 @fancy_zero
; CHECK-SAME: () #[[ATTR0]] {
; CHECK-NEXT: ret i32 shl (i32 ptrtoint (ptr addrspace(8) @buf to i32), i32 32)
; CHECK-NEXT: ret i32 0
;
ret i32 ptrtoint (
ptr addrspace(7) addrspacecast (ptr addrspace(8) @buf to ptr addrspace(7))
Expand Down
Loading