@@ -179,6 +179,11 @@ class AMDGPUInformationCache : public InformationCache {
179
179
return {ST.getMinFlatWorkGroupSize (), ST.getMaxFlatWorkGroupSize ()};
180
180
}
181
181
182
+ SmallVector<unsigned > getMaxNumWorkGroups (const Function &F) {
183
+ const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
184
+ return ST.getMaxNumWorkGroups (F);
185
+ }
186
+
182
187
// / Get code object version.
183
188
unsigned getCodeObjectVersion () const {
184
189
return CodeObjectVersion;
@@ -821,6 +826,150 @@ AAAMDFlatWorkGroupSize::createForPosition(const IRPosition &IRP,
821
826
" AAAMDFlatWorkGroupSize is only valid for function position" );
822
827
}
823
828
829
+ struct TupleDecIntegerRangeState : public AbstractState {
830
+ DecIntegerState<uint32_t > X, Y, Z;
831
+
832
+ bool isValidState () const override {
833
+ return X.isValidState () && Y.isValidState () && Z.isValidState ();
834
+ }
835
+
836
+ bool isAtFixpoint () const override {
837
+ return X.isAtFixpoint () && Y.isAtFixpoint () && Z.isAtFixpoint ();
838
+ }
839
+
840
+ ChangeStatus indicateOptimisticFixpoint () override {
841
+ return X.indicateOptimisticFixpoint () | Y.indicateOptimisticFixpoint () |
842
+ Z.indicateOptimisticFixpoint ();
843
+ }
844
+
845
+ ChangeStatus indicatePessimisticFixpoint () override {
846
+ return X.indicatePessimisticFixpoint () | Y.indicatePessimisticFixpoint () |
847
+ Z.indicatePessimisticFixpoint ();
848
+ }
849
+
850
+ TupleDecIntegerRangeState operator ^=(const TupleDecIntegerRangeState &Other) {
851
+ X ^= Other.X ;
852
+ Y ^= Other.Y ;
853
+ Z ^= Other.Z ;
854
+ return *this ;
855
+ }
856
+
857
+ bool operator ==(const TupleDecIntegerRangeState &Other) const {
858
+ return X == Other.X && Y == Other.Y && Z == Other.Z ;
859
+ }
860
+
861
+ TupleDecIntegerRangeState &getAssumed () { return *this ; }
862
+ const TupleDecIntegerRangeState &getAssumed () const { return *this ; }
863
+ };
864
+
865
+ using AAAMDMaxNumWorkgroupsState =
866
+ StateWrapper<TupleDecIntegerRangeState, AbstractAttribute, uint32_t >;
867
+
868
+ // / Propagate amdgpu-max-num-workgroups attribute.
869
+ struct AAAMDMaxNumWorkgroups
870
+ : public StateWrapper<TupleDecIntegerRangeState, AbstractAttribute> {
871
+ using Base = StateWrapper<TupleDecIntegerRangeState, AbstractAttribute>;
872
+
873
+ AAAMDMaxNumWorkgroups (const IRPosition &IRP, Attributor &A) : Base(IRP) {}
874
+
875
+ void initialize (Attributor &A) override {
876
+ Function *F = getAssociatedFunction ();
877
+ auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
878
+
879
+ SmallVector<unsigned > MaxNumWorkgroups = InfoCache.getMaxNumWorkGroups (*F);
880
+
881
+ // FIXME: What is the interpretation of 0?
882
+ for (unsigned &Entry : MaxNumWorkgroups) {
883
+ if (Entry == 0 )
884
+ Entry = std::numeric_limits<uint32_t >::max ();
885
+ }
886
+
887
+ X.takeKnownMinimum (MaxNumWorkgroups[0 ]);
888
+ Y.takeKnownMinimum (MaxNumWorkgroups[1 ]);
889
+ Z.takeKnownMinimum (MaxNumWorkgroups[2 ]);
890
+
891
+ if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
892
+ indicatePessimisticFixpoint ();
893
+ }
894
+
895
+ ChangeStatus updateImpl (Attributor &A) override {
896
+ ChangeStatus Change = ChangeStatus::UNCHANGED;
897
+
898
+ auto CheckCallSite = [&](AbstractCallSite CS) {
899
+ Function *Caller = CS.getInstruction ()->getFunction ();
900
+ LLVM_DEBUG (dbgs () << " [AAAMDMaxNumWorkgroups] Call " << Caller->getName ()
901
+ << " ->" << getAssociatedFunction ()->getName () << ' \n ' );
902
+
903
+ const auto *CallerInfo = A.getAAFor <AAAMDMaxNumWorkgroups>(
904
+ *this , IRPosition::function (*Caller), DepClassTy::REQUIRED);
905
+ if (!CallerInfo)
906
+ return false ;
907
+
908
+ Change |=
909
+ clampStateAndIndicateChange (this ->getState (), CallerInfo->getState ());
910
+ return true ;
911
+ };
912
+
913
+ bool AllCallSitesKnown = true ;
914
+ if (!A.checkForAllCallSites (CheckCallSite, *this , true , AllCallSitesKnown))
915
+ return indicatePessimisticFixpoint ();
916
+
917
+ return Change;
918
+ }
919
+
920
+ // / Create an abstract attribute view for the position \p IRP.
921
+ static AAAMDMaxNumWorkgroups &createForPosition (const IRPosition &IRP,
922
+ Attributor &A);
923
+
924
+ ChangeStatus manifest (Attributor &A) override {
925
+ Function *F = getAssociatedFunction ();
926
+ // TODO: Skip adding if worst case?
927
+ LLVMContext &Ctx = F->getContext ();
928
+ SmallString<32 > Buffer;
929
+ raw_svector_ostream OS (Buffer);
930
+ OS << X.getAssumed () << ' ,' << Y.getAssumed () << ' ,' << Z.getAssumed ();
931
+
932
+ // TODO: Should annotate loads of the group size for this to do anything
933
+ // useful.
934
+ return A.manifestAttrs (
935
+ getIRPosition (),
936
+ {Attribute::get (Ctx, " amdgpu-max-num-workgroups" , OS.str ())},
937
+ /* ForceReplace= */ true );
938
+ }
939
+
940
+ const std::string getName () const override { return " AAAMDMaxNumWorkgroups" ; }
941
+
942
+ const std::string getAsStr (Attributor *) const override {
943
+ std::string Buffer = " AAAMDMaxNumWorkgroupsState[" ;
944
+ raw_string_ostream OS (Buffer);
945
+ OS << X.getAssumed () << ' ,' << Y.getAssumed () << ' ,' << Z.getAssumed ()
946
+ << ' ]' ;
947
+ return OS.str ();
948
+ }
949
+
950
+ const char *getIdAddr () const override { return &ID; }
951
+
952
+ // / This function should return true if the type of the \p AA is
953
+ // / AAAMDMaxNumWorkgroups
954
+ static bool classof (const AbstractAttribute *AA) {
955
+ return (AA->getIdAddr () == &ID);
956
+ }
957
+
958
+ void trackStatistics () const override {}
959
+
960
+ // / Unique ID (due to the unique address)
961
+ static const char ID;
962
+ };
963
+
964
+ const char AAAMDMaxNumWorkgroups::ID = 0 ;
965
+
966
+ AAAMDMaxNumWorkgroups &
967
+ AAAMDMaxNumWorkgroups::createForPosition (const IRPosition &IRP, Attributor &A) {
968
+ if (IRP.getPositionKind () == IRPosition::IRP_FUNCTION)
969
+ return *new (A.Allocator ) AAAMDMaxNumWorkgroups (IRP, A);
970
+ llvm_unreachable (" AAAMDMaxNumWorkgroups is only valid for function position" );
971
+ }
972
+
824
973
// / Propagate amdgpu-waves-per-eu attribute.
825
974
struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
826
975
AAAMDWavesPerEU (const IRPosition &IRP, Attributor &A)
@@ -1043,8 +1192,8 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
1043
1192
DenseSet<const char *> Allowed (
1044
1193
{&AAAMDAttributes::ID, &AAUniformWorkGroupSize::ID,
1045
1194
&AAPotentialValues::ID, &AAAMDFlatWorkGroupSize::ID,
1046
- &AAAMDWavesPerEU ::ID, &AAAMDGPUNoAGPR ::ID, &AACallEdges ::ID,
1047
- &AAPointerInfo::ID, &AAPotentialConstantValues::ID,
1195
+ &AAAMDMaxNumWorkgroups ::ID, &AAAMDWavesPerEU ::ID, &AAAMDGPUNoAGPR ::ID,
1196
+ &AACallEdges::ID, & AAPointerInfo::ID, &AAPotentialConstantValues::ID,
1048
1197
&AAUnderlyingObjects::ID, &AAAddressSpace::ID, &AAIndirectCallInfo::ID,
1049
1198
&AAInstanceInfo::ID});
1050
1199
@@ -1068,6 +1217,7 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
1068
1217
for (auto *F : Functions) {
1069
1218
A.getOrCreateAAFor <AAAMDAttributes>(IRPosition::function (*F));
1070
1219
A.getOrCreateAAFor <AAUniformWorkGroupSize>(IRPosition::function (*F));
1220
+ A.getOrCreateAAFor <AAAMDMaxNumWorkgroups>(IRPosition::function (*F));
1071
1221
A.getOrCreateAAFor <AAAMDGPUNoAGPR>(IRPosition::function (*F));
1072
1222
CallingConv::ID CC = F->getCallingConv ();
1073
1223
if (!AMDGPU::isEntryFunctionCC (CC)) {
0 commit comments