@@ -215,6 +215,19 @@ class AMDGPUInformationCache : public InformationCache {
215
215
return ST.getWavesPerEU (F, FlatWorkGroupSize);
216
216
}
217
217
218
+ std::optional<std::pair<unsigned , unsigned >>
219
+ getWavesPerEUAttr (const Function &F) {
220
+ Attribute Attr = F.getFnAttribute (" amdgpu-waves-per-eu" );
221
+ if (!Attr.isStringAttribute ())
222
+ return std::nullopt;
223
+ auto Val = parseRangeAttribute (Attr.getValueAsString ());
224
+ if (Val && Val->second == 0 ) {
225
+ const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
226
+ Val->second = ST.getMaxWavesPerEU ();
227
+ }
228
+ return Val;
229
+ }
230
+
218
231
std::pair<unsigned , unsigned >
219
232
getEffectiveWavesPerEU (const Function &F,
220
233
std::pair<unsigned , unsigned > WavesPerEU,
@@ -785,22 +798,6 @@ struct AAAMDSizeRangeAttribute
785
798
/* ForceReplace=*/ true );
786
799
}
787
800
788
- ChangeStatus emitAttributeIfNotDefault (Attributor &A, unsigned Min,
789
- unsigned Max) {
790
- // Don't add the attribute if it's the implied default.
791
- if (getAssumed ().getLower () == Min && getAssumed ().getUpper () - 1 == Max)
792
- return ChangeStatus::UNCHANGED;
793
-
794
- Function *F = getAssociatedFunction ();
795
- LLVMContext &Ctx = F->getContext ();
796
- SmallString<10 > Buffer;
797
- raw_svector_ostream OS (Buffer);
798
- OS << getAssumed ().getLower () << ' ,' << getAssumed ().getUpper () - 1 ;
799
- return A.manifestAttrs (getIRPosition (),
800
- {Attribute::get (Ctx, AttrName, OS.str ())},
801
- /* ForceReplace=*/ true );
802
- }
803
-
804
801
const std::string getAsStr (Attributor *) const override {
805
802
std::string Str;
806
803
raw_string_ostream OS (Str);
@@ -885,29 +882,44 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
885
882
AAAMDWavesPerEU (const IRPosition &IRP, Attributor &A)
886
883
: AAAMDSizeRangeAttribute(IRP, A, " amdgpu-waves-per-eu" ) {}
887
884
888
- bool isValidState () const override {
889
- return !Assumed.isEmptySet () && IntegerRangeState::isValidState ();
890
- }
891
-
892
885
void initialize (Attributor &A) override {
893
886
Function *F = getAssociatedFunction ();
894
887
auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
895
888
896
- if (const auto *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
897
- *this , IRPosition::function (*F), DepClassTy::REQUIRED);
898
- AssumedGroupSize->isValidState ()) {
899
-
900
- unsigned Min, Max;
901
- std::tie (Min, Max) = InfoCache.getWavesPerEU (
902
- *F, {AssumedGroupSize->getAssumed ().getLower ().getZExtValue (),
903
- AssumedGroupSize->getAssumed ().getUpper ().getZExtValue () - 1 });
904
-
889
+ auto TakeRange = [&](std::pair<unsigned , unsigned > R) {
890
+ auto [Min, Max] = R;
905
891
ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
906
- intersectKnown (Range);
892
+ IntegerRangeState RangeState (Range);
893
+ clampStateAndIndicateChange (this ->getState (), RangeState);
894
+ indicateOptimisticFixpoint ();
895
+ };
896
+
897
+ // If the attribute exists, simple honor it.
898
+ if (auto Attr = InfoCache.getWavesPerEUAttr (*F)) {
899
+ TakeRange (*Attr);
900
+ return ;
907
901
}
908
902
909
- if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
910
- indicatePessimisticFixpoint ();
903
+ // It's getting trickier here, different from AAAMDFlatWorkGroupSize. Since
904
+ // the calculation of waves per EU involves flat work group size, we can't
905
+ // simply use an assumed flat work group size as a start point, because the
906
+ // update of flat work group size is in an inverse direction of waves per
907
+ // EU. However, we can still do something if it is an entry function. Since
908
+ // an entry function is a terminal node, and flat work group size either
909
+ // from attribute or default will be used anyway, we can take that value and
910
+ // calculate the waves per EU based on it. This result can't be updated by
911
+ // no means, but that could still allow us to propagate it.
912
+ if (AMDGPU::isEntryFunctionCC (F->getCallingConv ())) {
913
+ std::pair<unsigned , unsigned > MaxWavesPerEURange{
914
+ 1U , InfoCache.getMaxWavesPerEU (*F)};
915
+ std::pair<unsigned , unsigned > FlatWorkGroupSize;
916
+ if (auto Attr = InfoCache.getFlatWorkGroupSizeAttr (*F))
917
+ FlatWorkGroupSize = *Attr;
918
+ else
919
+ FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize (*F);
920
+ TakeRange (InfoCache.getEffectiveWavesPerEU (*F, MaxWavesPerEURange,
921
+ FlatWorkGroupSize));
922
+ }
911
923
}
912
924
913
925
ChangeStatus updateImpl (Attributor &A) override {
@@ -956,8 +968,8 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
956
968
ChangeStatus manifest (Attributor &A) override {
957
969
Function *F = getAssociatedFunction ();
958
970
auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
959
- unsigned Max = InfoCache. getMaxWavesPerEU (*F);
960
- return emitAttributeIfNotDefault ( A, 1 , Max );
971
+ return emitAttributeIfNotDefaultAfterClamp (
972
+ A, { 1U , InfoCache. getMaxWavesPerEU (*F)} );
961
973
}
962
974
963
975
// / See AbstractAttribute::getName()
0 commit comments