@@ -198,6 +198,16 @@ class AMDGPUInformationCache : public InformationCache {
198
198
return ST.getWavesPerEU (F, FlatWorkGroupSize);
199
199
}
200
200
201
+ std::optional<std::pair<unsigned , unsigned >>
202
+ getWavesPerEUAttr (const Function &F) {
203
+ auto Val = AMDGPU::getIntegerPairAttribute (F, " amdgpu-waves-per-eu" );
204
+ if (Val && Val->second == 0 ) {
205
+ const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
206
+ Val->second = ST.getMaxWavesPerEU ();
207
+ }
208
+ return Val;
209
+ }
210
+
201
211
std::pair<unsigned , unsigned >
202
212
getEffectiveWavesPerEU (const Function &F,
203
213
std::pair<unsigned , unsigned > WavesPerEU,
@@ -768,22 +778,6 @@ struct AAAMDSizeRangeAttribute
768
778
/* ForceReplace=*/ true );
769
779
}
770
780
771
- ChangeStatus emitAttributeIfNotDefault (Attributor &A, unsigned Min,
772
- unsigned Max) {
773
- // Don't add the attribute if it's the implied default.
774
- if (getAssumed ().getLower () == Min && getAssumed ().getUpper () - 1 == Max)
775
- return ChangeStatus::UNCHANGED;
776
-
777
- Function *F = getAssociatedFunction ();
778
- LLVMContext &Ctx = F->getContext ();
779
- SmallString<10 > Buffer;
780
- raw_svector_ostream OS (Buffer);
781
- OS << getAssumed ().getLower () << ' ,' << getAssumed ().getUpper () - 1 ;
782
- return A.manifestAttrs (getIRPosition (),
783
- {Attribute::get (Ctx, AttrName, OS.str ())},
784
- /* ForceReplace=*/ true );
785
- }
786
-
787
781
const std::string getAsStr (Attributor *) const override {
788
782
std::string Str;
789
783
raw_string_ostream OS (Str);
@@ -868,29 +862,44 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
868
862
AAAMDWavesPerEU (const IRPosition &IRP, Attributor &A)
869
863
: AAAMDSizeRangeAttribute(IRP, A, " amdgpu-waves-per-eu" ) {}
870
864
871
- bool isValidState () const override {
872
- return !Assumed.isEmptySet () && IntegerRangeState::isValidState ();
873
- }
874
-
875
865
void initialize (Attributor &A) override {
876
866
Function *F = getAssociatedFunction ();
877
867
auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
878
868
879
- if (const auto *AssumedGroupSize = A.getAAFor <AAAMDFlatWorkGroupSize>(
880
- *this , IRPosition::function (*F), DepClassTy::REQUIRED);
881
- AssumedGroupSize->isValidState ()) {
882
-
883
- unsigned Min, Max;
884
- std::tie (Min, Max) = InfoCache.getWavesPerEU (
885
- *F, {AssumedGroupSize->getAssumed ().getLower ().getZExtValue (),
886
- AssumedGroupSize->getAssumed ().getUpper ().getZExtValue () - 1 });
887
-
869
+ auto TakeRange = [&](std::pair<unsigned , unsigned > R) {
870
+ auto [Min, Max] = R;
888
871
ConstantRange Range (APInt (32 , Min), APInt (32 , Max + 1 ));
889
- intersectKnown (Range);
872
+ IntegerRangeState RangeState (Range);
873
+ clampStateAndIndicateChange (this ->getState (), RangeState);
874
+ indicateOptimisticFixpoint ();
875
+ };
876
+
877
+ // If the attribute exists, simple honor it.
878
+ if (auto Attr = InfoCache.getWavesPerEUAttr (*F)) {
879
+ TakeRange (*Attr);
880
+ return ;
890
881
}
891
882
892
- if (AMDGPU::isEntryFunctionCC (F->getCallingConv ()))
893
- indicatePessimisticFixpoint ();
883
+ // It's getting trickier here, different from AAAMDFlatWorkGroupSize. Since
884
+ // the calculation of waves per EU involves flat work group size, we can't
885
+ // simply use an assumed flat work group size as a start point, because the
886
+ // update of flat work group size is in an inverse direction of waves per
887
+ // EU. However, we can still do something if it is an entry function. Since
888
+ // an entry function is a terminal node, and flat work group size either
889
+ // from attribute or default will be used anyway, we can take that value and
890
+ // calculate the waves per EU based on it. This result can't be updated by
891
+ // no means, but that could still allow us to propagate it.
892
+ if (AMDGPU::isEntryFunctionCC (F->getCallingConv ())) {
893
+ std::pair<unsigned , unsigned > MaxWavesPerEURange{
894
+ 1U , InfoCache.getMaxWavesPerEU (*F)};
895
+ std::pair<unsigned , unsigned > FlatWorkGroupSize;
896
+ if (auto Attr = InfoCache.getFlatWorkGroupSizeAttr (*F))
897
+ FlatWorkGroupSize = *Attr;
898
+ else
899
+ FlatWorkGroupSize = InfoCache.getDefaultFlatWorkGroupSize (*F);
900
+ TakeRange (InfoCache.getEffectiveWavesPerEU (*F, MaxWavesPerEURange,
901
+ FlatWorkGroupSize));
902
+ }
894
903
}
895
904
896
905
ChangeStatus updateImpl (Attributor &A) override {
@@ -939,8 +948,8 @@ struct AAAMDWavesPerEU : public AAAMDSizeRangeAttribute {
939
948
ChangeStatus manifest (Attributor &A) override {
940
949
Function *F = getAssociatedFunction ();
941
950
auto &InfoCache = static_cast <AMDGPUInformationCache &>(A.getInfoCache ());
942
- unsigned Max = InfoCache. getMaxWavesPerEU (*F);
943
- return emitAttributeIfNotDefault ( A, 1 , Max );
951
+ return emitAttributeIfNotDefaultAfterClamp (
952
+ A, { 1U , InfoCache. getMaxWavesPerEU (*F)} );
944
953
}
945
954
946
955
// / See AbstractAttribute::getName()
0 commit comments