Skip to content

Commit 664a226

Browse files
authored
AMDGPU: Propagate amdgpu-max-num-workgroups attribute (#113018)
I'm not sure what the interpretation of 0 is supposed to be, AMDGPUUsage doesn't say.
1 parent 084451c commit 664a226

File tree

2 files changed

+385
-2
lines changed

2 files changed

+385
-2
lines changed

llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp

Lines changed: 147 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ class AMDGPUInformationCache : public InformationCache {
179179
return {ST.getMinFlatWorkGroupSize(), ST.getMaxFlatWorkGroupSize()};
180180
}
181181

182+
SmallVector<unsigned> getMaxNumWorkGroups(const Function &F) {
183+
const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
184+
return ST.getMaxNumWorkGroups(F);
185+
}
186+
182187
/// Get code object version.
183188
unsigned getCodeObjectVersion() const { return CodeObjectVersion; }
184189

@@ -821,6 +826,145 @@ AAAMDFlatWorkGroupSize::createForPosition(const IRPosition &IRP,
821826
"AAAMDFlatWorkGroupSize is only valid for function position");
822827
}
823828

829+
struct TupleDecIntegerRangeState : public AbstractState {
830+
DecIntegerState<uint32_t> X, Y, Z;
831+
832+
bool isValidState() const override {
833+
return X.isValidState() && Y.isValidState() && Z.isValidState();
834+
}
835+
836+
bool isAtFixpoint() const override {
837+
return X.isAtFixpoint() && Y.isAtFixpoint() && Z.isAtFixpoint();
838+
}
839+
840+
ChangeStatus indicateOptimisticFixpoint() override {
841+
return X.indicateOptimisticFixpoint() | Y.indicateOptimisticFixpoint() |
842+
Z.indicateOptimisticFixpoint();
843+
}
844+
845+
ChangeStatus indicatePessimisticFixpoint() override {
846+
return X.indicatePessimisticFixpoint() | Y.indicatePessimisticFixpoint() |
847+
Z.indicatePessimisticFixpoint();
848+
}
849+
850+
TupleDecIntegerRangeState operator^=(const TupleDecIntegerRangeState &Other) {
851+
X ^= Other.X;
852+
Y ^= Other.Y;
853+
Z ^= Other.Z;
854+
return *this;
855+
}
856+
857+
bool operator==(const TupleDecIntegerRangeState &Other) const {
858+
return X == Other.X && Y == Other.Y && Z == Other.Z;
859+
}
860+
861+
TupleDecIntegerRangeState &getAssumed() { return *this; }
862+
const TupleDecIntegerRangeState &getAssumed() const { return *this; }
863+
};
864+
865+
using AAAMDMaxNumWorkgroupsState =
866+
StateWrapper<TupleDecIntegerRangeState, AbstractAttribute, uint32_t>;
867+
868+
/// Propagate amdgpu-max-num-workgroups attribute.
869+
struct AAAMDMaxNumWorkgroups
870+
: public StateWrapper<TupleDecIntegerRangeState, AbstractAttribute> {
871+
using Base = StateWrapper<TupleDecIntegerRangeState, AbstractAttribute>;
872+
873+
AAAMDMaxNumWorkgroups(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
874+
875+
void initialize(Attributor &A) override {
876+
Function *F = getAssociatedFunction();
877+
auto &InfoCache = static_cast<AMDGPUInformationCache &>(A.getInfoCache());
878+
879+
SmallVector<unsigned> MaxNumWorkgroups = InfoCache.getMaxNumWorkGroups(*F);
880+
881+
X.takeKnownMinimum(MaxNumWorkgroups[0]);
882+
Y.takeKnownMinimum(MaxNumWorkgroups[1]);
883+
Z.takeKnownMinimum(MaxNumWorkgroups[2]);
884+
885+
if (AMDGPU::isEntryFunctionCC(F->getCallingConv()))
886+
indicatePessimisticFixpoint();
887+
}
888+
889+
ChangeStatus updateImpl(Attributor &A) override {
890+
ChangeStatus Change = ChangeStatus::UNCHANGED;
891+
892+
auto CheckCallSite = [&](AbstractCallSite CS) {
893+
Function *Caller = CS.getInstruction()->getFunction();
894+
LLVM_DEBUG(dbgs() << "[AAAMDMaxNumWorkgroups] Call " << Caller->getName()
895+
<< "->" << getAssociatedFunction()->getName() << '\n');
896+
897+
const auto *CallerInfo = A.getAAFor<AAAMDMaxNumWorkgroups>(
898+
*this, IRPosition::function(*Caller), DepClassTy::REQUIRED);
899+
if (!CallerInfo || !CallerInfo->isValidState())
900+
return false;
901+
902+
Change |=
903+
clampStateAndIndicateChange(this->getState(), CallerInfo->getState());
904+
return true;
905+
};
906+
907+
bool AllCallSitesKnown = true;
908+
if (!A.checkForAllCallSites(CheckCallSite, *this,
909+
/*RequireAllCallSites=*/true,
910+
AllCallSitesKnown))
911+
return indicatePessimisticFixpoint();
912+
913+
return Change;
914+
}
915+
916+
/// Create an abstract attribute view for the position \p IRP.
917+
static AAAMDMaxNumWorkgroups &createForPosition(const IRPosition &IRP,
918+
Attributor &A);
919+
920+
ChangeStatus manifest(Attributor &A) override {
921+
Function *F = getAssociatedFunction();
922+
LLVMContext &Ctx = F->getContext();
923+
SmallString<32> Buffer;
924+
raw_svector_ostream OS(Buffer);
925+
OS << X.getAssumed() << ',' << Y.getAssumed() << ',' << Z.getAssumed();
926+
927+
// TODO: Should annotate loads of the group size for this to do anything
928+
// useful.
929+
return A.manifestAttrs(
930+
getIRPosition(),
931+
{Attribute::get(Ctx, "amdgpu-max-num-workgroups", OS.str())},
932+
/* ForceReplace= */ true);
933+
}
934+
935+
const std::string getName() const override { return "AAAMDMaxNumWorkgroups"; }
936+
937+
const std::string getAsStr(Attributor *) const override {
938+
std::string Buffer = "AAAMDMaxNumWorkgroupsState[";
939+
raw_string_ostream OS(Buffer);
940+
OS << X.getAssumed() << ',' << Y.getAssumed() << ',' << Z.getAssumed()
941+
<< ']';
942+
return OS.str();
943+
}
944+
945+
const char *getIdAddr() const override { return &ID; }
946+
947+
/// This function should return true if the type of the \p AA is
948+
/// AAAMDMaxNumWorkgroups
949+
static bool classof(const AbstractAttribute *AA) {
950+
return (AA->getIdAddr() == &ID);
951+
}
952+
953+
void trackStatistics() const override {}
954+
955+
/// Unique ID (due to the unique address)
956+
static const char ID;
957+
};
958+
959+
const char AAAMDMaxNumWorkgroups::ID = 0;
960+
961+
AAAMDMaxNumWorkgroups &
962+
AAAMDMaxNumWorkgroups::createForPosition(const IRPosition &IRP, Attributor &A) {
963+
if (IRP.getPositionKind() == IRPosition::IRP_FUNCTION)
964+
return *new (A.Allocator) AAAMDMaxNumWorkgroups(IRP, A);
965+
llvm_unreachable("AAAMDMaxNumWorkgroups is only valid for function position");
966+
}
967+
824968
/// Propagate amdgpu-waves-per-eu attribute.
825969
struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
826970
AAAMDWavesPerEU(const IRPosition &IRP, Attributor &A)
@@ -1046,8 +1190,8 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
10461190
DenseSet<const char *> Allowed(
10471191
{&AAAMDAttributes::ID, &AAUniformWorkGroupSize::ID,
10481192
&AAPotentialValues::ID, &AAAMDFlatWorkGroupSize::ID,
1049-
&AAAMDWavesPerEU::ID, &AAAMDGPUNoAGPR::ID, &AACallEdges::ID,
1050-
&AAPointerInfo::ID, &AAPotentialConstantValues::ID,
1193+
&AAAMDMaxNumWorkgroups::ID, &AAAMDWavesPerEU::ID, &AAAMDGPUNoAGPR::ID,
1194+
&AACallEdges::ID, &AAPointerInfo::ID, &AAPotentialConstantValues::ID,
10511195
&AAUnderlyingObjects::ID, &AAAddressSpace::ID, &AAIndirectCallInfo::ID,
10521196
&AAInstanceInfo::ID});
10531197

@@ -1071,6 +1215,7 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
10711215
for (auto *F : Functions) {
10721216
A.getOrCreateAAFor<AAAMDAttributes>(IRPosition::function(*F));
10731217
A.getOrCreateAAFor<AAUniformWorkGroupSize>(IRPosition::function(*F));
1218+
A.getOrCreateAAFor<AAAMDMaxNumWorkgroups>(IRPosition::function(*F));
10741219
A.getOrCreateAAFor<AAAMDGPUNoAGPR>(IRPosition::function(*F));
10751220
CallingConv::ID CC = F->getCallingConv();
10761221
if (!AMDGPU::isEntryFunctionCC(CC)) {

0 commit comments

Comments
 (0)