Skip to content

Commit fa056a0

Browse files
mtargowsigcbot
authored andcommitted
Mix function optimization (3rd try).
Mix function optimization, found in GLSL shaders.
1 parent 140dada commit fa056a0

File tree

2 files changed

+163
-2
lines changed

2 files changed

+163
-2
lines changed

IGC/Compiler/CustomSafeOptPass.cpp

Lines changed: 162 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,159 @@ void CustomSafeOptPass::matchDp4a(BinaryOperator &I) {
12271227
I.replaceAllUsesWith(Res);
12281228
}
12291229

1230+
// Optimize mix operation if detected.
1231+
// Mix is computed as x*(1 - a) + y*a
1232+
// Replace it with a*(y - x) + x to save one instruction ('add' ISA, 'sub' in IR).
1233+
// This pattern also optimizes a similar operation:
1234+
// x*(a - 1) + y*a which can be replaced with a(x + y) - x
1235+
void CustomSafeOptPass::matchMixOperation(BinaryOperator& I)
1236+
{
1237+
// Pattern Mix check step 1: find a FSub instruction with a constant value of 1.
1238+
if (I.getOpcode() == BinaryOperator::FSub)
1239+
{
1240+
unsigned int fSubOpIdx = 0;
1241+
while (fSubOpIdx < 2 && !llvm::isa<llvm::ConstantFP>(I.getOperand(fSubOpIdx)))
1242+
{
1243+
fSubOpIdx++;
1244+
}
1245+
if ((fSubOpIdx == 1) ||
1246+
((fSubOpIdx == 0) && !llvm::isa<llvm::ConstantFP>(I.getOperand(1))))
1247+
{
1248+
llvm::ConstantFP* fSubOpConst = llvm::dyn_cast<llvm::ConstantFP>(I.getOperand(fSubOpIdx));
1249+
const APFloat& APF = fSubOpConst->getValueAPF();
1250+
bool isInf = APF.isInfinity();
1251+
bool isNaN = APF.isNaN();
1252+
double val = 0.0;
1253+
if (!isInf && !isNaN)
1254+
{
1255+
if (&APF.getSemantics() == &APFloat::IEEEdouble())
1256+
{
1257+
val = APF.convertToDouble();
1258+
}
1259+
else if (&APF.getSemantics() == &APFloat::IEEEsingle())
1260+
{
1261+
val = (double)APF.convertToFloat();
1262+
}
1263+
}
1264+
if (val == 1.0)
1265+
{
1266+
bool doNotOptimize = false;
1267+
bool matchFound = false;
1268+
SmallVector<std::pair<Instruction*, Instruction*>, 3> fMulInsts;
1269+
1270+
// Pattern Mix check step 2: there should be only FMul users of this FSub instruction
1271+
for (User* U : I.users())
1272+
{
1273+
matchFound = false;
1274+
Instruction* fMul = dyn_cast_or_null<Instruction>(U);
1275+
if (fMul && fMul->getOpcode() == BinaryOperator::FMul)
1276+
{
1277+
// Pattern Mix check step 3: there should be only one fAdd user for such an FMul instruction
1278+
if ((cast<Value>(fMul))->hasOneUse())
1279+
{
1280+
Instruction* fAdd = dyn_cast_or_null<Instruction>(*fMul->users().begin());
1281+
1282+
// Pattern Mix check step 4: fAdd should be a user of two FMul instructions
1283+
if (fAdd && fAdd->getOpcode() == BinaryOperator::FAdd)
1284+
{
1285+
unsigned int opIdx = 0;
1286+
while (opIdx < 2 && fMul != fAdd->getOperand(opIdx))
1287+
{
1288+
opIdx++;
1289+
}
1290+
1291+
if (opIdx < 2)
1292+
{
1293+
opIdx = 1 - opIdx; // 0 -> 1 or 1 -> 0
1294+
Instruction* fMul2nd = dyn_cast_or_null<Instruction>(fAdd->getOperand(opIdx));
1295+
1296+
// Pattern Mix check step 5: Second fMul should be a user of the same,
1297+
// other than a value of 1.0, operand as fSub instruction
1298+
if (fMul2nd && fMul2nd->getOpcode() == BinaryOperator::FMul)
1299+
{
1300+
unsigned int fSubNon1OpIdx = 1 - fSubOpIdx; // 0 -> 1 or 1 -> 0
1301+
while (opIdx < 2 && fMul2nd->getOperand(opIdx) != I.getOperand(fSubNon1OpIdx))
1302+
{
1303+
opIdx++;
1304+
}
1305+
1306+
if (opIdx < 2)
1307+
{
1308+
fMulInsts.push_back(std::make_pair(fMul, fMul2nd));
1309+
matchFound = true; // Pattern Mix (partially) detected.
1310+
}
1311+
}
1312+
}
1313+
}
1314+
}
1315+
}
1316+
1317+
if (!matchFound)
1318+
{
1319+
doNotOptimize = true; // To optimize both FMul instructions and FAdd must be found
1320+
}
1321+
}
1322+
1323+
if (!doNotOptimize && !fMulInsts.empty() && I.users().begin() != I.users().end())
1324+
{
1325+
// Pattern Mix fully detected. Replace sequence of detected instructions with new ones.
1326+
IGC_ASSERT_MESSAGE(
1327+
fMulInsts.size() == (int)std::distance(I.users().begin(), I.users().end()),
1328+
"Incorrect pattern match data");
1329+
// If Pattern Mix with 1-a in the first instruction was detected then create
1330+
// this sequence of new instructions: FSub, FMul, FAdd.
1331+
// But if Pattern Mix with a-1 in the first instruction was detected then create
1332+
// this sequence of new instructions: FAdd, FMul, FSub.
1333+
Instruction::BinaryOps newFirstInstType = (fSubOpIdx == 0) ? Instruction::FSub : Instruction::FAdd;
1334+
Instruction::BinaryOps newLastInstType = (fSubOpIdx == 0) ? Instruction::FAdd : Instruction::FSub;
1335+
1336+
fSubOpIdx = 1 - fSubOpIdx; // 0 -> 1 or 1 -> 0, i.e. get another FSub operand
1337+
Value* r = I.getOperand(fSubOpIdx);
1338+
1339+
for (std::pair<Instruction*, Instruction*> fMulPair : fMulInsts)
1340+
{
1341+
Instruction* fAdd = cast<Instruction>(*fMulPair.first->users().begin());
1342+
1343+
unsigned int fMul2OpToFirstInstIdx = (r == fMulPair.second->getOperand(0)) ? 1 : 0;
1344+
Value* newFirstInstOp = fMulPair.second->getOperand(fMul2OpToFirstInstIdx);
1345+
Value* fSubVal = cast<Value>(&I);
1346+
unsigned int fMul1OpToTakeIdx = (fSubVal == fMulPair.first->getOperand(0)) ? 1 : 0;
1347+
1348+
Instruction* newFirstInst = BinaryOperator::Create(
1349+
newFirstInstType, newFirstInstOp, fMulPair.first->getOperand(fMul1OpToTakeIdx), "", fAdd);
1350+
newFirstInst->copyFastMathFlags(fMulPair.first);
1351+
DILocation* DL1st = I.getDebugLoc();
1352+
if (DL1st)
1353+
{
1354+
newFirstInst->setDebugLoc(DL1st);
1355+
}
1356+
1357+
Instruction* newFMul = BinaryOperator::CreateFMul(
1358+
fMulPair.second->getOperand((fMul2OpToFirstInstIdx + 1) % 2), newFirstInst, "", fAdd);
1359+
newFMul->copyFastMathFlags(fMulPair.second);
1360+
DILocation* DL2nd = fMulPair.second->getDebugLoc();
1361+
if (DL2nd)
1362+
{
1363+
newFMul->setDebugLoc(DL2nd);
1364+
}
1365+
1366+
Instruction* newLastInst = BinaryOperator::Create(
1367+
newLastInstType, newFMul, fMulPair.first->getOperand(fMul1OpToTakeIdx), "", fAdd);
1368+
newLastInst->copyFastMathFlags(fAdd);
1369+
DILocation* DL3rd = fAdd->getDebugLoc();
1370+
if (DL3rd)
1371+
{
1372+
newLastInst->setDebugLoc(DL3rd);
1373+
}
1374+
1375+
fAdd->replaceAllUsesWith(newLastInst);
1376+
}
1377+
}
1378+
}
1379+
}
1380+
}
1381+
}
1382+
12301383
void CustomSafeOptPass::hoistDp3(BinaryOperator& I)
12311384
{
12321385
if (I.getOpcode() != Instruction::BinaryOps::FAdd)
@@ -1602,6 +1755,15 @@ void CustomSafeOptPass::visitBinaryOperator(BinaryOperator& I)
16021755
{
16031756
matchDp4a(I);
16041757

1758+
CodeGenContext* pContext = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
1759+
1760+
if (!pContext->platform.supportLRPInstruction())
1761+
{
1762+
// Optimize mix operation if detected.
1763+
// Mix is computed as x*(1 - a) + y*a
1764+
matchMixOperation(I);
1765+
}
1766+
16051767
// move immediate value in consecutive integer adds to the last added value.
16061768
// this can allow more chance of doing CSE and memopt.
16071769
// a = b + 8
@@ -1610,8 +1772,6 @@ void CustomSafeOptPass::visitBinaryOperator(BinaryOperator& I)
16101772
// a = b + c
16111773
// d = a + 8
16121774

1613-
CodeGenContext* pContext = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
1614-
16151775
// Before WA if() as it's validated behavior.
16161776
if (I.getType()->isIntegerTy() && I.getOpcode() == Instruction::Or)
16171777
{

IGC/Compiler/CustomSafeOptPass.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ namespace IGC
9191
void visitBitCast(llvm::BitCastInst& BC);
9292

9393
void matchDp4a(llvm::BinaryOperator& I);
94+
void matchMixOperation(llvm::BinaryOperator& I);
9495
void hoistDp3(llvm::BinaryOperator& I);
9596

9697
template <typename MaskType> void matchReverse(llvm::BinaryOperator& I);

0 commit comments

Comments
 (0)