@@ -179,6 +179,11 @@ class AMDGPUInformationCache : public InformationCache {
179
179
return {ST.getMinFlatWorkGroupSize (), ST.getMaxFlatWorkGroupSize ()};
180
180
}
181
181
182
+ SmallVector<unsigned > getMaxNumWorkGroups (const Function &F) {
183
+ const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
184
+ return ST.getMaxNumWorkGroups (F);
185
+ }
186
+
182
187
// / Get code object version.
183
188
unsigned getCodeObjectVersion () const { return CodeObjectVersion; }
184
189
@@ -821,6 +826,145 @@ AAAMDFlatWorkGroupSize::createForPosition(const IRPosition &IRP,
821
826
" AAAMDFlatWorkGroupSize is only valid for function position" );
822
827
}
823
828
829
+ struct TupleDecIntegerRangeState : public AbstractState {
830
+ DecIntegerState<uint32_t > X, Y, Z;
831
+
832
+ bool isValidState () const override {
833
+ return X.isValidState () && Y.isValidState () && Z.isValidState ();
834
+ }
835
+
836
+ bool isAtFixpoint () const override {
837
+ return X.isAtFixpoint () && Y.isAtFixpoint () && Z.isAtFixpoint ();
838
+ }
839
+
840
+ ChangeStatus indicateOptimisticFixpoint () override {
841
+ return X.indicateOptimisticFixpoint () | Y.indicateOptimisticFixpoint () |
842
+ Z.indicateOptimisticFixpoint ();
843
+ }
844
+
845
+ ChangeStatus indicatePessimisticFixpoint () override {
846
+ return X.indicatePessimisticFixpoint () | Y.indicatePessimisticFixpoint () |
847
+ Z.indicatePessimisticFixpoint ();
848
+ }
849
+
850
+ TupleDecIntegerRangeState operator ^=(const TupleDecIntegerRangeState &Other) {
851
+ X ^= Other.X ;
852
+ Y ^= Other.Y ;
853
+ Z ^= Other.Z ;
854
+ return *this ;
855
+ }
856
+
857
+ bool operator ==(const TupleDecIntegerRangeState &Other) const {
858
+ return X == Other.X && Y == Other.Y && Z == Other.Z ;
859
+ }
860
+
861
+ TupleDecIntegerRangeState &getAssumed () { return *this ; }
862
+ const TupleDecIntegerRangeState &getAssumed () const { return *this ; }
863
+ };
864
+
865
+ using AAAMDMaxNumWorkgroupsState =
866
+ StateWrapper<TupleDecIntegerRangeState, AbstractAttribute, uint32_t >;
867
+
868
+ // / Propagate amdgpu-max-num-workgroups attribute.
869
+ struct AAAMDMaxNumWorkgroups
870
+ : public StateWrapper<TupleDecIntegerRangeState, AbstractAttribute> {
871
+ using Base = StateWrapper<TupleDecIntegerRangeState, AbstractAttribute>;
872
+
873
+ AAAMDMaxNumWorkgroups (const IRPosition &IRP, Attributor &A) : Base(IRP) {}
874
+
875
+ void initialize (Attributor &A) override {
876
+ Function *F = getAssociatedFunction ();
877
+ auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
878
+
879
+ SmallVector<unsigned > MaxNumWorkgroups = InfoCache.getMaxNumWorkGroups (*F);
880
+
881
+ X.takeKnownMinimum (MaxNumWorkgroups[0 ]);
882
+ Y.takeKnownMinimum (MaxNumWorkgroups[1 ]);
883
+ Z.takeKnownMinimum (MaxNumWorkgroups[2 ]);
884
+
885
+ if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
886
+ indicatePessimisticFixpoint ();
887
+ }
888
+
889
+ ChangeStatus updateImpl (Attributor &A) override {
890
+ ChangeStatus Change = ChangeStatus::UNCHANGED;
891
+
892
+ auto CheckCallSite = [&](AbstractCallSite CS) {
893
+ Function *Caller = CS.getInstruction ()->getFunction ();
894
+ LLVM_DEBUG (dbgs () << " [AAAMDMaxNumWorkgroups] Call " << Caller->getName ()
895
+ << " ->" << getAssociatedFunction ()->getName () << ' \n ' );
896
+
897
+ const auto *CallerInfo = A.getAAFor <AAAMDMaxNumWorkgroups>(
898
+ *this , IRPosition::function (*Caller), DepClassTy::REQUIRED);
899
+ if (!CallerInfo || !CallerInfo->isValidState ())
900
+ return false ;
901
+
902
+ Change |=
903
+ clampStateAndIndicateChange (this ->getState (), CallerInfo->getState ());
904
+ return true ;
905
+ };
906
+
907
+ bool AllCallSitesKnown = true ;
908
+ if (!A.checkForAllCallSites (CheckCallSite, *this ,
909
+ /* RequireAllCallSites=*/ true ,
910
+ AllCallSitesKnown))
911
+ return indicatePessimisticFixpoint ();
912
+
913
+ return Change;
914
+ }
915
+
916
+ // / Create an abstract attribute view for the position \p IRP.
917
+ static AAAMDMaxNumWorkgroups &createForPosition (const IRPosition &IRP,
918
+ Attributor &A);
919
+
920
+ ChangeStatus manifest (Attributor &A) override {
921
+ Function *F = getAssociatedFunction ();
922
+ LLVMContext &Ctx = F->getContext ();
923
+ SmallString<32 > Buffer;
924
+ raw_svector_ostream OS (Buffer);
925
+ OS << X.getAssumed () << ' ,' << Y.getAssumed () << ' ,' << Z.getAssumed ();
926
+
927
+ // TODO: Should annotate loads of the group size for this to do anything
928
+ // useful.
929
+ return A.manifestAttrs (
930
+ getIRPosition (),
931
+ {Attribute::get (Ctx, " amdgpu-max-num-workgroups" , OS.str ())},
932
+ /* ForceReplace= */ true );
933
+ }
934
+
935
+ const std::string getName () const override { return " AAAMDMaxNumWorkgroups" ; }
936
+
937
+ const std::string getAsStr (Attributor *) const override {
938
+ std::string Buffer = " AAAMDMaxNumWorkgroupsState[" ;
939
+ raw_string_ostream OS (Buffer);
940
+ OS << X.getAssumed () << ' ,' << Y.getAssumed () << ' ,' << Z.getAssumed ()
941
+ << ' ]' ;
942
+ return OS.str ();
943
+ }
944
+
945
+ const char *getIdAddr () const override { return &ID; }
946
+
947
+ // / This function should return true if the type of the \p AA is
948
+ // / AAAMDMaxNumWorkgroups
949
+ static bool classof (const AbstractAttribute *AA) {
950
+ return (AA->getIdAddr () == &ID);
951
+ }
952
+
953
+ void trackStatistics () const override {}
954
+
955
+ // / Unique ID (due to the unique address)
956
+ static const char ID;
957
+ };
958
+
959
+ const char AAAMDMaxNumWorkgroups::ID = 0 ;
960
+
961
+ AAAMDMaxNumWorkgroups &
962
+ AAAMDMaxNumWorkgroups::createForPosition (const IRPosition &IRP, Attributor &A) {
963
+ if (IRP.getPositionKind () == IRPosition::IRP_FUNCTION)
964
+ return *new (A.Allocator ) AAAMDMaxNumWorkgroups (IRP, A);
965
+ llvm_unreachable (" AAAMDMaxNumWorkgroups is only valid for function position" );
966
+ }
967
+
824
968
// / Propagate amdgpu-waves-per-eu attribute.
825
969
struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
826
970
AAAMDWavesPerEU (const IRPosition &IRP, Attributor &A)
@@ -1046,8 +1190,8 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
1046
1190
DenseSet<const char *> Allowed (
1047
1191
{&AAAMDAttributes::ID, &AAUniformWorkGroupSize::ID,
1048
1192
&AAPotentialValues::ID, &AAAMDFlatWorkGroupSize::ID,
1049
- &AAAMDWavesPerEU ::ID, &AAAMDGPUNoAGPR ::ID, &AACallEdges ::ID,
1050
- &AAPointerInfo::ID, &AAPotentialConstantValues::ID,
1193
+ &AAAMDMaxNumWorkgroups ::ID, &AAAMDWavesPerEU ::ID, &AAAMDGPUNoAGPR ::ID,
1194
+ &AACallEdges::ID, & AAPointerInfo::ID, &AAPotentialConstantValues::ID,
1051
1195
&AAUnderlyingObjects::ID, &AAAddressSpace::ID, &AAIndirectCallInfo::ID,
1052
1196
&AAInstanceInfo::ID});
1053
1197
@@ -1071,6 +1215,7 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM,
1071
1215
for (auto *F : Functions) {
1072
1216
A.getOrCreateAAFor <AAAMDAttributes>(IRPosition::function (*F));
1073
1217
A.getOrCreateAAFor <AAUniformWorkGroupSize>(IRPosition::function (*F));
1218
+ A.getOrCreateAAFor <AAAMDMaxNumWorkgroups>(IRPosition::function (*F));
1074
1219
A.getOrCreateAAFor <AAAMDGPUNoAGPR>(IRPosition::function (*F));
1075
1220
CallingConv::ID CC = F->getCallingConv ();
1076
1221
if (!AMDGPU::isEntryFunctionCC (CC)) {
0 commit comments