@@ -1227,6 +1227,159 @@ void CustomSafeOptPass::matchDp4a(BinaryOperator &I) {
1227
1227
I.replaceAllUsesWith (Res);
1228
1228
}
1229
1229
1230
+ // Optimize mix operation if detected.
1231
+ // Mix is computed as x*(1 - a) + y*a
1232
+ // Replace it with a*(y - x) + x to save one instruction ('add' ISA, 'sub' in IR).
1233
+ // This pattern also optimizes a similar operation:
1234
+ // x*(a - 1) + y*a which can be replaced with a(x + y) - x
1235
+ void CustomSafeOptPass::matchMixOperation (BinaryOperator& I)
1236
+ {
1237
+ // Pattern Mix check step 1: find a FSub instruction with a constant value of 1.
1238
+ if (I.getOpcode () == BinaryOperator::FSub)
1239
+ {
1240
+ unsigned int fSubOpIdx = 0 ;
1241
+ while (fSubOpIdx < 2 && !llvm::isa<llvm::ConstantFP>(I.getOperand (fSubOpIdx )))
1242
+ {
1243
+ fSubOpIdx ++;
1244
+ }
1245
+ if ((fSubOpIdx == 1 ) ||
1246
+ ((fSubOpIdx == 0 ) && !llvm::isa<llvm::ConstantFP>(I.getOperand (1 ))))
1247
+ {
1248
+ llvm::ConstantFP* fSubOpConst = llvm::dyn_cast<llvm::ConstantFP>(I.getOperand (fSubOpIdx ));
1249
+ const APFloat& APF = fSubOpConst ->getValueAPF ();
1250
+ bool isInf = APF.isInfinity ();
1251
+ bool isNaN = APF.isNaN ();
1252
+ double val = 0.0 ;
1253
+ if (!isInf && !isNaN)
1254
+ {
1255
+ if (&APF.getSemantics () == &APFloat::IEEEdouble ())
1256
+ {
1257
+ val = APF.convertToDouble ();
1258
+ }
1259
+ else if (&APF.getSemantics () == &APFloat::IEEEsingle ())
1260
+ {
1261
+ val = (double )APF.convertToFloat ();
1262
+ }
1263
+ }
1264
+ if (val == 1.0 )
1265
+ {
1266
+ bool doNotOptimize = false ;
1267
+ bool matchFound = false ;
1268
+ SmallVector<std::pair<Instruction*, Instruction*>, 3 > fMulInsts ;
1269
+
1270
+ // Pattern Mix check step 2: there should be only FMul users of this FSub instruction
1271
+ for (User* U : I.users ())
1272
+ {
1273
+ matchFound = false ;
1274
+ Instruction* fMul = dyn_cast_or_null<Instruction>(U);
1275
+ if (fMul && fMul ->getOpcode () == BinaryOperator::FMul)
1276
+ {
1277
+ // Pattern Mix check step 3: there should be only one fAdd user for such an FMul instruction
1278
+ if ((cast<Value>(fMul ))->hasOneUse ())
1279
+ {
1280
+ Instruction* fAdd = dyn_cast_or_null<Instruction>(*fMul ->users ().begin ());
1281
+
1282
+ // Pattern Mix check step 4: fAdd should be a user of two FMul instructions
1283
+ if (fAdd && fAdd ->getOpcode () == BinaryOperator::FAdd)
1284
+ {
1285
+ unsigned int opIdx = 0 ;
1286
+ while (opIdx < 2 && fMul != fAdd ->getOperand (opIdx))
1287
+ {
1288
+ opIdx++;
1289
+ }
1290
+
1291
+ if (opIdx < 2 )
1292
+ {
1293
+ opIdx = 1 - opIdx; // 0 -> 1 or 1 -> 0
1294
+ Instruction* fMul2nd = dyn_cast_or_null<Instruction>(fAdd ->getOperand (opIdx));
1295
+
1296
+ // Pattern Mix check step 5: Second fMul should be a user of the same,
1297
+ // other than a value of 1.0, operand as fSub instruction
1298
+ if (fMul2nd && fMul2nd ->getOpcode () == BinaryOperator::FMul)
1299
+ {
1300
+ unsigned int fSubNon1OpIdx = 1 - fSubOpIdx ; // 0 -> 1 or 1 -> 0
1301
+ while (opIdx < 2 && fMul2nd ->getOperand (opIdx) != I.getOperand (fSubNon1OpIdx ))
1302
+ {
1303
+ opIdx++;
1304
+ }
1305
+
1306
+ if (opIdx < 2 )
1307
+ {
1308
+ fMulInsts .push_back (std::make_pair (fMul , fMul2nd ));
1309
+ matchFound = true ; // Pattern Mix (partially) detected.
1310
+ }
1311
+ }
1312
+ }
1313
+ }
1314
+ }
1315
+ }
1316
+
1317
+ if (!matchFound)
1318
+ {
1319
+ doNotOptimize = true ; // To optimize both FMul instructions and FAdd must be found
1320
+ }
1321
+ }
1322
+
1323
+ if (!doNotOptimize && !fMulInsts .empty () && I.users ().begin () != I.users ().end ())
1324
+ {
1325
+ // Pattern Mix fully detected. Replace sequence of detected instructions with new ones.
1326
+ IGC_ASSERT_MESSAGE (
1327
+ fMulInsts .size () == (int )std::distance (I.users ().begin (), I.users ().end ()),
1328
+ " Incorrect pattern match data" );
1329
+ // If Pattern Mix with 1-a in the first instruction was detected then create
1330
+ // this sequence of new instructions: FSub, FMul, FAdd.
1331
+ // But if Pattern Mix with a-1 in the first instruction was detected then create
1332
+ // this sequence of new instructions: FAdd, FMul, FSub.
1333
+ Instruction::BinaryOps newFirstInstType = (fSubOpIdx == 0 ) ? Instruction::FSub : Instruction::FAdd;
1334
+ Instruction::BinaryOps newLastInstType = (fSubOpIdx == 0 ) ? Instruction::FAdd : Instruction::FSub;
1335
+
1336
+ fSubOpIdx = 1 - fSubOpIdx ; // 0 -> 1 or 1 -> 0, i.e. get another FSub operand
1337
+ Value* r = I.getOperand (fSubOpIdx );
1338
+
1339
+ for (std::pair<Instruction*, Instruction*> fMulPair : fMulInsts )
1340
+ {
1341
+ Instruction* fAdd = cast<Instruction>(*fMulPair .first ->users ().begin ());
1342
+
1343
+ unsigned int fMul2OpToFirstInstIdx = (r == fMulPair .second ->getOperand (0 )) ? 1 : 0 ;
1344
+ Value* newFirstInstOp = fMulPair .second ->getOperand (fMul2OpToFirstInstIdx );
1345
+ Value* fSubVal = cast<Value>(&I);
1346
+ unsigned int fMul1OpToTakeIdx = (fSubVal == fMulPair .first ->getOperand (0 )) ? 1 : 0 ;
1347
+
1348
+ Instruction* newFirstInst = BinaryOperator::Create (
1349
+ newFirstInstType, newFirstInstOp, fMulPair .first ->getOperand (fMul1OpToTakeIdx ), " " , fAdd );
1350
+ newFirstInst->copyFastMathFlags (fMulPair .first );
1351
+ DILocation* DL1st = I.getDebugLoc ();
1352
+ if (DL1st)
1353
+ {
1354
+ newFirstInst->setDebugLoc (DL1st);
1355
+ }
1356
+
1357
+ Instruction* newFMul = BinaryOperator::CreateFMul (
1358
+ fMulPair .second ->getOperand ((fMul2OpToFirstInstIdx + 1 ) % 2 ), newFirstInst, " " , fAdd );
1359
+ newFMul->copyFastMathFlags (fMulPair .second );
1360
+ DILocation* DL2nd = fMulPair .second ->getDebugLoc ();
1361
+ if (DL2nd)
1362
+ {
1363
+ newFMul->setDebugLoc (DL2nd);
1364
+ }
1365
+
1366
+ Instruction* newLastInst = BinaryOperator::Create (
1367
+ newLastInstType, newFMul, fMulPair .first ->getOperand (fMul1OpToTakeIdx ), " " , fAdd );
1368
+ newLastInst->copyFastMathFlags (fAdd );
1369
+ DILocation* DL3rd = fAdd ->getDebugLoc ();
1370
+ if (DL3rd)
1371
+ {
1372
+ newLastInst->setDebugLoc (DL3rd);
1373
+ }
1374
+
1375
+ fAdd ->replaceAllUsesWith (newLastInst);
1376
+ }
1377
+ }
1378
+ }
1379
+ }
1380
+ }
1381
+ }
1382
+
1230
1383
void CustomSafeOptPass::hoistDp3 (BinaryOperator& I)
1231
1384
{
1232
1385
if (I.getOpcode () != Instruction::BinaryOps::FAdd)
@@ -1602,6 +1755,15 @@ void CustomSafeOptPass::visitBinaryOperator(BinaryOperator& I)
1602
1755
{
1603
1756
matchDp4a (I);
1604
1757
1758
+ CodeGenContext* pContext = getAnalysis<CodeGenContextWrapper>().getCodeGenContext ();
1759
+
1760
+ if (!pContext->platform .supportLRPInstruction ())
1761
+ {
1762
+ // Optimize mix operation if detected.
1763
+ // Mix is computed as x*(1 - a) + y*a
1764
+ matchMixOperation (I);
1765
+ }
1766
+
1605
1767
// move immediate value in consecutive integer adds to the last added value.
1606
1768
// this can allow more chance of doing CSE and memopt.
1607
1769
// a = b + 8
@@ -1610,8 +1772,6 @@ void CustomSafeOptPass::visitBinaryOperator(BinaryOperator& I)
1610
1772
// a = b + c
1611
1773
// d = a + 8
1612
1774
1613
- CodeGenContext* pContext = getAnalysis<CodeGenContextWrapper>().getCodeGenContext ();
1614
-
1615
1775
// Before WA if() as it's validated behavior.
1616
1776
if (I.getType ()->isIntegerTy () && I.getOpcode () == Instruction::Or)
1617
1777
{
0 commit comments