Skip to content

Commit 2005b26

Browse files
committed
AMDGPU: Propagate amdgpu-max-num-workgroups attribute
I'm not sure what the interpretation of 0 is supposed to be, AMDGPUUsage doesn't say.
1 parent cf4442e commit 2005b26

File tree

2 files changed

+380
-2
lines changed

2 files changed

+380
-2
lines changed

llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp

Lines changed: 152 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ class AMDGPUInformationCache : public InformationCache {
179179
return {ST.getMinFlatWorkGroupSize(), ST.getMaxFlatWorkGroupSize()};
180180
}
181181

182+
SmallVector<unsigned> getMaxNumWorkGroups(const Function &F) {
183+
const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
184+
return ST.getMaxNumWorkGroups(F);
185+
}
186+
182187
/// Get code object version.
183188
unsigned getCodeObjectVersion() const {
184189
return CodeObjectVersion;
@@ -821,6 +826,150 @@ AAAMDFlatWorkGroupSize::createForPosition(const IRPosition &IRP,
821826
"AAAMDFlatWorkGroupSize is only valid for function position");
822827
}
823828

829+
struct TupleDecIntegerRangeState : public AbstractState {
830+
DecIntegerState<uint32_t> X, Y, Z;
831+
832+
bool isValidState() const override {
833+
return X.isValidState() && Y.isValidState() && Z.isValidState();
834+
}
835+
836+
bool isAtFixpoint() const override {
837+
return X.isAtFixpoint() && Y.isAtFixpoint() && Z.isAtFixpoint();
838+
}
839+
840+
ChangeStatus indicateOptimisticFixpoint() override {
841+
return X.indicateOptimisticFixpoint() | Y.indicateOptimisticFixpoint() |
842+
Z.indicateOptimisticFixpoint();
843+
}
844+
845+
ChangeStatus indicatePessimisticFixpoint() override {
846+
return X.indicatePessimisticFixpoint() | Y.indicatePessimisticFixpoint() |
847+
Z.indicatePessimisticFixpoint();
848+
}
849+
850+
TupleDecIntegerRangeState operator^=(const TupleDecIntegerRangeState &Other) {
851+
X ^= Other.X;
852+
Y ^= Other.Y;
853+
Z ^= Other.Z;
854+
return *this;
855+
}
856+
857+
bool operator==(const TupleDecIntegerRangeState &Other) const {
858+
return X == Other.X && Y == Other.Y && Z == Other.Z;
859+
}
860+
861+
TupleDecIntegerRangeState &getAssumed() { return *this; }
862+
const TupleDecIntegerRangeState &getAssumed() const { return *this; }
863+
};
864+
865+
using AAAMDMaxNumWorkgroupsState =
866+
StateWrapper<TupleDecIntegerRangeState, AbstractAttribute, uint32_t>;
867+
868+
/// Propagate amdgpu-max-num-workgroups attribute.
869+
struct AAAMDMaxNumWorkgroups
870+
: public StateWrapper<TupleDecIntegerRangeState, AbstractAttribute> {
871+
using Base = StateWrapper<TupleDecIntegerRangeState, AbstractAttribute>;
872+
873+
AAAMDMaxNumWorkgroups(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
874+
875+
void initialize(Attributor &A) override {
876+
Function *F = getAssociatedFunction();
877+
auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
878+
879+
SmallVector<unsigned> MaxNumWorkgroups = InfoCache.getMaxNumWorkGroups(*F);
880+
881+
// FIXME: What is the interpretation of 0?
882+
for (unsigned &Entry : MaxNumWorkgroups) {
883+
if (Entry == 0)
884+
Entry = std::numeric_limits<uint32_t>::max();
885+
}
886+
887+
X.takeKnownMinimum(MaxNumWorkgroups[0]);
888+
Y.takeKnownMinimum(MaxNumWorkgroups[1]);
889+
Z.takeKnownMinimum(MaxNumWorkgroups[2]);
890+
891+
if (AMDGPU::isEntryFunctionCC(F->getCallingConv()))
892+
indicatePessimisticFixpoint();
893+
}
894+
895+
ChangeStatus updateImpl(Attributor &A) override {
896+
ChangeStatus Change = ChangeStatus::UNCHANGED;
897+
898+
auto CheckCallSite = [&](AbstractCallSite CS) {
899+
Function *Caller = CS.getInstruction()->getFunction();
900+
LLVM_DEBUG(dbgs() << "[AAAMDMaxNumWorkgroups] Call " << Caller->getName()
901+
<< "->" << getAssociatedFunction()->getName() << '\n');
902+
903+
const auto *CallerInfo = A.getAAFor<AAAMDMaxNumWorkgroups>(
904+
*this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
905+
if (!CallerInfo)
906+
return false;
907+
908+
Change |=
909+
clampStateAndIndicateChange(this->getState(), CallerInfo->getState());
910+
return true;
911+
};
912+
913+
bool AllCallSitesKnown = true;
914+
if (!A.checkForAllCallSites(CheckCallSite, *this, true, AllCallSitesKnown))
915+
return indicatePessimisticFixpoint();
916+
917+
return Change;
918+
}
919+
920+
/// Create an abstract attribute view for the position \p IRP.
921+
static AAAMDMaxNumWorkgroups &createForPosition(const IRPosition &IRP,
922+
Attributor &A);
923+
924+
ChangeStatus manifest(Attributor &A) override {
925+
Function *F = getAssociatedFunction();
926+
// TODO: Skip adding if worst case?
927+
LLVMContext &Ctx = F->getContext();
928+
SmallString<32> Buffer;
929+
raw_svector_ostream OS(Buffer);
930+
OS << X.getAssumed() << ',' << Y.getAssumed() << ',' << Z.getAssumed();
931+
932+
// TODO: Should annotate loads of the group size for this to do anything
933+
// useful.
934+
return A.manifestAttrs(
935+
getIRPosition(),
936+
{Attribute::get(Ctx, "amdgpu-max-num-workgroups", OS.str())},
937+
/* ForceReplace= */ true);
938+
}
939+
940+
const std::string getName() const override { return "AAAMDMaxNumWorkgroups"; }
941+
942+
const std::string getAsStr(Attributor *) const override {
943+
std::string Buffer = "AAAMDMaxNumWorkgroupsState[";
944+
raw_string_ostream OS(Buffer);
945+
OS << X.getAssumed() << ',' << Y.getAssumed() << ',' << Z.getAssumed()
946+
<< ']';
947+
return OS.str();
948+
}
949+
950+
const char *getIdAddr() const override { return &ID; }
951+
952+
/// This function should return true if the type of the \p AA is
953+
/// AAAMDMaxNumWorkgroups
954+
static bool classof(const AbstractAttribute *AA) {
955+
return (AA->getIdAddr() == &ID);
956+
}
957+
958+
void trackStatistics() const override {}
959+
960+
/// Unique ID (due to the unique address)
961+
static const char ID;
962+
};
963+
964+
const char AAAMDMaxNumWorkgroups::ID = 0;
965+
966+
AAAMDMaxNumWorkgroups &
967+
AAAMDMaxNumWorkgroups::createForPosition(const IRPosition &IRP, Attributor &A) {
968+
if (IRP.getPositionKind() == IRPosition::IRP_FUNCTION)
969+
return *new (A.Allocator) AAAMDMaxNumWorkgroups(IRP, A);
970+
llvm_unreachable("AAAMDMaxNumWorkgroups is only valid for function position");
971+
}
972+
824973
/// Propagate amdgpu-waves-per-eu attribute.
825974
struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
826975
AAAMDWavesPerEU(const IRPosition &IRP, Attributor &A)
@@ -1043,8 +1192,8 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
10431192
DenseSet<const char *> Allowed(
10441193
{&AAAMDAttributes::ID, &AAUniformWorkGroupSize::ID,
10451194
&AAPotentialValues::ID, &AAAMDFlatWorkGroupSize::ID,
1046-
&AAAMDWavesPerEU::ID, &AAAMDGPUNoAGPR::ID, &AACallEdges::ID,
1047-
&AAPointerInfo::ID, &AAPotentialConstantValues::ID,
1195+
&AAAMDMaxNumWorkgroups::ID, &AAAMDWavesPerEU::ID, &AAAMDGPUNoAGPR::ID,
1196+
&AACallEdges::ID, &AAPointerInfo::ID, &AAPotentialConstantValues::ID,
10481197
&AAUnderlyingObjects::ID, &AAAddressSpace::ID, &AAIndirectCallInfo::ID,
10491198
&AAInstanceInfo::ID});
10501199

@@ -1068,6 +1217,7 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
10681217
for (auto *F : Functions) {
10691218
A.getOrCreateAAFor<AAAMDAttributes>(IRPosition::function(*F));
10701219
A.getOrCreateAAFor<AAUniformWorkGroupSize>(IRPosition::function(*F));
1220+
A.getOrCreateAAFor<AAAMDMaxNumWorkgroups>(IRPosition::function(*F));
10711221
A.getOrCreateAAFor<AAAMDGPUNoAGPR>(IRPosition::function(*F));
10721222
CallingConv::ID CC = F->getCallingConv();
10731223
if (!AMDGPU::isEntryFunctionCC(CC)) {

0 commit comments

Comments
 (0)