@@ -68,20 +68,23 @@ bool PeepholeTypeLegalizer::runOnFunction(Function& F) {
68
68
ctx->platform .WaDisableD64ScratchMessage ()) &&
69
69
ctx->getModuleMetaData ()->compOpt .UseScratchSpacePrivateMemory ;
70
70
71
+ NonBitcastInstructionsLegalized = false ;
72
+ CastInst_ZExtWithIntermediateIllegalsEliminated = false ;
73
+ CastInst_TruncWithIntermediateIllegalsEliminated = false ;
74
+ Bitcast_BitcastWithIntermediateIllegalsEliminated = false ;
75
+
71
76
IGCLLVM::IRBuilder<> builder (F.getContext ());
72
77
m_builder = &builder;
73
78
74
79
Changed = false ;
75
80
visit (F);
76
- if (Changed) {
77
- NonBitcastInstructionsLegalized = true ;
78
- visit (F);
79
- CastInst_ZExtWithIntermediateIllegalsEliminated = true ;
80
- visit (F);
81
- CastInst_TruncWithIntermediateIllegalsEliminated = true ;
82
- visit (F);
83
- Bitcast_BitcastWithIntermediateIllegalsEliminated = true ;
84
- }
81
+ NonBitcastInstructionsLegalized = true ;
82
+ visit (F);
83
+ CastInst_ZExtWithIntermediateIllegalsEliminated = true ;
84
+ visit (F);
85
+ CastInst_TruncWithIntermediateIllegalsEliminated = true ;
86
+ visit (F);
87
+ Bitcast_BitcastWithIntermediateIllegalsEliminated = true ;
85
88
return Changed;
86
89
}
87
90
@@ -110,37 +113,39 @@ void PeepholeTypeLegalizer::visitInstruction(Instruction& I) {
110
113
return ;
111
114
112
115
if (!I.getOperand (0 )->getType ()->isIntOrIntVectorTy () &&
113
- !dyn_cast <ExtractElementInst>(&I))
116
+ !isa <ExtractElementInst>(&I))
114
117
return ; // Legalization for int types only or for extractelements
115
118
116
119
m_builder->SetInsertPoint (&I);
117
120
118
121
// Depending on the phase of legalization pass, call appropriate function
119
122
if (!NonBitcastInstructionsLegalized) { // LEGALIZE ALUs first
120
- if (dyn_cast <PHINode>(&I)) {
123
+ if (isa <PHINode>(&I)) {
121
124
legalizePhiInstruction (I); // phi nodes and all incoming values
122
125
}
123
- else if (dyn_cast <UnaryInstruction>(&I)) {
126
+ else if (isa <UnaryInstruction>(&I)) {
124
127
legalizeUnaryInstruction (I); // pointercast &/or load
125
128
}
126
- else if (dyn_cast <ICmpInst>(&I) || dyn_cast <BinaryOperator>(&I) || dyn_cast <SelectInst>(&I)) {
129
+ else if (isa <ICmpInst>(&I) || isa <BinaryOperator>(&I) || isa <SelectInst>(&I)) {
127
130
legalizeBinaryOperator (I); // Bitwise and Arithmetic Operations
128
131
}
129
- else if (dyn_cast <ExtractElementInst>(&I)) {
132
+ else if (isa <ExtractElementInst>(&I)) {
130
133
legalizeExtractElement (I);
131
134
}
132
135
}
133
136
else if (!CastInst_ZExtWithIntermediateIllegalsEliminated) { // Eliminate intermediate ILLEGAL operands in bitcast-zext or trunc-zext pairs
134
- if (dyn_cast <ZExtInst>(&I))
137
+ if (isa <ZExtInst>(&I))
135
138
cleanupZExtInst (I);
136
139
}
137
140
else if (!CastInst_TruncWithIntermediateIllegalsEliminated) { // Eliminate intermediate ILLEGAL operands in bitcast-zext or trunc-zext pairs
138
- if (dyn_cast <TruncInst>(&I))
141
+ if (isa <TruncInst>(&I))
139
142
cleanupTruncInst (I);
140
143
}
141
144
else if (!Bitcast_BitcastWithIntermediateIllegalsEliminated) { // Eliminate redundant bitcast-bitcast pairs and eliminate intermediate ILLEGAL operands in bitcast-bitcast pairs with src == dest OR src != dest
142
- if (dyn_cast <BitCastInst>(&I))
145
+ if (isa <BitCastInst>(&I))
143
146
cleanupBitCastInst (I);
147
+ if (isa<TruncInst>(&I))
148
+ cleanupBitCastTruncInst (I);
144
149
}
145
150
}
146
151
@@ -779,6 +784,7 @@ void PeepholeTypeLegalizer::legalizeUnaryInstruction(Instruction& I) {
779
784
// %4 = extractelement %1, 1
780
785
// %5 = insertelement %3, %4, 1
781
786
// %6 = bitcast <2 x i64> %5 to i128
787
+
782
788
unsigned dstSize = I.getType ()->getScalarSizeInBits ();
783
789
unsigned srcSize = I.getOperand (0 )->getType ()->getScalarSizeInBits ();
784
790
@@ -801,18 +807,20 @@ void PeepholeTypeLegalizer::legalizeUnaryInstruction(Instruction& I) {
801
807
return ;
802
808
}
803
809
804
- unsigned numSrcElements = srcSize / promotedInt;
805
- unsigned numDstElements = dstSize / promotedInt;
810
+ unsigned numSrcElements = static_cast < unsigned >(I. getOperand ( 0 )-> getType ()-> getPrimitiveSizeInBits () / promotedInt) ;
811
+ unsigned numDstElements = static_cast < unsigned >(I. getType ()-> getPrimitiveSizeInBits () / promotedInt) ;
806
812
Type* srcVecTy = IGCLLVM::FixedVectorType::get (Type::getIntNTy (I.getContext (), promotedInt), numSrcElements);
807
813
Type* dstVecTy = IGCLLVM::FixedVectorType::get (Type::getIntNTy (I.getContext (), promotedInt), numDstElements);
808
814
809
815
// Bitcast the illegal src type to a legal vector
810
816
Value* srcVec = m_builder->CreateBitCast (I.getOperand (0 ), srcVecTy);
811
817
Value* dstVec = UndefValue::get (dstVecTy);
818
+ unsigned numElements = I.getType ()->isVectorTy () ? (unsigned )cast<IGCLLVM::FixedVectorType>(I.getType ())->getNumElements () : 1 ;
812
819
813
820
for (unsigned i = 0 ; i < numDstElements; i++)
814
821
{
815
- Value* v = m_builder->CreateExtractElement (srcVec, m_builder->getInt32 (i));
822
+ Value* v = m_builder->CreateExtractElement (srcVec, m_builder->getInt32 ((i / (numDstElements / numElements)) *
823
+ (numSrcElements / numElements) + (i % (numDstElements / numElements))));
816
824
dstVec = m_builder->CreateInsertElement (dstVec, v, m_builder->getInt32 (i));
817
825
}
818
826
// Cast back to original dst type
@@ -907,6 +915,7 @@ void PeepholeTypeLegalizer::cleanupZExtInst(Instruction& I) {
907
915
}
908
916
else {
909
917
// this is a place holder, but DO NOT expect to need an implementation for this case.
918
+ IGC_ASSERT_MESSAGE (0 , " Not yet implemented" );
910
919
}
911
920
}
912
921
else {
@@ -942,6 +951,7 @@ void PeepholeTypeLegalizer::cleanupZExtInst(Instruction& I) {
942
951
}
943
952
else { // (promoteToInt*quotient != Src1width) case
944
953
// No support yet
954
+ IGC_ASSERT_MESSAGE (0 , " Not yet implemented" );
945
955
}
946
956
}
947
957
break ;
@@ -1135,23 +1145,26 @@ void PeepholeTypeLegalizer::cleanupTruncInst(Instruction& I) {
1135
1145
}
1136
1146
else
1137
1147
{
1138
- for (Value::use_iterator UI = I. use_begin (), UE = I. use_end (); UI != UE; ++UI) {
1139
- if (TruncInst* useTrunc = dyn_cast<TruncInst>(UI-> getUser ()))
1140
- {
1141
- IGC_ASSERT (I. getType ()-> getScalarSizeInBits () > useTrunc-> getType ()-> getScalarSizeInBits ());
1142
- auto newTrunc = dyn_cast<TruncInst>(m_builder-> CreateTrunc (I. getOperand ( 0 ), useTrunc-> getType ()));
1143
- useTrunc-> replaceAllUsesWith (newTrunc);
1144
- // Commented out for now because it breaks use_iterator
1145
- // useTrunc->eraseFromParent( );
1146
- cleanupTruncInst (*newTrunc );
1147
- Changed = true ;
1148
- }
1148
+ if (TruncInst* prevTrunc = dyn_cast<TruncInst>(I. getOperand ( 0 )))
1149
+ {
1150
+ // Example:
1151
+ // %1 = trunc i96 %in to i65
1152
+ // %out = trunc i65 %1 to i64
1153
+ // =>
1154
+ // % out = trunc i96 %in to i64
1155
+ IGC_ASSERT (prevTrunc-> getType ()-> getScalarSizeInBits () > I. getType ()-> getScalarSizeInBits () );
1156
+ auto * newTrunc = cast<TruncInst>(m_builder-> CreateTrunc (prevTrunc-> getOperand ( 0 ), I. getType ()) );
1157
+ I. replaceAllUsesWith (newTrunc) ;
1158
+ Changed = true ;
1149
1159
}
1150
1160
}
1151
1161
1152
1162
if (I.use_empty ())
1153
1163
{
1164
+ Instruction* prevInst = dyn_cast<Instruction>(I.getOperand (0 ));
1154
1165
I.eraseFromParent ();
1166
+ if (prevInst && prevInst->use_empty ())
1167
+ prevInst->eraseFromParent ();
1155
1168
Changed = true ;
1156
1169
}
1157
1170
@@ -1160,23 +1173,22 @@ void PeepholeTypeLegalizer::cleanupTruncInst(Instruction& I) {
1160
1173
1161
1174
void PeepholeTypeLegalizer::cleanupBitCastInst (Instruction& I) {
1162
1175
1163
- /*
1164
- Need to handle:
1165
- 1. bitcast
1166
- 2. bitcast addrspace*
1167
-
1168
- a. bitcast iSrc , iILLEGAL
1169
- bitcast iILLEGAL, iSrc
1170
- b. bitcast iSrc, iILLEGAL
1171
- bitcast iILLEGAL, iLEGAL
1172
- */
1173
-
1174
1176
Instruction* prevInst = dyn_cast<Instruction>(I.getOperand (0 ));
1175
1177
if (!prevInst)
1176
1178
return ;
1177
1179
switch (prevInst->getOpcode ()) {
1178
1180
case Instruction::BitCast:
1179
1181
{
1182
+ /*
1183
+ Need to handle:
1184
+ 1. bitcast
1185
+ 2. bitcast addrspace*
1186
+
1187
+ a. bitcast iSrc , iILLEGAL
1188
+ bitcast iILLEGAL, iSrc
1189
+ b. bitcast iSrc, iILLEGAL
1190
+ bitcast iILLEGAL, iLEGAL
1191
+ */
1180
1192
Type* srcType = prevInst->getOperand (0 )->getType ();
1181
1193
Type* dstType = I.getType ();
1182
1194
if (srcType == dstType)
@@ -1229,7 +1241,138 @@ void PeepholeTypeLegalizer::cleanupBitCastInst(Instruction& I) {
1229
1241
}
1230
1242
break ;
1231
1243
}
1244
+ case Instruction::ZExt:
1245
+ {
1246
+ Type* srcType = prevInst->getOperand (0 )->getType ();
1247
+ Type* midType = prevInst->getType ();
1248
+ Type* dstType = I.getType ();
1249
+ if (isLegalInteger (srcType->getScalarSizeInBits ())
1250
+ && !isLegalInteger (midType->getScalarSizeInBits ())
1251
+ && isLegalInteger (dstType->getScalarSizeInBits ()))
1252
+ {
1253
+ m_builder->SetInsertPoint (&I);
1254
+
1255
+ IGC_ASSERT_MESSAGE (midType->getScalarSizeInBits () % 8 == 0 , " Unexpected type" );
1256
+ int interimTypeBitWidth = DL->getLargestLegalIntTypeSizeInBits ();
1257
+ for (; interimTypeBitWidth >= 8 ; interimTypeBitWidth /= 2 )
1258
+ {
1259
+ if (srcType->getScalarSizeInBits () % interimTypeBitWidth == 0
1260
+ && midType->getScalarSizeInBits () % interimTypeBitWidth == 0 )
1261
+ break ;
1262
+ }
1263
+ Value* newInVecValue = prevInst->getOperand (0 );
1264
+ if (srcType->getScalarSizeInBits () != interimTypeBitWidth)
1265
+ {
1266
+ Type* newInVecType = IGCLLVM::FixedVectorType::get (Type::getIntNTy (I.getContext (), interimTypeBitWidth),
1267
+ static_cast <unsigned >(srcType->getPrimitiveSizeInBits () / interimTypeBitWidth));
1268
+ newInVecValue = m_builder->CreateBitCast (newInVecValue, newInVecType);
1269
+ }
1270
+ Value* newExtVec = UndefValue::get (IGCLLVM::FixedVectorType::get (Type::getIntNTy (I.getContext (),
1271
+ interimTypeBitWidth), static_cast <unsigned >(dstType->getPrimitiveSizeInBits () / interimTypeBitWidth)));
1272
+ unsigned numElements = srcType->isVectorTy () ? (unsigned )cast<IGCLLVM::FixedVectorType>(srcType)->getNumElements () : 1 ;
1273
+ unsigned newInQuotient = srcType->getScalarSizeInBits () / interimTypeBitWidth;
1274
+ unsigned extQuotient = static_cast <unsigned >(dstType->getPrimitiveSizeInBits () / numElements / interimTypeBitWidth);
1275
+ auto zero = ConstantInt::get (IntegerType::get (I.getContext (), interimTypeBitWidth), 0 , false );
1276
+ for (unsigned i = 0 ; i < numElements; i++) {
1277
+ for (unsigned k = 0 ; k < newInQuotient; k++) {
1278
+ Value* extractedVal = m_builder->CreateExtractElement (newInVecValue, m_builder->getInt32 (newInQuotient * i + k));
1279
+ newExtVec = m_builder->CreateInsertElement (newExtVec, extractedVal, m_builder->getInt32 (extQuotient * i + k));
1280
+ }
1281
+ for (unsigned k = newInQuotient; k < extQuotient; k++) {
1282
+ newExtVec = m_builder->CreateInsertElement (newExtVec, zero, m_builder->getInt32 (extQuotient * i + k));
1283
+ }
1284
+ }
1285
+ if (dstType->getScalarSizeInBits () != newExtVec->getType ()->getScalarSizeInBits ())
1286
+ {
1287
+ newExtVec = m_builder->CreateBitCast (newExtVec, dstType);
1288
+ }
1289
+ I.replaceAllUsesWith (newExtVec);
1290
+ I.eraseFromParent ();
1291
+ if (prevInst->use_empty ())
1292
+ {
1293
+ prevInst->eraseFromParent ();
1294
+ }
1295
+ Changed = true ;
1296
+ }
1297
+ else
1298
+ {
1299
+ IGC_ASSERT_MESSAGE (isLegalInteger (srcType->getScalarSizeInBits ())
1300
+ && isLegalInteger (midType->getScalarSizeInBits ()),
1301
+ " Unexpected illegal type width" );
1302
+ }
1303
+ break ;
1304
+ }
1232
1305
default :
1306
+ IGC_ASSERT_MESSAGE (isLegalInteger (I.getOperand (0 )->getType ()->getScalarSizeInBits ()),
1307
+ " Unexpected illegal type width" );
1233
1308
break ;
1234
1309
}
1235
1310
}
1311
+
1312
+ void PeepholeTypeLegalizer::cleanupBitCastTruncInst (Instruction& I) {
1313
+
1314
+ BitCastInst* bitCastInst = dyn_cast<BitCastInst>(I.getOperand (0 ));
1315
+ if (!bitCastInst)
1316
+ return ;
1317
+
1318
+ if (isLegalInteger (bitCastInst->getOperand (0 )->getType ()->getScalarSizeInBits ())
1319
+ && !isLegalInteger (bitCastInst->getType ()->getScalarSizeInBits ()) &&
1320
+ isLegalInteger (I.getType ()->getScalarSizeInBits ()))
1321
+ {
1322
+ /*
1323
+ Example:
1324
+ %2 = bitcast <3 x i32> %0 to <2 x i48>
1325
+ %3 = trunc <2 x i48> %2 to <2 x i16>
1326
+ =>
1327
+ %2 = bitcast <3 x i32> %0 to <6 x i16>
1328
+ %3 = extractelement <6 x i16> %2, i32 0
1329
+ %4 = insertelement <2 x i16> undef, i16 %3, i32 0
1330
+ %5 = extractelement <6 x i16> %2, i32 3
1331
+ %6 = insertelement <2 x i16> %4, i16 %5, i32 1
1332
+ */
1333
+
1334
+ m_builder->SetInsertPoint (&I);
1335
+
1336
+ Type* srcType = bitCastInst->getOperand (0 )->getType ();
1337
+ Type* midType = bitCastInst->getType ();
1338
+ Type* dstType = I.getType ();
1339
+
1340
+ IGC_ASSERT_MESSAGE (midType->getScalarSizeInBits () % 8 == 0 , " Unexpected type" );
1341
+ int interimTypeBitWidth = DL->getLargestLegalIntTypeSizeInBits ();
1342
+ for (; interimTypeBitWidth >= 8 ; interimTypeBitWidth /= 2 )
1343
+ {
1344
+ if (midType->getScalarSizeInBits () % interimTypeBitWidth == 0
1345
+ && dstType->getScalarSizeInBits () % interimTypeBitWidth == 0 )
1346
+ break ;
1347
+ }
1348
+ Value* newInVecValue = bitCastInst->getOperand (0 );
1349
+ if (srcType->getScalarSizeInBits () != interimTypeBitWidth)
1350
+ {
1351
+ Type* newInVecType = IGCLLVM::FixedVectorType::get (Type::getIntNTy (I.getContext (),
1352
+ interimTypeBitWidth), static_cast <unsigned >(midType->getPrimitiveSizeInBits () / interimTypeBitWidth));
1353
+ newInVecValue = m_builder->CreateBitCast (newInVecValue, newInVecType);
1354
+ }
1355
+ Value* newTruncVec = UndefValue::get (IGCLLVM::FixedVectorType::get (Type::getIntNTy (I.getContext (),
1356
+ interimTypeBitWidth), static_cast <unsigned >(dstType->getPrimitiveSizeInBits () / interimTypeBitWidth)));
1357
+ unsigned numElements = dstType->isVectorTy () ? (unsigned )cast<IGCLLVM::FixedVectorType>(dstType)->getNumElements () : 1 ;
1358
+ unsigned newInQuotient = midType->getScalarSizeInBits () / interimTypeBitWidth;
1359
+ unsigned truncQuotient = static_cast <unsigned >(dstType->getPrimitiveSizeInBits () / numElements / interimTypeBitWidth);
1360
+ for (unsigned i = 0 ; i < numElements; i++) {
1361
+ for (unsigned k = 0 ; k < truncQuotient; k++) {
1362
+ Value* extractedVal = m_builder->CreateExtractElement (newInVecValue, m_builder->getInt32 (newInQuotient * i + k));
1363
+ newTruncVec = m_builder->CreateInsertElement (newTruncVec, extractedVal, m_builder->getInt32 (truncQuotient * i + k));
1364
+ }
1365
+ }
1366
+ if (dstType->getScalarSizeInBits () != newTruncVec->getType ()->getScalarSizeInBits ())
1367
+ {
1368
+ newTruncVec = m_builder->CreateBitCast (newTruncVec, dstType);
1369
+ }
1370
+ I.replaceAllUsesWith (newTruncVec);
1371
+ I.eraseFromParent ();
1372
+ if (bitCastInst->use_empty ())
1373
+ {
1374
+ bitCastInst->eraseFromParent ();
1375
+ }
1376
+ Changed = true ;
1377
+ }
1378
+ }
0 commit comments