Skip to content

Commit 44d81c0

Browse files
aparshin-intelsys_zuul
authored andcommitted
refactor GenXLowering for mul64
Change-Id: I158ac50a72edc170eee9727319740f374af350b1
1 parent c77111d commit 44d81c0

File tree

3 files changed

+120
-51
lines changed

3 files changed

+120
-51
lines changed

IGC/VectorCompiler/lib/GenXCodeGen/GenXLowering.cpp

Lines changed: 21 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ using namespace genx;
138138
static cl::opt<bool>
139139
EnableGenXByteWidening("enable-genx-byte-widening", cl::init(true),
140140
cl::Hidden, cl::desc("Enable GenX byte widening."));
141-
142141
namespace {
143142

144143
// GenXLowering : legalize execution widths and GRF crossing
@@ -2391,64 +2390,35 @@ bool GenXLowering::lowerFCmpInst(FCmpInst *Inst) {
23912390

23922391
// Lower cmp instructions that GenX cannot deal with.
23932392
bool GenXLowering::lowerMul64(Instruction *Inst) {
2393+
2394+
LoHiSplitter SplitBuilder(*Inst);
2395+
if (!SplitBuilder.IsI64Operation())
2396+
return false;
2397+
23942398
IRBuilder<> Builder(Inst);
23952399
Builder.SetCurrentDebugLocation(Inst->getDebugLoc());
2396-
auto Src0 = Inst->getOperand(0);
2397-
auto Src1 = Inst->getOperand(1);
2398-
auto ETy = Src0->getType();
2399-
auto Len = 1;
2400-
if (ETy->isVectorTy()) {
2401-
Len = ETy->getVectorNumElements();
2402-
ETy = ETy->getVectorElementType();
2403-
}
2404-
if (!ETy->isIntegerTy() || ETy->getPrimitiveSizeInBits() != 64)
2405-
return false;
2406-
auto VTy = VectorType::get(ETy->getInt32Ty(Inst->getContext()), Len * 2);
2407-
// create src0 bitcast, then the low and high part
2408-
auto Src0V = Builder.CreateBitCast(Src0, VTy);
2409-
Region R(Inst);
2410-
R.Offset = 0;
2411-
R.Width = Len;
2412-
R.NumElements = Len;
2413-
R.Stride = 2;
2414-
R.VStride = 0;
2415-
auto Src0L = R.createRdRegion(Src0V, "", Inst, Inst->getDebugLoc());
2416-
R.Offset = 4;
2417-
auto Src0H = R.createRdRegion(Src0V, "", Inst, Inst->getDebugLoc());
2418-
// create src1 bitcast, then the low and high part
2419-
auto Src1V = Builder.CreateBitCast(Src1, VTy);
2420-
R.Offset = 0;
2421-
auto Src1L = R.createRdRegion(Src1V, "", Inst, Inst->getDebugLoc());
2422-
R.Offset = 4;
2423-
auto Src1H = R.createRdRegion(Src1V, "", Inst, Inst->getDebugLoc());
2400+
2401+
auto Src0 = SplitBuilder.splitOperand(0);
2402+
auto Src1 = SplitBuilder.splitOperand(1);
2403+
24242404
// create muls and adds
2425-
auto ResL = Builder.CreateMul(Src0L, Src1L);
2405+
auto *ResL = Builder.CreateMul(Src0.Lo, Src1.Lo);
24262406
// create the mulh intrinsic to the get the carry-part
2427-
Type *tys[2];
2428-
SmallVector<llvm::Value *, 2> args;
2429-
// build type-list
2430-
tys[0] = ResL->getType();
2431-
tys[1] = Src0L->getType();
2407+
Type *tys[2] = {ResL->getType(), Src0.Lo->getType()};
24322408
// build argument list
2433-
args.push_back(Src0L);
2434-
args.push_back(Src1L);
2435-
auto M = Inst->getParent()->getParent()->getParent();
2409+
SmallVector<llvm::Value *, 2> args{Src0.Lo, Src1.Lo};
2410+
auto *M = Inst->getModule();
24362411
Function *IntrinFunc =
24372412
GenXIntrinsic::getGenXDeclaration(M, GenXIntrinsic::genx_umulh, tys);
2438-
Instruction *Cari = CallInst::Create(IntrinFunc, args, "", Inst);
2439-
Cari->setDebugLoc(Inst->getDebugLoc());
2440-
auto Temp0 = Builder.CreateMul(Src0L, Src1H);
2441-
auto Temp1 = Builder.CreateAdd(Cari, Temp0);
2442-
auto Temp2 = Builder.CreateMul(Src0H, Src1L);
2443-
auto ResH = Builder.CreateAdd(Temp2, Temp1);
2444-
// create the write-regions
2445-
auto UndefV = UndefValue::get(VTy);
2446-
R.Offset = 0;
2447-
auto WrL = R.createWrRegion(UndefV, ResL, "WrLow", Inst, Inst->getDebugLoc());
2448-
R.Offset = 4;
2449-
auto WrH = R.createWrRegion(WrL, ResH, "WrHigh", Inst, Inst->getDebugLoc());
2413+
2414+
auto *Cari = Builder.CreateCall(IntrinFunc, args, ".cari");
2415+
auto *Temp0 = Builder.CreateMul(Src0.Lo, Src1.Hi);
2416+
auto *Temp1 = Builder.CreateAdd(Cari, Temp0);
2417+
auto *Temp2 = Builder.CreateMul(Src0.Hi, Src1.Lo);
2418+
auto *ResH = Builder.CreateAdd(Temp2, Temp1);
2419+
24502420
// create the bitcast to the destination-type
2451-
auto Replace = Builder.CreateBitCast(WrH, Inst->getType(), "mul64");
2421+
auto *Replace = SplitBuilder.combineSplit(*ResL, *ResH, "mul64");
24522422
Inst->replaceAllUsesWith(Replace);
24532423
ToErase.push_back(Inst);
24542424
return true;

IGC/VectorCompiler/lib/GenXCodeGen/GenXUtil.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,71 @@ unsigned ShuffleVectorAnalyzer::getSerializeCost(unsigned i) {
760760
return Cost;
761761
}
762762

763+
LoHiSplitter::LoHiSplitter(Instruction &Inst, unsigned BaseOpIdx) : Inst(Inst) {
764+
765+
auto *Operand = Inst.getOperand(BaseOpIdx);
766+
ETy = Operand->getType();
767+
Len = 1;
768+
if (ETy->isVectorTy()) {
769+
Len = ETy->getVectorNumElements();
770+
ETy = ETy->getVectorElementType();
771+
}
772+
VI32Ty = VectorType::get(ETy->getInt32Ty(Inst.getContext()), Len * 2);
773+
}
774+
775+
Region LoHiSplitter::createSplitRegion(Type *Ty, LoHiSplitter::RegionType RT) {
776+
Region R(Ty);
777+
R.Width = Len;
778+
R.NumElements = Len;
779+
R.VStride = 0;
780+
// take every second element;
781+
R.Stride = 2;
782+
// offset is encoded in bytes
783+
R.Offset = (RT == RegionType::LoRegion) ? 0 : 4;
784+
return R;
785+
}
786+
787+
LoHiSplitter::Split LoHiSplitter::splitOperand(unsigned SourceIdx) {
788+
789+
const auto &DL = Inst.getDebugLoc();
790+
auto Name = Inst.getName();
791+
792+
assert(Inst.getNumOperands() > SourceIdx);
793+
auto *Src = Inst.getOperand(SourceIdx);
794+
assert(Src->getType()->getScalarType()->isIntegerTy(64));
795+
796+
auto *ShreddedSrc = new BitCastInst(Src, VI32Ty, Name + ".iv32cast", &Inst);
797+
ShreddedSrc->setDebugLoc(DL);
798+
799+
auto LoRegion = createSplitRegion(VI32Ty, RegionType::LoRegion);
800+
auto *L = LoRegion.createRdRegion(ShreddedSrc, Name + ".lsplit", &Inst, DL);
801+
802+
auto HiRegion = createSplitRegion(VI32Ty, RegionType::HiRegion);
803+
auto *H = HiRegion.createRdRegion(ShreddedSrc, Name + ".rsplit", &Inst, DL);
804+
805+
return {L, H};
806+
}
807+
Value *LoHiSplitter::combineSplit(Value &L, Value &H, const Twine &Name) {
808+
809+
const auto &DL = Inst.getDebugLoc();
810+
811+
assert(L.getType() == H.getType() && L.getType()->isVectorTy() &&
812+
L.getType()->getVectorElementType()->isIntegerTy(32));
813+
814+
// create the write-regions
815+
auto LoRegion = createSplitRegion(VI32Ty, RegionType::LoRegion);
816+
auto *UndefV = UndefValue::get(VI32Ty);
817+
auto *WrL = LoRegion.createWrRegion(UndefV, &L, "WrLow", &Inst, DL);
818+
819+
auto HiRegion = createSplitRegion(VI32Ty, RegionType::HiRegion);
820+
auto *WrH = HiRegion.createWrRegion(WrL, &H, "WrHigh", &Inst, DL);
821+
822+
auto *V64Ty = VectorType::get(ETy->getInt64Ty(Inst.getContext()), Len);
823+
auto *Result = new BitCastInst(WrH, V64Ty, Name, &Inst);
824+
Result->setDebugLoc(DL);
825+
return Result;
826+
}
827+
763828
/***********************************************************************
764829
* adjustPhiNodesForBlockRemoval : adjust phi nodes when removing a block
765830
*

IGC/VectorCompiler/lib/GenXCodeGen/GenXUtil.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,40 @@ class ShuffleVectorAnalyzer {
257257
OperandRegionInfo getMaskRegionPrefix(int StartIdx);
258258
};
259259

260+
// class for splitting i64 (both vector and scalar) to subregions of i32 vectors
261+
// Used in GenxLowering and emulation routines
262+
class LoHiSplitter {
263+
Instruction &Inst;
264+
265+
Type *ETy = nullptr;
266+
Type *VI32Ty = nullptr;
267+
size_t Len = 0;
268+
269+
enum class RegionType { LoRegion, HiRegion };
270+
Region createSplitRegion(Type *Ty, RegionType RT);
271+
272+
public:
273+
struct Split {
274+
Value *Lo;
275+
Value *Hi;
276+
};
277+
278+
// Instruction is used as an insertion point, debug location source and
279+
// as a source of operands to split.
280+
// If BaseOpIdx indexes a scalar/vector operand of i64 type, then
281+
// IsI64Operation shall return true
282+
LoHiSplitter(Instruction &Inst, unsigned BaseOpIdx = 0);
283+
284+
// Splitted Operand is expected to be a scalar/vector of i64 type
285+
Split splitOperand(unsigned SourceIdx);
286+
287+
// Combined values are expected to be a vector of i32 of the same size
288+
Value *combineSplit(Value &L, Value &H, const Twine &Name);
289+
290+
// convinence method for quick sanity checking
291+
bool IsI64Operation() { return ETy->isIntegerTy(64); }
292+
};
293+
260294
// adjustPhiNodesForBlockRemoval : adjust phi nodes when removing a block
261295
void adjustPhiNodesForBlockRemoval(BasicBlock *Succ, BasicBlock *BB);
262296

0 commit comments

Comments
 (0)