Skip to content

Commit 4444c1a

Browse files
itetyush-inteligcbot
authored andcommitted
Using cloneInstWithNewOps instead of switch in GenXVectorCombiner
1 parent 1a25c55 commit 4444c1a

File tree

4 files changed

+85
-33
lines changed

4 files changed

+85
-33
lines changed

IGC/VectorCompiler/include/vc/Utils/General/Types.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ llvm::Type *changeAddrSpace(llvm::Type *OrigTy, int AddrSpace);
3232
// Get addrspace of a pointer or a vector of pointers type.
3333
int getAddrSpace(llvm::Type *PtrOrPtrVec);
3434

35+
// calculates new return type for cast instructions
36+
// * trunc
37+
// * bitcast
38+
llvm::Type *getNewTypeForCast(llvm::Type *OldOutType, llvm::Type *OldInType,
39+
llvm::Type *NewInType);
40+
3541
// If \p Ty is degenerate vector type <1 x ElTy>,
3642
// ElTy is returned, otherwise original type \p Ty is returned.
3743
const llvm::Type &fixDegenerateVectorType(const llvm::Type &Ty);

IGC/VectorCompiler/lib/GenXCodeGen/GenXVectorCombiner.cpp

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ SPDX-License-Identifier: MIT
2727
#include "GenXIntrinsics.h"
2828
#include "GenXModule.h"
2929
#include "vc/GenXOpts/Utils/CMRegion.h"
30+
#include "vc/Utils/General/InstRebuilder.h"
3031

3132
#include "IGC/common/debug/DebugMacros.hpp"
3233

@@ -321,9 +322,9 @@ void GenXVectorCombiner::filterWorkList() {
321322
void GenXVectorCombiner::createNewInstruction(
322323
Instruction *InsteadOf, Instruction *Operation,
323324
const SmallVectorImpl<Value *> &Vals) {
324-
IGC_ASSERT_MESSAGE(InsteadOf && Operation, "Error: nullptr");
325-
IRBuilder<> Builder{InsteadOf};
325+
IGC_ASSERT_MESSAGE(InsteadOf && Operation, "Error: nullptr input");
326326
if (GenXIntrinsic::isGenXIntrinsic(Operation)) {
327+
IRBuilder<> Builder{InsteadOf};
327328
Function *Fn = nullptr;
328329
GenXIntrinsic::ID IdCode = GenXIntrinsic::getGenXIntrinsicID(Operation);
329330
Module *M = Operation->getParent()->getParent()->getParent();
@@ -341,26 +342,10 @@ void GenXVectorCombiner::createNewInstruction(
341342
InsteadOf->replaceAllUsesWith(CI);
342343
return;
343344
}
344-
unsigned OpCode = Operation->getOpcode();
345-
Value *NewInst = nullptr;
346-
switch (OpCode) {
347-
default:
348-
IGC_ASSERT_MESSAGE(false, "get unknown opcode");
349-
case Instruction::Add:
350-
NewInst = Builder.CreateAdd(Vals[0], Vals[1], VALUE_NAME("widenedAdd"));
351-
break;
352-
case Instruction::FAdd:
353-
NewInst = Builder.CreateFAdd(Vals[0], Vals[1], VALUE_NAME("widenedFAdd"));
354-
break;
355-
case Instruction::BitCast:
356-
NewInst = Builder.CreateBitCast(Vals[0], InsteadOf->getType(),
357-
VALUE_NAME("widenedBitcast"));
358-
break;
359-
case Instruction::Trunc:
360-
NewInst = Builder.CreateTrunc(Vals[0], InsteadOf->getType(),
361-
VALUE_NAME("widenedTrunc"));
362-
break;
363-
}
345+
Instruction *NewInst = vc::cloneInstWithNewOps(*Operation, Vals);
346+
NewInst->insertBefore(InsteadOf);
347+
NewInst->setDebugLoc(InsteadOf->getDebugLoc());
348+
NewInst->takeName(Operation);
364349
InsteadOf->replaceAllUsesWith(NewInst);
365350
}
366351

IGC/VectorCompiler/lib/Utils/General/InstRebuilder.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ SPDX-License-Identifier: MIT
1414

1515
#include <llvm/ADT/ArrayRef.h>
1616
#include <llvm/IR/InstVisitor.h>
17+
#include <llvm/IR/Instruction.h>
1718
#include <llvm/IR/IntrinsicInst.h>
1819

1920
#include <algorithm>
@@ -84,6 +85,13 @@ class cloneInstWithNewOpsImpl
8485
NewOperands.drop_front());
8586
}
8687

88+
Instruction *visitTrunc(TruncInst &Trunc) {
89+
Value &NewOp = getSingleNewOperand();
90+
Type *NewOutType = vc::getNewTypeForCast(
91+
Trunc.getType(), Trunc.getOperand(0)->getType(), NewOp.getType());
92+
return new TruncInst{&NewOp, NewOutType};
93+
}
94+
8795
Instruction *visitLoadInst(LoadInst &OrigLoad) {
8896
Value &Ptr = getSingleNewOperand();
8997
auto *NewLoad =
@@ -108,17 +116,9 @@ class cloneInstWithNewOpsImpl
108116
// type addrspace corresponds with this operand.
109117
CastInst *visitBitCastInst(BitCastInst &OrigCast) {
110118
Value &NewOp = getSingleNewOperand();
111-
// If the operand changed addrspace the bitcast type should change it too.
112-
if (OrigCast.getType()->isPtrOrPtrVectorTy())
113-
return visitPtrOrPtrVectorBitCastInst(OrigCast);
114-
return new BitCastInst{&NewOp, OrigCast.getType()};
115-
}
116-
117-
CastInst *visitPtrOrPtrVectorBitCastInst(BitCastInst &OrigCast) {
118-
Value &NewOp = getSingleNewOperand();
119-
auto NewOpAS = getAddrSpace(NewOp.getType());
120-
return new BitCastInst{&NewOp,
121-
changeAddrSpace(OrigCast.getType(), NewOpAS)};
119+
Type *NewOutType = vc::getNewTypeForCast(
120+
OrigCast.getType(), OrigCast.getOperand(0)->getType(), NewOp.getType());
121+
return new BitCastInst{&NewOp, NewOutType};
122122
}
123123

124124
CastInst *visitAddrSpaceCastInst(AddrSpaceCastInst &OrigCast) {

IGC/VectorCompiler/lib/Utils/General/Types.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ SPDX-License-Identifier: MIT
77
============================= end_copyright_notice ===========================*/
88

99
#include "vc/Utils/General/Types.h"
10+
#include "llvmWrapper/Support/TypeSize.h"
1011

1112
#include "Probe/Assertion.h"
1213

@@ -57,3 +58,63 @@ Type &vc::fixDegenerateVectorType(Type &Ty) {
5758
return const_cast<Type &>(
5859
fixDegenerateVectorType(static_cast<const Type &>(Ty)));
5960
}
61+
62+
// calculates new return type for cast instructions
63+
// * trunc
64+
// * bitcast
65+
// Expect that scalar type of instruction not changed and previous
66+
// combination of OldOutType & OldInType is valid
67+
Type *vc::getNewTypeForCast(Type *OldOutType, Type *OldInType,
68+
Type *NewInType) {
69+
IGC_ASSERT_MESSAGE(OldOutType && NewInType && OldInType,
70+
"Error: nullptr input");
71+
72+
bool NewInIsVec = isa<IGCLLVM::FixedVectorType>(NewInType);
73+
bool OldOutIsVec = isa<IGCLLVM::FixedVectorType>(OldOutType);
74+
bool OldInIsVec = isa<IGCLLVM::FixedVectorType>(OldInType);
75+
76+
bool NewInIsPtrOrVecPtr = NewInType->isPtrOrPtrVectorTy();
77+
bool OldOutIsPtrOrVecPtr = OldOutType->isPtrOrPtrVectorTy();
78+
bool OldInIsPtrOrVecPtr = OldInType->isPtrOrPtrVectorTy();
79+
80+
// only pointer to pointer
81+
IGC_ASSERT(NewInIsPtrOrVecPtr == OldOutIsPtrOrVecPtr &&
82+
NewInIsPtrOrVecPtr == OldInIsPtrOrVecPtr);
83+
// <2 x char> -> int : < 4 x char> -> ? forbidden
84+
IGC_ASSERT(OldOutIsVec == OldInIsVec && OldOutIsVec == NewInIsVec);
85+
Type *NewOutType = OldOutType;
86+
if (OldOutIsVec) {
87+
// <4 x char> -> <2 x int> : <8 x char> -> <4 x int>
88+
// <4 x char> -> <2 x int> : <2 x char> -> <1 x int>
89+
auto NewInEC = cast<IGCLLVM::FixedVectorType>(NewInType)->getNumElements();
90+
auto OldOutEC =
91+
cast<IGCLLVM::FixedVectorType>(OldOutType)->getNumElements();
92+
auto OldInEC = cast<IGCLLVM::FixedVectorType>(OldInType)->getNumElements();
93+
auto NewOutEC = OldOutEC * NewInEC / OldInEC;
94+
// <4 x char> -> <2 x int> : <5 x char> -> ? forbidden
95+
IGC_ASSERT_MESSAGE((OldOutEC * NewInEC) % OldInEC == 0,
96+
"Error: wrong combination of input/output");
97+
// element count changed, scalar type as previous
98+
NewOutType = IGCLLVM::FixedVectorType::get(
99+
OldOutType->getVectorElementType(), IGCLLVM::getElementCount(NewOutEC));
100+
}
101+
102+
IGC_ASSERT(NewOutType);
103+
104+
if (NewInIsPtrOrVecPtr) {
105+
// <4 x char*> -> <2 x half*> : < 2 x int*> - ? forbidden
106+
// char* -> half* : int* -> ? forbidden
107+
IGC_ASSERT_MESSAGE(OldInType->getScalarType()->getPointerElementType() ==
108+
NewInType->getScalarType()->getPointerElementType(),
109+
"Error: unexpected type change");
110+
// address space from new
111+
// element count calculated as for vector
112+
// element type expect address space similar
113+
auto AddressSpace = getAddrSpace(NewInType);
114+
return changeAddrSpace(NewOutType, AddressSpace);
115+
}
116+
// <4 x char> -> <2 x half> : < 2 x int> - ? forbiddeb
117+
IGC_ASSERT_MESSAGE(OldInType->getScalarType() == NewInType->getScalarType(),
118+
"Error: unexpected type change");
119+
return NewOutType;
120+
}

0 commit comments

Comments
 (0)