@@ -24,10 +24,10 @@ SPDX-License-Identifier: MIT
24
24
#include " llvm/IR/Constants.h"
25
25
#include " llvm/IR/DerivedTypes.h"
26
26
#include " llvm/IR/Function.h"
27
+ #include " llvm/IR/IRBuilder.h"
27
28
#include " llvm/IR/Instructions.h"
28
29
#include " llvm/Support/Debug.h"
29
30
#include < unordered_map>
30
- #include " Probe/Assertion.h"
31
31
32
32
#include " llvmWrapper/IR/DerivedTypes.h"
33
33
#include " llvmWrapper/Support/TypeSize.h"
@@ -922,14 +922,66 @@ static Instruction* simplifyConstIndirectRegion(Instruction* Inst) {
922
922
return Inst;
923
923
}
924
924
925
- static Value *simplifyRegionWrite (Instruction *Inst) {
926
- IGC_ASSERT (GenXIntrinsic::isWrRegion (Inst));
927
- Value *NewVal = Inst->getOperand (GenXIntrinsic::GenXRegion::NewValueOperandNum);
925
+ // fold bitcast with wrregion:
926
+ // ==> %oldval.cast = bitcast(%oldval)
927
+ // %2 = bitcast(%1) %3 = wrregion(%oldval.cast, %1, ...)
928
+ // %3 = wrregion(%oldval, %2, ...) %2 = bitcast(%3)
929
+ // so it can be baled later.
930
+ static Value *simplifyBitCastWithRegionWrite (Instruction *WrR,
931
+ const DataLayout &DL,
932
+ const GenXSubtarget &ST) {
933
+ using namespace GenXIntrinsic ::GenXRegion;
934
+ IGC_ASSERT (GenXIntrinsic::isWrRegion (WrR));
935
+ Value *NewVal = WrR->getOperand (NewValueOperandNum);
936
+ auto *BCI = dyn_cast<BitCastInst>(NewVal);
937
+ if (!BCI)
938
+ return nullptr ;
939
+ if (WrR->hasOneUse () && GenXIntrinsic::isWritePredefReg (WrR->user_back ()))
940
+ return nullptr ;
941
+ auto *NewScalarTy = BCI->getSrcTy ()->getScalarType ();
942
+ // Do not change register category to predicate.
943
+ if (NewScalarTy->isIntegerTy (1 ))
944
+ return nullptr ;
945
+ auto *OldVal = WrR->getOperand (OldValueOperandNum);
946
+ if (GenXIntrinsic::isReadWritePredefReg (OldVal))
947
+ return nullptr ;
948
+ auto *NewVecTy = genx::changeVectorType (OldVal->getType (), NewScalarTy);
949
+ if (!NewVecTy)
950
+ return nullptr ;
951
+ Region R = makeRegionFromBaleInfo (WrR, BaleInfo ());
952
+ if (!R.changeElementType (NewScalarTy, &DL))
953
+ return nullptr ;
954
+ // Transformation is not profitable for 2D regions or if it will require
955
+ // legalization.
956
+ if (R.is2D () || R.NumElements > llvm::PowerOf2Floor (
957
+ genx::getExecSizeAllowedBits (WrR, &ST)))
958
+ return nullptr ;
959
+ IRBuilder<> IRB (WrR);
960
+ auto *OldValCast =
961
+ IRB.CreateBitCast (OldVal, NewVecTy, OldVal->getName () + " .cast" );
962
+ auto *NewWrR = R.createWrRegion (OldValCast, BCI->getOperand (0 ),
963
+ WrR->getName (), WrR, WrR->getDebugLoc ());
964
+ auto *NewBCI = IRB.CreateBitCast (NewWrR, WrR->getType (), BCI->getName ());
965
+ return NewBCI;
966
+ }
928
967
968
+ static Value *simplifyRegionWrite (Instruction *WrR) {
969
+ using namespace GenXIntrinsic ::GenXRegion;
970
+ IGC_ASSERT (GenXIntrinsic::isWrRegion (WrR));
971
+ Value *NewVal = WrR->getOperand (NewValueOperandNum);
972
+
973
+ // Replace C with B if R - whole region
974
+ // C = wrregion(A, B, R)
975
+ if (std::none_of (
976
+ WrR->user_begin (), WrR->user_end (),
977
+ [](auto *U) { return GenXIntrinsic::isWritePredefReg (U); }) &&
978
+ makeRegionFromBaleInfo (WrR, BaleInfo ()).isWhole (WrR->getType ()) &&
979
+ NewVal->getType () == WrR->getType ())
980
+ return NewVal;
929
981
// Replace C with A
930
982
// C = wrregion(A, undef, R)
931
983
if (isa<UndefValue>(NewVal))
932
- return Inst ->getOperand (GenXIntrinsic::GenXRegion:: OldValueOperandNum);
984
+ return WrR ->getOperand (OldValueOperandNum);
933
985
934
986
// When A and undef have the same type, replace C with A
935
987
// B = rdregion(A, R)
@@ -941,29 +993,68 @@ static Value *simplifyRegionWrite(Instruction *Inst) {
941
993
// C = wrregion(A, B, R)
942
994
//
943
995
if (GenXIntrinsic::isRdRegion (NewVal)) {
944
- Instruction *B = cast<Instruction>(NewVal);
945
- Region InnerR = makeRegionFromBaleInfo (B , BaleInfo ());
946
- Region OuterR = makeRegionFromBaleInfo (Inst , BaleInfo ());
996
+ Instruction *RdR = cast<Instruction>(NewVal);
997
+ Region InnerR = makeRegionFromBaleInfo (RdR , BaleInfo ());
998
+ Region OuterR = makeRegionFromBaleInfo (WrR , BaleInfo ());
947
999
if (OuterR != InnerR)
948
1000
return nullptr ;
949
1001
950
- auto OldValB = B ->getOperand (GenXIntrinsic::GenXRegion:: OldValueOperandNum);
951
- if (GenXIntrinsic::isReadPredefReg (OldValB ))
1002
+ auto OldValRdR = RdR ->getOperand (OldValueOperandNum);
1003
+ if (GenXIntrinsic::isReadPredefReg (OldValRdR ))
952
1004
return nullptr ;
953
- auto OldValC = Inst ->getOperand (GenXIntrinsic::GenXRegion:: OldValueOperandNum);
954
- if ((isa<UndefValue>(OldValC ) &&
955
- OldValB ->getType () == OldValC ->getType ()) ||
956
- OldValB == OldValC )
957
- return OldValB ;
1005
+ auto OldValWrR = WrR ->getOperand (OldValueOperandNum);
1006
+ if ((isa<UndefValue>(OldValWrR ) &&
1007
+ OldValRdR ->getType () == OldValWrR ->getType ()) ||
1008
+ OldValRdR == OldValWrR )
1009
+ return OldValRdR ;
958
1010
}
959
-
960
1011
return nullptr ;
961
1012
}
962
1013
1014
+ // fold bitcast with rdregion:
1015
+ // %2 = rdregion(%1, ...) ==> %3 = bitcast(%1)
1016
+ // %3 = bitcast(%2) %2 = rdregion(%3, ...)
1017
+ // so it can be baled later.
1018
+ static Value *simplifyBitCastFromRegionRead (BitCastInst *BCI,
1019
+ const DataLayout &DL,
1020
+ const GenXSubtarget &ST) {
1021
+ using namespace GenXIntrinsic ::GenXRegion;
1022
+ Instruction *RdR = dyn_cast<Instruction>(BCI->getOperand (0 ));
1023
+ if (!RdR || !GenXIntrinsic::isRdRegion (RdR) || !RdR->hasOneUse ())
1024
+ return nullptr ;
1025
+ auto *OldVal = RdR->getOperand (OldValueOperandNum);
1026
+ if (GenXIntrinsic::isReadPredefReg (OldVal))
1027
+ return nullptr ;
1028
+ auto *NewScalarTy = BCI->getDestTy ()->getScalarType ();
1029
+ // Do not change register category to predicate.
1030
+ if (NewScalarTy->isIntegerTy (1 ))
1031
+ return nullptr ;
1032
+ auto *NewVecTy = genx::changeVectorType (OldVal->getType (), NewScalarTy);
1033
+ if (!NewVecTy)
1034
+ return nullptr ;
1035
+ Region R = makeRegionFromBaleInfo (RdR, BaleInfo ());
1036
+ if (!R.changeElementType (NewScalarTy, &DL))
1037
+ return nullptr ;
1038
+ // Transformation is not profitable for 2D regions or if it will require
1039
+ // legalization.
1040
+ if (R.is2D () || R.NumElements > llvm::PowerOf2Floor (
1041
+ genx::getExecSizeAllowedBits (RdR, &ST)))
1042
+ return nullptr ;
1043
+ auto *NewBCI =
1044
+ IRBuilder<>(BCI).CreateBitCast (OldVal, NewVecTy, BCI->getName ());
1045
+ auto *NewRdR =
1046
+ R.createRdRegion (NewBCI, RdR->getName (), BCI, RdR->getDebugLoc ());
1047
+ return NewRdR;
1048
+ }
1049
+
963
1050
static Value *simplifyRegionRead (Instruction *Inst) {
964
1051
IGC_ASSERT (GenXIntrinsic::isRdRegion (Inst));
965
1052
Value *Input = Inst->getOperand (GenXIntrinsic::GenXRegion::OldValueOperandNum);
966
- if (isa<UndefValue>(Input))
1053
+ if (!GenXIntrinsic::isReadPredefReg (Input) &&
1054
+ makeRegionFromBaleInfo (Inst, BaleInfo ()).isWhole (Input->getType ()) &&
1055
+ Input->getType () == Inst->getType ())
1056
+ return Input;
1057
+ else if (isa<UndefValue>(Input))
967
1058
return UndefValue::get (Inst->getType ());
968
1059
else if (auto C = dyn_cast<Constant>(Input)) {
969
1060
if (auto Splat = C->getSplatValue ()) {
@@ -990,7 +1081,8 @@ static Value *simplifyRegionRead(Instruction *Inst) {
990
1081
}
991
1082
992
1083
// Simplify a region read or write.
993
- Value *llvm::genx::simplifyRegionInst (Instruction *Inst, const DataLayout *DL) {
1084
+ Value *llvm::genx::simplifyRegionInst (Instruction *Inst, const DataLayout *DL,
1085
+ const GenXSubtarget *ST) {
994
1086
if (Inst->use_empty ())
995
1087
return nullptr ;
996
1088
@@ -1013,11 +1105,17 @@ Value *llvm::genx::simplifyRegionInst(Instruction *Inst, const DataLayout *DL) {
1013
1105
if (Constant *C = ConstantFoldGenX (Inst, *DL))
1014
1106
return C;
1015
1107
1108
+ if (auto *BCI = dyn_cast<BitCastInst>(Inst); BCI && DL && ST)
1109
+ return simplifyBitCastFromRegionRead (BCI, *DL, *ST);
1016
1110
ID = GenXIntrinsic::getGenXIntrinsicID (Inst);
1017
1111
switch (ID) {
1018
1112
case GenXIntrinsic::genx_wrregionf:
1019
1113
case GenXIntrinsic::genx_wrregioni:
1020
- return simplifyRegionWrite (Inst);
1114
+ if (auto *Res = simplifyRegionWrite (Inst))
1115
+ return Res;
1116
+ if (DL && ST)
1117
+ return simplifyBitCastWithRegionWrite (Inst, *DL, *ST);
1118
+ break ;
1021
1119
case GenXIntrinsic::genx_rdregionf:
1022
1120
case GenXIntrinsic::genx_rdregioni:
1023
1121
return simplifyRegionRead (Inst);
@@ -1027,12 +1125,13 @@ Value *llvm::genx::simplifyRegionInst(Instruction *Inst, const DataLayout *DL) {
1027
1125
return nullptr ;
1028
1126
}
1029
1127
1030
- bool llvm::genx::simplifyRegionInsts (Function *F, const DataLayout *DL) {
1128
+ bool llvm::genx::simplifyRegionInsts (Function *F, const DataLayout *DL,
1129
+ const GenXSubtarget *ST) {
1031
1130
bool Changed = false ;
1032
1131
for (auto &BB : F->getBasicBlockList ()) {
1033
1132
for (auto I = BB.begin (); I != BB.end ();) {
1034
1133
Instruction *Inst = &*I++;
1035
- if (auto V = simplifyRegionInst (Inst, DL)) {
1134
+ if (auto V = simplifyRegionInst (Inst, DL, ST )) {
1036
1135
Inst->replaceAllUsesWith (V);
1037
1136
Inst->eraseFromParent ();
1038
1137
Changed = true ;
0 commit comments