Skip to content

Commit 603d683

Browse files
ViacheslavRbigcbot
authored andcommitted
Illegal type Legalizer improvement
Improve support of int types with width greater 64 bit in PeepholeTypeLegalizer.
1 parent 023bbde commit 603d683

File tree

3 files changed

+286
-42
lines changed

3 files changed

+286
-42
lines changed

IGC/Compiler/Legalizer/PeepholeTypeLegalizer.cpp

Lines changed: 185 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,23 @@ bool PeepholeTypeLegalizer::runOnFunction(Function& F) {
6868
ctx->platform.WaDisableD64ScratchMessage()) &&
6969
ctx->getModuleMetaData()->compOpt.UseScratchSpacePrivateMemory;
7070

71+
NonBitcastInstructionsLegalized = false;
72+
CastInst_ZExtWithIntermediateIllegalsEliminated = false;
73+
CastInst_TruncWithIntermediateIllegalsEliminated = false;
74+
Bitcast_BitcastWithIntermediateIllegalsEliminated = false;
75+
7176
IGCLLVM::IRBuilder<> builder(F.getContext());
7277
m_builder = &builder;
7378

7479
Changed = false;
7580
visit(F);
76-
if (Changed) {
77-
NonBitcastInstructionsLegalized = true;
78-
visit(F);
79-
CastInst_ZExtWithIntermediateIllegalsEliminated = true;
80-
visit(F);
81-
CastInst_TruncWithIntermediateIllegalsEliminated = true;
82-
visit(F);
83-
Bitcast_BitcastWithIntermediateIllegalsEliminated = true;
84-
}
81+
NonBitcastInstructionsLegalized = true;
82+
visit(F);
83+
CastInst_ZExtWithIntermediateIllegalsEliminated = true;
84+
visit(F);
85+
CastInst_TruncWithIntermediateIllegalsEliminated = true;
86+
visit(F);
87+
Bitcast_BitcastWithIntermediateIllegalsEliminated = true;
8588
return Changed;
8689
}
8790

@@ -110,37 +113,39 @@ void PeepholeTypeLegalizer::visitInstruction(Instruction& I) {
110113
return;
111114

112115
if (!I.getOperand(0)->getType()->isIntOrIntVectorTy() &&
113-
!dyn_cast<ExtractElementInst>(&I))
116+
!isa<ExtractElementInst>(&I))
114117
return; // Legalization for int types only or for extractelements
115118

116119
m_builder->SetInsertPoint(&I);
117120

118121
//Depending on the phase of legalization pass, call appropriate function
119122
if (!NonBitcastInstructionsLegalized) { // LEGALIZE ALUs first
120-
if (dyn_cast<PHINode>(&I)) {
123+
if (isa<PHINode>(&I)) {
121124
legalizePhiInstruction(I); // phi nodes and all incoming values
122125
}
123-
else if (dyn_cast<UnaryInstruction>(&I)) {
126+
else if (isa<UnaryInstruction>(&I)) {
124127
legalizeUnaryInstruction(I); // pointercast &/or load
125128
}
126-
else if (dyn_cast<ICmpInst>(&I) || dyn_cast<BinaryOperator>(&I) || dyn_cast<SelectInst>(&I)) {
129+
else if (isa<ICmpInst>(&I) || isa<BinaryOperator>(&I) || isa<SelectInst>(&I)) {
127130
legalizeBinaryOperator(I); // Bitwise and Arithmetic Operations
128131
}
129-
else if (dyn_cast<ExtractElementInst>(&I)) {
132+
else if (isa<ExtractElementInst>(&I)) {
130133
legalizeExtractElement(I);
131134
}
132135
}
133136
else if (!CastInst_ZExtWithIntermediateIllegalsEliminated) { // Eliminate intermediate ILLEGAL operands in bitcast-zext or trunc-zext pairs
134-
if (dyn_cast<ZExtInst>(&I))
137+
if (isa<ZExtInst>(&I))
135138
cleanupZExtInst(I);
136139
}
137140
else if (!CastInst_TruncWithIntermediateIllegalsEliminated) { // Eliminate intermediate ILLEGAL operands in bitcast-zext or trunc-zext pairs
138-
if (dyn_cast<TruncInst>(&I))
141+
if (isa<TruncInst>(&I))
139142
cleanupTruncInst(I);
140143
}
141144
else if (!Bitcast_BitcastWithIntermediateIllegalsEliminated) { // Eliminate redundant bitcast-bitcast pairs and eliminate intermediate ILLEGAL operands in bitcast-bitcast pairs with src == dest OR src != dest
142-
if (dyn_cast<BitCastInst>(&I))
145+
if (isa<BitCastInst>(&I))
143146
cleanupBitCastInst(I);
147+
if (isa<TruncInst>(&I))
148+
cleanupBitCastTruncInst(I);
144149
}
145150
}
146151

@@ -779,6 +784,7 @@ void PeepholeTypeLegalizer::legalizeUnaryInstruction(Instruction& I) {
779784
// %4 = extractelement %1, 1
780785
// %5 = insertelement %3, %4, 1
781786
// %6 = bitcast <2 x i64> %5 to i128
787+
782788
unsigned dstSize = I.getType()->getScalarSizeInBits();
783789
unsigned srcSize = I.getOperand(0)->getType()->getScalarSizeInBits();
784790

@@ -801,18 +807,20 @@ void PeepholeTypeLegalizer::legalizeUnaryInstruction(Instruction& I) {
801807
return;
802808
}
803809

804-
unsigned numSrcElements = srcSize / promotedInt;
805-
unsigned numDstElements = dstSize / promotedInt;
810+
unsigned numSrcElements = static_cast<unsigned>(I.getOperand(0)->getType()->getPrimitiveSizeInBits() / promotedInt);
811+
unsigned numDstElements = static_cast<unsigned>(I.getType()->getPrimitiveSizeInBits() / promotedInt);
806812
Type* srcVecTy = IGCLLVM::FixedVectorType::get(Type::getIntNTy(I.getContext(), promotedInt), numSrcElements);
807813
Type* dstVecTy = IGCLLVM::FixedVectorType::get(Type::getIntNTy(I.getContext(), promotedInt), numDstElements);
808814

809815
// Bitcast the illegal src type to a legal vector
810816
Value* srcVec = m_builder->CreateBitCast(I.getOperand(0), srcVecTy);
811817
Value* dstVec = UndefValue::get(dstVecTy);
818+
unsigned numElements = I.getType()->isVectorTy() ? (unsigned)cast<IGCLLVM::FixedVectorType>(I.getType())->getNumElements() : 1;
812819

813820
for (unsigned i = 0; i < numDstElements; i++)
814821
{
815-
Value* v = m_builder->CreateExtractElement(srcVec, m_builder->getInt32(i));
822+
Value* v = m_builder->CreateExtractElement(srcVec, m_builder->getInt32((i / (numDstElements / numElements)) *
823+
(numSrcElements / numElements) + (i % (numDstElements / numElements))));
816824
dstVec = m_builder->CreateInsertElement(dstVec, v, m_builder->getInt32(i));
817825
}
818826
// Cast back to original dst type
@@ -907,6 +915,7 @@ void PeepholeTypeLegalizer::cleanupZExtInst(Instruction& I) {
907915
}
908916
else {
909917
// this is a place holder, but DO NOT expect to need an implementation for this case.
918+
IGC_ASSERT_MESSAGE(0, "Not yet implemented");
910919
}
911920
}
912921
else {
@@ -942,6 +951,7 @@ void PeepholeTypeLegalizer::cleanupZExtInst(Instruction& I) {
942951
}
943952
else { // (promoteToInt*quotient != Src1width) case
944953
// No support yet
954+
IGC_ASSERT_MESSAGE(0, "Not yet implemented");
945955
}
946956
}
947957
break;
@@ -1135,23 +1145,26 @@ void PeepholeTypeLegalizer::cleanupTruncInst(Instruction& I) {
11351145
}
11361146
else
11371147
{
1138-
for (Value::use_iterator UI = I.use_begin(), UE = I.use_end(); UI != UE; ++UI) {
1139-
if (TruncInst* useTrunc = dyn_cast<TruncInst>(UI->getUser()))
1140-
{
1141-
IGC_ASSERT(I.getType()->getScalarSizeInBits() > useTrunc->getType()->getScalarSizeInBits());
1142-
auto newTrunc = dyn_cast<TruncInst>(m_builder->CreateTrunc(I.getOperand(0), useTrunc->getType()));
1143-
useTrunc->replaceAllUsesWith(newTrunc);
1144-
// Commented out for now because it breaks use_iterator
1145-
// useTrunc->eraseFromParent();
1146-
cleanupTruncInst(*newTrunc);
1147-
Changed = true;
1148-
}
1148+
if (TruncInst* prevTrunc = dyn_cast<TruncInst>(I.getOperand(0)))
1149+
{
1150+
//Example:
1151+
//%1 = trunc i96 %in to i65
1152+
//%out = trunc i65 %1 to i64
1153+
//=>
1154+
//%out = trunc i96 %in to i64
1155+
IGC_ASSERT(prevTrunc->getType()->getScalarSizeInBits() > I.getType()->getScalarSizeInBits());
1156+
auto* newTrunc = cast<TruncInst>(m_builder->CreateTrunc(prevTrunc->getOperand(0), I.getType()));
1157+
I.replaceAllUsesWith(newTrunc);
1158+
Changed = true;
11491159
}
11501160
}
11511161

11521162
if (I.use_empty())
11531163
{
1164+
Instruction* prevInst = dyn_cast<Instruction>(I.getOperand(0));
11541165
I.eraseFromParent();
1166+
if (prevInst && prevInst->use_empty())
1167+
prevInst->eraseFromParent();
11551168
Changed = true;
11561169
}
11571170

@@ -1160,23 +1173,22 @@ void PeepholeTypeLegalizer::cleanupTruncInst(Instruction& I) {
11601173

11611174
void PeepholeTypeLegalizer::cleanupBitCastInst(Instruction& I) {
11621175

1163-
/*
1164-
Need to handle:
1165-
1. bitcast
1166-
2. bitcast addrspace*
1167-
1168-
a. bitcast iSrc , iILLEGAL
1169-
bitcast iILLEGAL, iSrc
1170-
b. bitcast iSrc, iILLEGAL
1171-
bitcast iILLEGAL, iLEGAL
1172-
*/
1173-
11741176
Instruction* prevInst = dyn_cast<Instruction>(I.getOperand(0));
11751177
if (!prevInst)
11761178
return;
11771179
switch (prevInst->getOpcode()) {
11781180
case Instruction::BitCast:
11791181
{
1182+
/*
1183+
Need to handle:
1184+
1. bitcast
1185+
2. bitcast addrspace*
1186+
1187+
a. bitcast iSrc , iILLEGAL
1188+
bitcast iILLEGAL, iSrc
1189+
b. bitcast iSrc, iILLEGAL
1190+
bitcast iILLEGAL, iLEGAL
1191+
*/
11801192
Type* srcType = prevInst->getOperand(0)->getType();
11811193
Type* dstType = I.getType();
11821194
if (srcType == dstType)
@@ -1229,7 +1241,138 @@ void PeepholeTypeLegalizer::cleanupBitCastInst(Instruction& I) {
12291241
}
12301242
break;
12311243
}
1244+
case Instruction::ZExt:
1245+
{
1246+
Type* srcType = prevInst->getOperand(0)->getType();
1247+
Type* midType = prevInst->getType();
1248+
Type* dstType = I.getType();
1249+
if (isLegalInteger(srcType->getScalarSizeInBits())
1250+
&& !isLegalInteger(midType->getScalarSizeInBits())
1251+
&& isLegalInteger(dstType->getScalarSizeInBits()))
1252+
{
1253+
m_builder->SetInsertPoint(&I);
1254+
1255+
IGC_ASSERT_MESSAGE(midType->getScalarSizeInBits() % 8 == 0, "Unexpected type");
1256+
int interimTypeBitWidth = DL->getLargestLegalIntTypeSizeInBits();
1257+
for (; interimTypeBitWidth >= 8; interimTypeBitWidth /= 2)
1258+
{
1259+
if (srcType->getScalarSizeInBits() % interimTypeBitWidth == 0
1260+
&& midType->getScalarSizeInBits() % interimTypeBitWidth == 0)
1261+
break;
1262+
}
1263+
Value* newInVecValue = prevInst->getOperand(0);
1264+
if (srcType->getScalarSizeInBits() != interimTypeBitWidth)
1265+
{
1266+
Type* newInVecType = IGCLLVM::FixedVectorType::get(Type::getIntNTy(I.getContext(), interimTypeBitWidth),
1267+
static_cast<unsigned>(srcType->getPrimitiveSizeInBits() / interimTypeBitWidth));
1268+
newInVecValue = m_builder->CreateBitCast(newInVecValue, newInVecType);
1269+
}
1270+
Value* newExtVec = UndefValue::get(IGCLLVM::FixedVectorType::get(Type::getIntNTy(I.getContext(),
1271+
interimTypeBitWidth), static_cast<unsigned>(dstType->getPrimitiveSizeInBits() / interimTypeBitWidth)));
1272+
unsigned numElements = srcType->isVectorTy() ? (unsigned)cast<IGCLLVM::FixedVectorType>(srcType)->getNumElements() : 1;
1273+
unsigned newInQuotient = srcType->getScalarSizeInBits() / interimTypeBitWidth;
1274+
unsigned extQuotient = static_cast<unsigned>(dstType->getPrimitiveSizeInBits() / numElements / interimTypeBitWidth);
1275+
auto zero = ConstantInt::get(IntegerType::get(I.getContext(), interimTypeBitWidth), 0, false);
1276+
for (unsigned i = 0; i < numElements; i++) {
1277+
for (unsigned k = 0; k < newInQuotient; k++) {
1278+
Value* extractedVal = m_builder->CreateExtractElement(newInVecValue, m_builder->getInt32(newInQuotient * i + k));
1279+
newExtVec = m_builder->CreateInsertElement(newExtVec, extractedVal, m_builder->getInt32(extQuotient * i + k));
1280+
}
1281+
for (unsigned k = newInQuotient; k < extQuotient; k++) {
1282+
newExtVec = m_builder->CreateInsertElement(newExtVec, zero, m_builder->getInt32(extQuotient * i + k));
1283+
}
1284+
}
1285+
if (dstType->getScalarSizeInBits() != newExtVec->getType()->getScalarSizeInBits())
1286+
{
1287+
newExtVec = m_builder->CreateBitCast(newExtVec, dstType);
1288+
}
1289+
I.replaceAllUsesWith(newExtVec);
1290+
I.eraseFromParent();
1291+
if (prevInst->use_empty())
1292+
{
1293+
prevInst->eraseFromParent();
1294+
}
1295+
Changed = true;
1296+
}
1297+
else
1298+
{
1299+
IGC_ASSERT_MESSAGE(isLegalInteger(srcType->getScalarSizeInBits())
1300+
&& isLegalInteger(midType->getScalarSizeInBits()),
1301+
"Unexpected illegal type width");
1302+
}
1303+
break;
1304+
}
12321305
default:
1306+
IGC_ASSERT_MESSAGE(isLegalInteger(I.getOperand(0)->getType()->getScalarSizeInBits()),
1307+
"Unexpected illegal type width");
12331308
break;
12341309
}
12351310
}
1311+
1312+
void PeepholeTypeLegalizer::cleanupBitCastTruncInst(Instruction& I) {
1313+
1314+
BitCastInst* bitCastInst = dyn_cast<BitCastInst>(I.getOperand(0));
1315+
if (!bitCastInst)
1316+
return;
1317+
1318+
if (isLegalInteger(bitCastInst->getOperand(0)->getType()->getScalarSizeInBits())
1319+
&& !isLegalInteger(bitCastInst->getType()->getScalarSizeInBits()) &&
1320+
isLegalInteger(I.getType()->getScalarSizeInBits()))
1321+
{
1322+
/*
1323+
Example:
1324+
%2 = bitcast <3 x i32> %0 to <2 x i48>
1325+
%3 = trunc <2 x i48> %2 to <2 x i16>
1326+
=>
1327+
%2 = bitcast <3 x i32> %0 to <6 x i16>
1328+
%3 = extractelement <6 x i16> %2, i32 0
1329+
%4 = insertelement <2 x i16> undef, i16 %3, i32 0
1330+
%5 = extractelement <6 x i16> %2, i32 3
1331+
%6 = insertelement <2 x i16> %4, i16 %5, i32 1
1332+
*/
1333+
1334+
m_builder->SetInsertPoint(&I);
1335+
1336+
Type* srcType = bitCastInst->getOperand(0)->getType();
1337+
Type* midType = bitCastInst->getType();
1338+
Type* dstType = I.getType();
1339+
1340+
IGC_ASSERT_MESSAGE(midType->getScalarSizeInBits() % 8 == 0, "Unexpected type");
1341+
int interimTypeBitWidth = DL->getLargestLegalIntTypeSizeInBits();
1342+
for (; interimTypeBitWidth >= 8; interimTypeBitWidth /= 2)
1343+
{
1344+
if (midType->getScalarSizeInBits() % interimTypeBitWidth == 0
1345+
&& dstType->getScalarSizeInBits() % interimTypeBitWidth == 0)
1346+
break;
1347+
}
1348+
Value* newInVecValue = bitCastInst->getOperand(0);
1349+
if (srcType->getScalarSizeInBits() != interimTypeBitWidth)
1350+
{
1351+
Type* newInVecType = IGCLLVM::FixedVectorType::get(Type::getIntNTy(I.getContext(),
1352+
interimTypeBitWidth), static_cast<unsigned>(midType->getPrimitiveSizeInBits() / interimTypeBitWidth));
1353+
newInVecValue = m_builder->CreateBitCast(newInVecValue, newInVecType);
1354+
}
1355+
Value* newTruncVec = UndefValue::get(IGCLLVM::FixedVectorType::get(Type::getIntNTy(I.getContext(),
1356+
interimTypeBitWidth), static_cast<unsigned>(dstType->getPrimitiveSizeInBits() / interimTypeBitWidth)));
1357+
unsigned numElements = dstType->isVectorTy() ? (unsigned)cast<IGCLLVM::FixedVectorType>(dstType)->getNumElements() : 1;
1358+
unsigned newInQuotient = midType->getScalarSizeInBits() / interimTypeBitWidth;
1359+
unsigned truncQuotient = static_cast<unsigned>(dstType->getPrimitiveSizeInBits() / numElements / interimTypeBitWidth);
1360+
for (unsigned i = 0; i < numElements; i++) {
1361+
for (unsigned k = 0; k < truncQuotient; k++) {
1362+
Value* extractedVal = m_builder->CreateExtractElement(newInVecValue, m_builder->getInt32(newInQuotient * i + k));
1363+
newTruncVec = m_builder->CreateInsertElement(newTruncVec, extractedVal, m_builder->getInt32(truncQuotient * i + k));
1364+
}
1365+
}
1366+
if (dstType->getScalarSizeInBits() != newTruncVec->getType()->getScalarSizeInBits())
1367+
{
1368+
newTruncVec = m_builder->CreateBitCast(newTruncVec, dstType);
1369+
}
1370+
I.replaceAllUsesWith(newTruncVec);
1371+
I.eraseFromParent();
1372+
if (bitCastInst->use_empty())
1373+
{
1374+
bitCastInst->eraseFromParent();
1375+
}
1376+
Changed = true;
1377+
}
1378+
}

IGC/Compiler/Legalizer/PeepholeTypeLegalizer.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ namespace IGC {
5252
void cleanupZExtInst(Instruction& I);
5353
void cleanupTruncInst(Instruction& I);
5454
void cleanupBitCastInst(Instruction& I);
55+
void cleanupBitCastTruncInst(Instruction& I);
5556

5657
private:
5758
bool NonBitcastInstructionsLegalized;

0 commit comments

Comments
 (0)