Skip to content

Commit 5ef768d

Browse files
authored
[AMDGPULowerBufferFatPointers] Expand const exprs using fat pointers (#95558)
Expand all constant expressions that use fat pointers upfront, so that the rewriting logic only has to deal with instructions and not the constant expression variants as well. My primary motivation is to remove the creation of illegal constant expressions (mul and shl) from this pass, but this also cuts down quite a bit on the amount of duplicate logic.
1 parent 9b933e9 commit 5ef768d

File tree

4 files changed

+70
-168
lines changed

4 files changed

+70
-168
lines changed

llvm/include/llvm/IR/ReplaceConstant.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,13 @@ class Function;
3030
/// RemoveDeadConstants by default will remove all dead constants as
3131
/// the final step of the function after replacement, when passed
3232
/// false it will skip this final step.
33+
///
34+
/// If \p IncludeSelf is enabled, also convert the passed constants themselves
35+
/// to instructions, rather than only their users.
3336
bool convertUsersOfConstantsToInstructions(ArrayRef<Constant *> Consts,
3437
Function *RestrictToFunc = nullptr,
35-
bool RemoveDeadConstants = true);
38+
bool RemoveDeadConstants = true,
39+
bool IncludeSelf = false);
3640

3741
} // end namespace llvm
3842

llvm/lib/IR/ReplaceConstant.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,20 @@ static SmallVector<Instruction *, 4> expandUser(BasicBlock::iterator InsertPt,
5151

5252
bool convertUsersOfConstantsToInstructions(ArrayRef<Constant *> Consts,
5353
Function *RestrictToFunc,
54-
bool RemoveDeadConstants) {
54+
bool RemoveDeadConstants,
55+
bool IncludeSelf) {
5556
// Find all expandable direct users of Consts.
5657
SmallVector<Constant *> Stack;
57-
for (Constant *C : Consts)
58-
for (User *U : C->users())
59-
if (isExpandableUser(U))
60-
Stack.push_back(cast<Constant>(U));
58+
for (Constant *C : Consts) {
59+
if (IncludeSelf) {
60+
assert(isExpandableUser(C) && "One of the constants is not expandable");
61+
Stack.push_back(C);
62+
} else {
63+
for (User *U : C->users())
64+
if (isExpandableUser(U))
65+
Stack.push_back(cast<Constant>(U));
66+
}
67+
}
6168

6269
// Include transitive users.
6370
SetVector<Constant *> ExpandableUsers;

llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp

Lines changed: 45 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
#include "llvm/IR/Metadata.h"
216216
#include "llvm/IR/Operator.h"
217217
#include "llvm/IR/PatternMatch.h"
218+
#include "llvm/IR/ReplaceConstant.h"
218219
#include "llvm/InitializePasses.h"
219220
#include "llvm/Pass.h"
220221
#include "llvm/Support/AtomicOrdering.h"
@@ -579,18 +580,14 @@ bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) {
579580
/// buffer fat pointer constant.
580581
static std::pair<Constant *, Constant *>
581582
splitLoweredFatBufferConst(Constant *C) {
582-
if (auto *AZ = dyn_cast<ConstantAggregateZero>(C))
583-
return std::make_pair(AZ->getStructElement(0), AZ->getStructElement(1));
584-
if (auto *SC = dyn_cast<ConstantStruct>(C))
585-
return std::make_pair(SC->getOperand(0), SC->getOperand(1));
586-
llvm_unreachable("Conversion should've created a {p8, i32} struct");
583+
assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer");
584+
return std::make_pair(C->getAggregateElement(0u), C->getAggregateElement(1u));
587585
}
588586

589587
namespace {
590588
/// Handle the remapping of ptr addrspace(7) constants.
591589
class FatPtrConstMaterializer final : public ValueMaterializer {
592590
BufferFatPtrToStructTypeMap *TypeMap;
593-
BufferFatPtrToIntTypeMap *IntTypeMap;
594591
// An internal mapper that is used to recurse into the arguments of constants.
595592
// While the documentation for `ValueMapper` specifies not to use it
596593
// recursively, examination of the logic in mapValue() shows that it can
@@ -600,16 +597,12 @@ class FatPtrConstMaterializer final : public ValueMaterializer {
600597

601598
Constant *materializeBufferFatPtrConst(Constant *C);
602599

603-
const DataLayout &DL;
604-
605600
public:
606601
// UnderlyingMap is the value map this materializer will be filling.
607602
FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap,
608-
ValueToValueMapTy &UnderlyingMap,
609-
BufferFatPtrToIntTypeMap *IntTypeMap,
610-
const DataLayout &DL)
611-
: TypeMap(TypeMap), IntTypeMap(IntTypeMap),
612-
InternalMapper(UnderlyingMap, RF_None, TypeMap, this), DL(DL) {}
603+
ValueToValueMapTy &UnderlyingMap)
604+
: TypeMap(TypeMap),
605+
InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {}
613606
virtual ~FatPtrConstMaterializer() = default;
614607

615608
Value *materialize(Value *V) override;
@@ -632,10 +625,6 @@ Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
632625
UndefValue::get(NewTy->getElementType(1))});
633626
}
634627

635-
if (isa<GlobalValue>(C))
636-
report_fatal_error("Global values containing ptr addrspace(7) (buffer "
637-
"fat pointer) values are not supported");
638-
639628
if (auto *VC = dyn_cast<ConstantVector>(C)) {
640629
if (Constant *S = VC->getSplatValue()) {
641630
Constant *NewS = InternalMapper.mapConstant(*S);
@@ -661,147 +650,21 @@ Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
661650
return ConstantStruct::get(NewTy, {RsrcVec, OffVec});
662651
}
663652

664-
// Constant expressions. This code mirrors how we fix up the equivalent
665-
// instructions later.
666-
auto *CE = dyn_cast<ConstantExpr>(C);
667-
if (!CE)
668-
return nullptr;
669-
if (auto *GEPO = dyn_cast<GEPOperator>(C)) {
670-
Constant *RemappedPtr =
671-
InternalMapper.mapConstant(*cast<Constant>(GEPO->getPointerOperand()));
672-
auto [Rsrc, Off] = splitLoweredFatBufferConst(RemappedPtr);
673-
Type *OffTy = Off->getType();
674-
bool InBounds = GEPO->isInBounds();
675-
676-
MapVector<Value *, APInt> VariableOffs;
677-
APInt NewConstOffVal = APInt::getZero(BufferOffsetWidth);
678-
if (!GEPO->collectOffset(DL, BufferOffsetWidth, VariableOffs,
679-
NewConstOffVal))
680-
report_fatal_error(
681-
"Scalable vector or unsized struct in fat pointer GEP");
682-
Constant *OffAccum = nullptr;
683-
for (auto [Arg, Multiple] : VariableOffs) {
684-
Constant *NewArg = InternalMapper.mapConstant(*cast<Constant>(Arg));
685-
NewArg = ConstantFoldIntegerCast(NewArg, OffTy, /*IsSigned=*/true, DL);
686-
if (!Multiple.isOne()) {
687-
if (Multiple.isPowerOf2()) {
688-
NewArg = ConstantExpr::getShl(
689-
NewArg, CE->getIntegerValue(OffTy, APInt(BufferOffsetWidth,
690-
Multiple.logBase2())));
691-
} else {
692-
NewArg = ConstantExpr::getMul(NewArg,
693-
CE->getIntegerValue(OffTy, Multiple));
694-
}
695-
}
696-
if (OffAccum) {
697-
OffAccum = ConstantExpr::getAdd(OffAccum, NewArg);
698-
} else {
699-
OffAccum = NewArg;
700-
}
701-
}
702-
Constant *NewConstOff = CE->getIntegerValue(OffTy, NewConstOffVal);
703-
if (OffAccum)
704-
OffAccum = ConstantExpr::getAdd(OffAccum, NewConstOff);
705-
else
706-
OffAccum = NewConstOff;
707-
bool HasNonNegativeOff = false;
708-
if (auto *CI = dyn_cast<ConstantInt>(OffAccum)) {
709-
HasNonNegativeOff = !CI->isNegative();
710-
}
711-
Constant *NewOff = ConstantExpr::getAdd(
712-
Off, OffAccum, /*hasNUW=*/InBounds && HasNonNegativeOff,
713-
/*hasNSW=*/false);
714-
return ConstantStruct::get(NewTy, {Rsrc, NewOff});
715-
}
716-
717-
if (auto *PI = dyn_cast<PtrToIntOperator>(CE)) {
718-
Constant *Parts =
719-
InternalMapper.mapConstant(*cast<Constant>(PI->getPointerOperand()));
720-
auto [Rsrc, Off] = splitLoweredFatBufferConst(Parts);
721-
// Here, we take advantage of the fact that ptrtoint has a built-in
722-
// zero-extension behavior.
723-
unsigned FatPtrWidth =
724-
DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER);
725-
Constant *RsrcInt = CE->getPtrToInt(Rsrc, SrcTy);
726-
unsigned Width = SrcTy->getScalarSizeInBits();
727-
Constant *Shift =
728-
CE->getIntegerValue(SrcTy, APInt(Width, BufferOffsetWidth));
729-
Constant *OffCast =
730-
ConstantFoldIntegerCast(Off, SrcTy, /*IsSigned=*/false, DL);
731-
Constant *RsrcHi = ConstantExpr::getShl(
732-
RsrcInt, Shift, Width >= FatPtrWidth, Width > FatPtrWidth);
733-
// This should be an or, but those got recently removed.
734-
Constant *Result = ConstantExpr::getAdd(RsrcHi, OffCast, true, true);
735-
return Result;
736-
}
653+
if (isa<GlobalValue>(C))
654+
report_fatal_error("Global values containing ptr addrspace(7) (buffer "
655+
"fat pointer) values are not supported");
737656

738-
if (CE->getOpcode() == Instruction::IntToPtr) {
739-
auto *Arg = cast<Constant>(CE->getOperand(0));
740-
unsigned FatPtrWidth =
741-
DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER);
742-
unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_RESOURCE);
743-
auto *WantedTy = Arg->getType()->getWithNewBitWidth(FatPtrWidth);
744-
Arg = ConstantFoldIntegerCast(Arg, WantedTy, /*IsSigned=*/false, DL);
745-
746-
Constant *Shift =
747-
CE->getIntegerValue(WantedTy, APInt(FatPtrWidth, BufferOffsetWidth));
748-
Type *RsrcIntType = WantedTy->getWithNewBitWidth(RsrcPtrWidth);
749-
Type *RsrcTy = NewTy->getElementType(0);
750-
Type *OffTy = WantedTy->getWithNewBitWidth(BufferOffsetWidth);
751-
Constant *RsrcInt = CE->getTrunc(
752-
ConstantFoldBinaryOpOperands(Instruction::LShr, Arg, Shift, DL),
753-
RsrcIntType);
754-
Constant *Rsrc = CE->getIntToPtr(RsrcInt, RsrcTy);
755-
Constant *Off = ConstantFoldIntegerCast(Arg, OffTy, /*isSigned=*/false, DL);
756-
757-
return ConstantStruct::get(NewTy, {Rsrc, Off});
758-
}
657+
if (isa<ConstantExpr>(C))
658+
report_fatal_error("Constant exprs containing ptr addrspace(7) (buffer "
659+
"fat pointer) values should have been expanded earlier");
759660

760-
if (auto *AC = dyn_cast<AddrSpaceCastOperator>(CE)) {
761-
unsigned SrcAS = AC->getSrcAddressSpace();
762-
unsigned DstAS = AC->getDestAddressSpace();
763-
auto *Arg = cast<Constant>(AC->getPointerOperand());
764-
auto *NewArg = InternalMapper.mapConstant(*Arg);
765-
if (!NewArg)
766-
return nullptr;
767-
if (SrcAS == AMDGPUAS::BUFFER_FAT_POINTER &&
768-
DstAS == AMDGPUAS::BUFFER_FAT_POINTER)
769-
return NewArg;
770-
if (SrcAS == AMDGPUAS::BUFFER_RESOURCE &&
771-
DstAS == AMDGPUAS::BUFFER_FAT_POINTER) {
772-
auto *NullOff = CE->getNullValue(NewTy->getElementType(1));
773-
return ConstantStruct::get(NewTy, {NewArg, NullOff});
774-
}
775-
report_fatal_error(
776-
"Unsupported address space cast for a buffer fat pointer");
777-
}
778661
return nullptr;
779662
}
780663

781664
Value *FatPtrConstMaterializer::materialize(Value *V) {
782665
Constant *C = dyn_cast<Constant>(V);
783666
if (!C)
784667
return nullptr;
785-
if (auto *GEPO = dyn_cast<GEPOperator>(C)) {
786-
// As a special case, adjust GEP constants that have a ptr addrspace(7) in
787-
// their source types here, since the earlier local changes didn't handle
788-
// htis.
789-
Type *SrcTy = GEPO->getSourceElementType();
790-
Type *NewSrcTy = IntTypeMap->remapType(SrcTy);
791-
if (SrcTy != NewSrcTy) {
792-
SmallVector<Constant *> Ops;
793-
Ops.reserve(GEPO->getNumOperands());
794-
for (const Use &U : GEPO->operands())
795-
Ops.push_back(cast<Constant>(U.get()));
796-
auto *NewGEP = ConstantExpr::getGetElementPtr(
797-
NewSrcTy, Ops[0], ArrayRef<Constant *>(Ops).slice(1),
798-
GEPO->getNoWrapFlags(), GEPO->getInRange());
799-
LLVM_DEBUG(dbgs() << "p7-getting GEP: " << *GEPO << " becomes " << *NewGEP
800-
<< "\n");
801-
Value *FurtherMap = materialize(NewGEP);
802-
return FurtherMap ? FurtherMap : NewGEP;
803-
}
804-
}
805668
// Structs and other types that happen to contain fat pointers get remapped
806669
// by the mapValue() logic.
807670
if (!isBufferFatPtrConst(C))
@@ -1782,14 +1645,9 @@ class AMDGPULowerBufferFatPointers : public ModulePass {
17821645
static bool containsBufferFatPointers(const Function &F,
17831646
BufferFatPtrToStructTypeMap *TypeMap) {
17841647
bool HasFatPointers = false;
1785-
for (const BasicBlock &BB : F) {
1786-
for (const Instruction &I : BB) {
1648+
for (const BasicBlock &BB : F)
1649+
for (const Instruction &I : BB)
17871650
HasFatPointers |= (I.getType() != TypeMap->remapType(I.getType()));
1788-
for (const Use &U : I.operands())
1789-
if (auto *C = dyn_cast<Constant>(U.get()))
1790-
HasFatPointers |= isBufferFatPtrConst(C);
1791-
}
1792-
}
17931651
return HasFatPointers;
17941652
}
17951653

@@ -1888,6 +1746,36 @@ bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
18881746
"buffer resource pointers (address space 8) instead.");
18891747
}
18901748

1749+
{
1750+
// Collect all constant exprs and aggregates referenced by any function.
1751+
SmallVector<Constant *, 8> Worklist;
1752+
for (Function &F : M.functions())
1753+
for (Instruction &I : instructions(F))
1754+
for (Value *Op : I.operands())
1755+
if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op))
1756+
Worklist.push_back(cast<Constant>(Op));
1757+
1758+
// Recursively look for any referenced buffer pointer constants.
1759+
SmallPtrSet<Constant *, 8> Visited;
1760+
SetVector<Constant *> BufferFatPtrConsts;
1761+
while (!Worklist.empty()) {
1762+
Constant *C = Worklist.pop_back_val();
1763+
if (!Visited.insert(C).second)
1764+
continue;
1765+
if (isBufferFatPtrOrVector(C->getType()))
1766+
BufferFatPtrConsts.insert(C);
1767+
for (Value *Op : C->operands())
1768+
if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op))
1769+
Worklist.push_back(cast<Constant>(Op));
1770+
}
1771+
1772+
// Expand all constant expressions using fat buffer pointers to
1773+
// instructions.
1774+
Changed |= convertUsersOfConstantsToInstructions(
1775+
BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr,
1776+
/*RemoveDeadConstants=*/false, /*IncludeSelf=*/true);
1777+
}
1778+
18911779
StoreFatPtrsAsIntsVisitor MemOpsRewrite(&IntTM, M.getContext());
18921780
for (Function &F : M.functions()) {
18931781
bool InterfaceChange = hasFatPointerInterface(F, &StructTM);
@@ -1903,7 +1791,7 @@ bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
19031791
SmallVector<Function *> Intrinsics;
19041792
// Keep one big map so as to memoize constants across functions.
19051793
ValueToValueMapTy CloneMap;
1906-
FatPtrConstMaterializer Materializer(&StructTM, CloneMap, &IntTM, DL);
1794+
FatPtrConstMaterializer Materializer(&StructTM, CloneMap);
19071795

19081796
ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer);
19091797
for (auto [F, InterfaceChange] : NeedsRemap) {

llvm/test/CodeGen/AMDGPU/lower-buffer-fat-pointers-constants.ll

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ define ptr addrspace(7) @gep_p7_from_p7() {
143143
define i160 @ptrtoint() {
144144
; CHECK-LABEL: define i160 @ptrtoint
145145
; CHECK-SAME: () #[[ATTR0]] {
146-
; CHECK-NEXT: ret i160 add nuw nsw (i160 shl nuw (i160 ptrtoint (ptr addrspace(8) @buf to i160), i160 32), i160 12)
146+
; CHECK-NEXT: [[TMP1:%.*]] = or i160 shl nuw (i160 ptrtoint (ptr addrspace(8) @buf to i160), i160 32), 12
147+
; CHECK-NEXT: ret i160 [[TMP1]]
147148
;
148149
ret i160 ptrtoint(
149150
ptr addrspace(7) getelementptr(
@@ -154,7 +155,8 @@ define i160 @ptrtoint() {
154155
define i256 @ptrtoint_long() {
155156
; CHECK-LABEL: define i256 @ptrtoint_long
156157
; CHECK-SAME: () #[[ATTR0]] {
157-
; CHECK-NEXT: ret i256 add nuw nsw (i256 shl nuw nsw (i256 ptrtoint (ptr addrspace(8) @buf to i256), i256 32), i256 12)
158+
; CHECK-NEXT: [[TMP1:%.*]] = or i256 shl nuw nsw (i256 ptrtoint (ptr addrspace(8) @buf to i256), i256 32), 12
159+
; CHECK-NEXT: ret i256 [[TMP1]]
158160
;
159161
ret i256 ptrtoint(
160162
ptr addrspace(7) getelementptr(
@@ -165,7 +167,8 @@ define i256 @ptrtoint_long() {
165167
define i64 @ptrtoint_short() {
166168
; CHECK-LABEL: define i64 @ptrtoint_short
167169
; CHECK-SAME: () #[[ATTR0]] {
168-
; CHECK-NEXT: ret i64 add nuw nsw (i64 shl (i64 ptrtoint (ptr addrspace(8) @buf to i64), i64 32), i64 12)
170+
; CHECK-NEXT: [[TMP1:%.*]] = or i64 shl (i64 ptrtoint (ptr addrspace(8) @buf to i64), i64 32), 12
171+
; CHECK-NEXT: ret i64 [[TMP1]]
169172
;
170173
ret i64 ptrtoint(
171174
ptr addrspace(7) getelementptr(
@@ -176,7 +179,7 @@ define i64 @ptrtoint_short() {
176179
define i32 @ptrtoint_very_short() {
177180
; CHECK-LABEL: define i32 @ptrtoint_very_short
178181
; CHECK-SAME: () #[[ATTR0]] {
179-
; CHECK-NEXT: ret i32 add nuw nsw (i32 shl (i32 ptrtoint (ptr addrspace(8) @buf to i32), i32 32), i32 12)
182+
; CHECK-NEXT: ret i32 12
180183
;
181184
ret i32 ptrtoint(
182185
ptr addrspace(7) getelementptr(
@@ -212,7 +215,7 @@ define <2 x ptr addrspace(7)> @inttoptr_vec() {
212215
define i32 @fancy_zero() {
213216
; CHECK-LABEL: define i32 @fancy_zero
214217
; CHECK-SAME: () #[[ATTR0]] {
215-
; CHECK-NEXT: ret i32 shl (i32 ptrtoint (ptr addrspace(8) @buf to i32), i32 32)
218+
; CHECK-NEXT: ret i32 0
216219
;
217220
ret i32 ptrtoint (
218221
ptr addrspace(7) addrspacecast (ptr addrspace(8) @buf to ptr addrspace(7))

0 commit comments

Comments
 (0)