Skip to content

Commit e7db2ba

Browse files
skachkov-inteligcbot
authored andcommitted
Fold bitcasts to rdregion/wrregion by changing region parameters
This transformation reorders rdregion->bitcast to bitcast->rdregion and bitcast->wrregion to wrregion->bitcast, so they can be baled later.
1 parent c4ea9e6 commit e7db2ba

File tree

7 files changed

+227
-119
lines changed

7 files changed

+227
-119
lines changed

IGC/VectorCompiler/lib/GenXCodeGen/GenXLegalization.cpp

Lines changed: 22 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ class GenXLegalization : public FunctionPass {
377377
Fixed4 = nullptr;
378378
TwiceWidth = nullptr;
379379
}
380-
unsigned getExecSizeAllowedBits(Instruction *Inst);
380+
unsigned adjustTwiceWidthOrFixed4(const Bale &B);
381381
bool checkIfLongLongSupportNeeded(Instruction *Inst) const;
382382
void verifyLSCFence(const Instruction *Inst);
383383
void verifyLSCAtomic(const Instruction *Inst);
@@ -575,84 +575,30 @@ bool GenXLegalization::runOnFunction(Function &F) {
575575
return true;
576576
}
577577

578-
/***********************************************************************
579-
* getExecSizeAllowedBits : get bitmap of allowed execution sizes
580-
*
581-
* Enter: Inst = main instruction of bale
582-
*
583-
* Return: bit N set if execution size 1<<N is allowed
584-
*
585-
* Most instructions have a minimum width of 1. But some instructions,
586-
* such as dp4 and lrp, have a minimum width of 4, and legalization cannot
587-
* allow such an instruction to be split to a smaller width.
588-
*
589-
* This also sets up fields in GenXLegalization: Fixed4 is set to a use
590-
* that is a FIXED4 operand, and TwiceWidth is set to a use that is a
591-
* TWICEWIDTH operand.
592-
*/
593-
unsigned GenXLegalization::getExecSizeAllowedBits(Instruction *Inst) {
594-
595-
switch (Inst->getOpcode()) {
596-
default:
597-
break;
598-
case BinaryOperator::SDiv:
599-
case BinaryOperator::UDiv:
600-
case BinaryOperator::SRem:
601-
case BinaryOperator::URem:
602-
// If integer division IS supported.
603-
// Set maximum SIMD width to 16:
604-
// Recent HW does not support SIMD16/SIMD32 division, however,
605-
// finalizer splits such SIMD16 operations and we piggy-back
606-
// on this behavior.
607-
// If integer division IS NOT supported.
608-
// The expectation is for GenXEmulate pass to replace such operations
609-
// with emulation routines (which has no restriction on SIMD width)
610-
return ST->hasIntDivRem32() ? 0x1f : 0x3f;
611-
}
612-
613-
unsigned ID = GenXIntrinsic::getAnyIntrinsicID(Inst);
614-
switch (ID) {
615-
case GenXIntrinsic::genx_ssmad:
616-
case GenXIntrinsic::genx_sumad:
617-
case GenXIntrinsic::genx_usmad:
618-
case GenXIntrinsic::genx_uumad:
619-
case GenXIntrinsic::genx_ssmad_sat:
620-
case GenXIntrinsic::genx_sumad_sat:
621-
case GenXIntrinsic::genx_usmad_sat:
622-
case GenXIntrinsic::genx_uumad_sat:
623-
case Intrinsic::fma:
624-
// Do not emit simd32 mad for pre-ICLLP.
625-
return ST->isICLLPplus() ? 0x3f : 0x1f;
626-
default:
627-
break;
628-
}
629-
630-
if (CallInst *CI = dyn_cast<CallInst>(Inst)) {
631-
// We have a call instruction, so we can assume it is an intrinsic since
632-
// otherwise processInst would not have got as far as calling us as
633-
// a non-intrinsic call forces isSplittable() to be false.
634-
auto CalledF = CI->getCalledFunction();
635-
IGC_ASSERT(CalledF);
636-
GenXIntrinsicInfo II(GenXIntrinsic::getAnyIntrinsicID(CalledF));
637-
// While we have the intrinsic info, we also spot whether we have a FIXED4
638-
// operand and/or a TWICEWIDTH operand.
578+
unsigned GenXLegalization::adjustTwiceWidthOrFixed4(const Bale &B) {
579+
auto Main = B.getMainInst();
580+
if (!Main)
581+
return 0x3f;
582+
// Spot whether we have a FIXED operand and/or a TWICEWIDTH operand.
583+
if (GenXIntrinsic::isGenXIntrinsic(Main->Inst)) {
584+
GenXIntrinsicInfo II(GenXIntrinsic::getAnyIntrinsicID(Main->Inst));
639585
for (auto i = II.begin(), e = II.end(); i != e; ++i) {
640586
auto ArgInfo = *i;
641-
if (ArgInfo.isArgOrRet()) {
642-
switch (ArgInfo.getRestriction()) {
643-
case GenXIntrinsicInfo::FIXED4:
644-
Fixed4 = &CI->getOperandUse(ArgInfo.getArgIdx());
645-
break;
646-
case GenXIntrinsicInfo::TWICEWIDTH:
647-
TwiceWidth = &CI->getOperandUse(ArgInfo.getArgIdx());
648-
break;
649-
}
587+
if (!ArgInfo.isArgOrRet())
588+
continue;
589+
switch (ArgInfo.getRestriction()) {
590+
case GenXIntrinsicInfo::FIXED4:
591+
Fixed4 = &Main->Inst->getOperandUse(ArgInfo.getArgIdx());
592+
break;
593+
case GenXIntrinsicInfo::TWICEWIDTH:
594+
TwiceWidth = &Main->Inst->getOperandUse(ArgInfo.getArgIdx());
595+
break;
650596
}
651597
}
652-
return II.getExecSizeAllowedBits();
653598
}
654-
return 0x3f;
599+
return genx::getExecSizeAllowedBits(Main->Inst, ST);
655600
}
601+
656602
/***********************************************************************
657603
* checkIfLongLongSupportNeeded: checks if an instruction requires
658604
* target to support 64-bit integer operations
@@ -1382,9 +1328,7 @@ unsigned GenXLegalization::determineWidth(unsigned WholeWidth,
13821328
unsigned StartIdx) {
13831329
// Prepare to keep track of whether an instruction with a minimum width
13841330
// (e.g. dp4) would be split too small, and whether we need to unbale.
1385-
unsigned ExecSizeAllowedBits = 0x3f;
1386-
if (auto Main = B.getMainInst())
1387-
ExecSizeAllowedBits = getExecSizeAllowedBits(Main->Inst);
1331+
unsigned ExecSizeAllowedBits = adjustTwiceWidthOrFixed4(B);
13881332
unsigned MainInstMinWidth =
13891333
1 << countTrailingZeros(ExecSizeAllowedBits, ZB_Undefined);
13901334
// Determine the vector width that we need to split into.
@@ -2604,20 +2548,6 @@ static Value *createBitCastIfNeeded(Value *V, Type *NewTy,
26042548
return Inst;
26052549
}
26062550

2607-
// Get type that represents OldTy as vector of NewScalarType.
2608-
static Type *getNewVectorType(Type *OldTy, IntegerType *NewScalarType) {
2609-
IGC_ASSERT(OldTy->isIntOrIntVectorTy());
2610-
unsigned OldElemSize = OldTy->getScalarSizeInBits();
2611-
IGC_ASSERT(OldElemSize > 0);
2612-
auto *VTy = dyn_cast<IGCLLVM::FixedVectorType>(OldTy);
2613-
unsigned OldNumElems = VTy ? VTy->getNumElements() : 1;
2614-
unsigned NewElemSize = NewScalarType->getBitWidth();
2615-
if (OldElemSize * OldNumElems % NewElemSize)
2616-
return nullptr;
2617-
return IGCLLVM::FixedVectorType::get(NewScalarType,
2618-
OldElemSize * OldNumElems / NewElemSize);
2619-
}
2620-
26212551
/***********************************************************************
26222552
* transformMoveType : transform move bale to new integer type.
26232553
*
@@ -2671,16 +2601,13 @@ Instruction *GenXLegalization::transformMoveType(Bale *B, IntegerType *FromTy,
26712601
return nullptr;
26722602
Region DstRgn = Wr ? makeRegionFromBaleInfo(Wr, BaleInfo()) : Region(Rd);
26732603

2674-
// Check that dst and src regions can be changed on new type.
2675-
if (SrcRgn.Indirect || DstRgn.Indirect || DstRgn.Mask)
2676-
return nullptr;
26772604
// If destination region is not contiguous, changing element type can be achived
26782605
// by conversion to 2D region, but we try to avoid it because such dst operands
26792606
// are not supported in HW and require additional code for emulation.
26802607
if (DstRgn.Stride != 1)
26812608
return nullptr;
2682-
Type *NewSrcTy = getNewVectorType(Src->getType(), ToTy),
2683-
*NewDstTy = getNewVectorType(Dst->getType(), ToTy);
2609+
Type *NewSrcTy = genx::changeVectorType(Src->getType(), ToTy),
2610+
*NewDstTy = genx::changeVectorType(Dst->getType(), ToTy);
26842611
if (!NewSrcTy || !NewDstTy)
26852612
return nullptr;
26862613

IGC/VectorCompiler/lib/GenXCodeGen/GenXPostLegalization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ bool GenXPostLegalization::runOnFunction(Function &F)
138138
// Run the vector decomposer for this function.
139139
Modified |= VD.run(DT);
140140
// Cleanup region reads and writes.
141-
Modified |= simplifyRegionInsts(&F, DL);
141+
Modified |= simplifyRegionInsts(&F, DL, ST);
142142
// Cleanup redundant global loads.
143143
Modified |= cleanupLoads(&F);
144144
// Cleanup constant loads.

IGC/VectorCompiler/lib/GenXCodeGen/GenXRegionUtils.cpp

Lines changed: 120 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ SPDX-License-Identifier: MIT
2424
#include "llvm/IR/Constants.h"
2525
#include "llvm/IR/DerivedTypes.h"
2626
#include "llvm/IR/Function.h"
27+
#include "llvm/IR/IRBuilder.h"
2728
#include "llvm/IR/Instructions.h"
2829
#include "llvm/Support/Debug.h"
2930
#include <unordered_map>
30-
#include "Probe/Assertion.h"
3131

3232
#include "llvmWrapper/IR/DerivedTypes.h"
3333
#include "llvmWrapper/Support/TypeSize.h"
@@ -922,14 +922,66 @@ static Instruction* simplifyConstIndirectRegion(Instruction* Inst) {
922922
return Inst;
923923
}
924924

925-
static Value *simplifyRegionWrite(Instruction *Inst) {
926-
IGC_ASSERT(GenXIntrinsic::isWrRegion(Inst));
927-
Value *NewVal = Inst->getOperand(GenXIntrinsic::GenXRegion::NewValueOperandNum);
925+
// fold bitcast with wrregion:
926+
// ==> %oldval.cast = bitcast(%oldval)
927+
// %2 = bitcast(%1) %3 = wrregion(%oldval.cast, %1, ...)
928+
// %3 = wrregion(%oldval, %2, ...) %2 = bitcast(%3)
929+
// so it can be baled later.
930+
static Value *simplifyBitCastWithRegionWrite(Instruction *WrR,
931+
const DataLayout &DL,
932+
const GenXSubtarget &ST) {
933+
using namespace GenXIntrinsic::GenXRegion;
934+
IGC_ASSERT(GenXIntrinsic::isWrRegion(WrR));
935+
Value *NewVal = WrR->getOperand(NewValueOperandNum);
936+
auto *BCI = dyn_cast<BitCastInst>(NewVal);
937+
if (!BCI)
938+
return nullptr;
939+
if (WrR->hasOneUse() && GenXIntrinsic::isWritePredefReg(WrR->user_back()))
940+
return nullptr;
941+
auto *NewScalarTy = BCI->getSrcTy()->getScalarType();
942+
// Do not change register category to predicate.
943+
if (NewScalarTy->isIntegerTy(1))
944+
return nullptr;
945+
auto *OldVal = WrR->getOperand(OldValueOperandNum);
946+
if (GenXIntrinsic::isReadWritePredefReg(OldVal))
947+
return nullptr;
948+
auto *NewVecTy = genx::changeVectorType(OldVal->getType(), NewScalarTy);
949+
if (!NewVecTy)
950+
return nullptr;
951+
Region R = makeRegionFromBaleInfo(WrR, BaleInfo());
952+
if (!R.changeElementType(NewScalarTy, &DL))
953+
return nullptr;
954+
// Transformation is not profitable for 2D regions or if it will require
955+
// legalization.
956+
if (R.is2D() || R.NumElements > llvm::PowerOf2Floor(
957+
genx::getExecSizeAllowedBits(WrR, &ST)))
958+
return nullptr;
959+
IRBuilder<> IRB(WrR);
960+
auto *OldValCast =
961+
IRB.CreateBitCast(OldVal, NewVecTy, OldVal->getName() + ".cast");
962+
auto *NewWrR = R.createWrRegion(OldValCast, BCI->getOperand(0),
963+
WrR->getName(), WrR, WrR->getDebugLoc());
964+
auto *NewBCI = IRB.CreateBitCast(NewWrR, WrR->getType(), BCI->getName());
965+
return NewBCI;
966+
}
928967

968+
static Value *simplifyRegionWrite(Instruction *WrR) {
969+
using namespace GenXIntrinsic::GenXRegion;
970+
IGC_ASSERT(GenXIntrinsic::isWrRegion(WrR));
971+
Value *NewVal = WrR->getOperand(NewValueOperandNum);
972+
973+
// Replace C with B if R - whole region
974+
// C = wrregion(A, B, R)
975+
if (std::none_of(
976+
WrR->user_begin(), WrR->user_end(),
977+
[](auto *U) { return GenXIntrinsic::isWritePredefReg(U); }) &&
978+
makeRegionFromBaleInfo(WrR, BaleInfo()).isWhole(WrR->getType()) &&
979+
NewVal->getType() == WrR->getType())
980+
return NewVal;
929981
// Replace C with A
930982
// C = wrregion(A, undef, R)
931983
if (isa<UndefValue>(NewVal))
932-
return Inst->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
984+
return WrR->getOperand(OldValueOperandNum);
933985

934986
// When A and undef have the same type, replace C with A
935987
// B = rdregion(A, R)
@@ -941,29 +993,68 @@ static Value *simplifyRegionWrite(Instruction *Inst) {
941993
// C = wrregion(A, B, R)
942994
//
943995
if (GenXIntrinsic::isRdRegion(NewVal)) {
944-
Instruction *B = cast<Instruction>(NewVal);
945-
Region InnerR = makeRegionFromBaleInfo(B, BaleInfo());
946-
Region OuterR = makeRegionFromBaleInfo(Inst, BaleInfo());
996+
Instruction *RdR = cast<Instruction>(NewVal);
997+
Region InnerR = makeRegionFromBaleInfo(RdR, BaleInfo());
998+
Region OuterR = makeRegionFromBaleInfo(WrR, BaleInfo());
947999
if (OuterR != InnerR)
9481000
return nullptr;
9491001

950-
auto OldValB = B->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
951-
if (GenXIntrinsic::isReadPredefReg(OldValB))
1002+
auto OldValRdR = RdR->getOperand(OldValueOperandNum);
1003+
if (GenXIntrinsic::isReadPredefReg(OldValRdR))
9521004
return nullptr;
953-
auto OldValC = Inst->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
954-
if ((isa<UndefValue>(OldValC) &&
955-
OldValB->getType() == OldValC->getType()) ||
956-
OldValB == OldValC)
957-
return OldValB;
1005+
auto OldValWrR = WrR->getOperand(OldValueOperandNum);
1006+
if ((isa<UndefValue>(OldValWrR) &&
1007+
OldValRdR->getType() == OldValWrR->getType()) ||
1008+
OldValRdR == OldValWrR)
1009+
return OldValRdR;
9581010
}
959-
9601011
return nullptr;
9611012
}
9621013

1014+
// fold bitcast with rdregion:
1015+
// %2 = rdregion(%1, ...) ==> %3 = bitcast(%1)
1016+
// %3 = bitcast(%2) %2 = rdregion(%3, ...)
1017+
// so it can be baled later.
1018+
static Value *simplifyBitCastFromRegionRead(BitCastInst *BCI,
1019+
const DataLayout &DL,
1020+
const GenXSubtarget &ST) {
1021+
using namespace GenXIntrinsic::GenXRegion;
1022+
Instruction *RdR = dyn_cast<Instruction>(BCI->getOperand(0));
1023+
if (!RdR || !GenXIntrinsic::isRdRegion(RdR) || !RdR->hasOneUse())
1024+
return nullptr;
1025+
auto *OldVal = RdR->getOperand(OldValueOperandNum);
1026+
if (GenXIntrinsic::isReadPredefReg(OldVal))
1027+
return nullptr;
1028+
auto *NewScalarTy = BCI->getDestTy()->getScalarType();
1029+
// Do not change register category to predicate.
1030+
if (NewScalarTy->isIntegerTy(1))
1031+
return nullptr;
1032+
auto *NewVecTy = genx::changeVectorType(OldVal->getType(), NewScalarTy);
1033+
if (!NewVecTy)
1034+
return nullptr;
1035+
Region R = makeRegionFromBaleInfo(RdR, BaleInfo());
1036+
if (!R.changeElementType(NewScalarTy, &DL))
1037+
return nullptr;
1038+
// Transformation is not profitable for 2D regions or if it will require
1039+
// legalization.
1040+
if (R.is2D() || R.NumElements > llvm::PowerOf2Floor(
1041+
genx::getExecSizeAllowedBits(RdR, &ST)))
1042+
return nullptr;
1043+
auto *NewBCI =
1044+
IRBuilder<>(BCI).CreateBitCast(OldVal, NewVecTy, BCI->getName());
1045+
auto *NewRdR =
1046+
R.createRdRegion(NewBCI, RdR->getName(), BCI, RdR->getDebugLoc());
1047+
return NewRdR;
1048+
}
1049+
9631050
static Value *simplifyRegionRead(Instruction *Inst) {
9641051
IGC_ASSERT(GenXIntrinsic::isRdRegion(Inst));
9651052
Value *Input = Inst->getOperand(GenXIntrinsic::GenXRegion::OldValueOperandNum);
966-
if (isa<UndefValue>(Input))
1053+
if (!GenXIntrinsic::isReadPredefReg(Input) &&
1054+
makeRegionFromBaleInfo(Inst, BaleInfo()).isWhole(Input->getType()) &&
1055+
Input->getType() == Inst->getType())
1056+
return Input;
1057+
else if (isa<UndefValue>(Input))
9671058
return UndefValue::get(Inst->getType());
9681059
else if (auto C = dyn_cast<Constant>(Input)) {
9691060
if (auto Splat = C->getSplatValue()) {
@@ -990,7 +1081,8 @@ static Value *simplifyRegionRead(Instruction *Inst) {
9901081
}
9911082

9921083
// Simplify a region read or write.
993-
Value *llvm::genx::simplifyRegionInst(Instruction *Inst, const DataLayout *DL) {
1084+
Value *llvm::genx::simplifyRegionInst(Instruction *Inst, const DataLayout *DL,
1085+
const GenXSubtarget *ST) {
9941086
if (Inst->use_empty())
9951087
return nullptr;
9961088

@@ -1013,11 +1105,17 @@ Value *llvm::genx::simplifyRegionInst(Instruction *Inst, const DataLayout *DL) {
10131105
if (Constant *C = ConstantFoldGenX(Inst, *DL))
10141106
return C;
10151107

1108+
if (auto *BCI = dyn_cast<BitCastInst>(Inst); BCI && DL && ST)
1109+
return simplifyBitCastFromRegionRead(BCI, *DL, *ST);
10161110
ID = GenXIntrinsic::getGenXIntrinsicID(Inst);
10171111
switch (ID) {
10181112
case GenXIntrinsic::genx_wrregionf:
10191113
case GenXIntrinsic::genx_wrregioni:
1020-
return simplifyRegionWrite(Inst);
1114+
if (auto *Res = simplifyRegionWrite(Inst))
1115+
return Res;
1116+
if (DL && ST)
1117+
return simplifyBitCastWithRegionWrite(Inst, *DL, *ST);
1118+
break;
10211119
case GenXIntrinsic::genx_rdregionf:
10221120
case GenXIntrinsic::genx_rdregioni:
10231121
return simplifyRegionRead(Inst);
@@ -1027,12 +1125,13 @@ Value *llvm::genx::simplifyRegionInst(Instruction *Inst, const DataLayout *DL) {
10271125
return nullptr;
10281126
}
10291127

1030-
bool llvm::genx::simplifyRegionInsts(Function *F, const DataLayout *DL) {
1128+
bool llvm::genx::simplifyRegionInsts(Function *F, const DataLayout *DL,
1129+
const GenXSubtarget *ST) {
10311130
bool Changed = false;
10321131
for (auto &BB : F->getBasicBlockList()) {
10331132
for (auto I = BB.begin(); I != BB.end();) {
10341133
Instruction *Inst = &*I++;
1035-
if (auto V = simplifyRegionInst(Inst, DL)) {
1134+
if (auto V = simplifyRegionInst(Inst, DL, ST)) {
10361135
Inst->replaceAllUsesWith(V);
10371136
Inst->eraseFromParent();
10381137
Changed = true;

IGC/VectorCompiler/lib/GenXCodeGen/GenXRegionUtils.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,10 @@ inline raw_ostream &operator<<(raw_ostream &OS, const RdWrRegionSequence &RWS) {
162162
return OS;
163163
}
164164

165-
Value *simplifyRegionInst(Instruction *Inst, const DataLayout *DL);
166-
bool simplifyRegionInsts(Function *F, const DataLayout *DL);
165+
Value *simplifyRegionInst(Instruction *Inst, const DataLayout *DL = nullptr,
166+
const GenXSubtarget *ST = nullptr);
167+
bool simplifyRegionInsts(Function *F, const DataLayout *DL = nullptr,
168+
const GenXSubtarget *ST = nullptr);
167169

168170
bool cleanupLoads(Function *F);
169171

0 commit comments

Comments
 (0)