Skip to content

Commit d1908ed

Browse files
Henry Estelasys_zuul
authored andcommitted
Remove extra casts from half to float when we find a fadd that has fmul
as one if its arguments. If the arguments from fadd or fmul are fptrunc, we will use the input to the fptrunc so that we can drop the fptrunc later. All other operands will be cast from half to float so that we can create new float copies of the fadd and fmul. This will allow for less packing of the half input for mad later on. Change-Id: I767a47d4e07dd5442885d1d22242bd0236271e25
1 parent a95b1d4 commit d1908ed

File tree

1 file changed

+85
-1
lines changed

1 file changed

+85
-1
lines changed

IGC/Compiler/CustomSafeOptPass.cpp

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -956,10 +956,20 @@ bool CustomSafeOptPass::isEmulatedAdd(BinaryOperator& I)
956956
// We can remove the extra casts in this case.
957957
// This becomes:
958958
// %41 = fadd fast float %34, %33
959+
// Can also do matches with fadd/fmul that will later become an mad instruction.
960+
// mad example:
961+
// %.prec70.i = fptrunc float %273 to half
962+
// %.prec78.i = fptrunc float %276 to half
963+
// %279 = fmul fast half %233, %.prec70.i
964+
// %282 = fadd fast half %279, %.prec78.i
965+
// %.prec84.i = fpext half %282 to float
966+
// This becomes:
967+
// %279 = fpext half %233 to float
968+
// %280 = fmul fast float %273, %279
969+
// %281 = fadd fast float %280, %276
959970
void CustomSafeOptPass::removeHftoFCast(Instruction& I)
960971
{
961972
// Skip if mix mode is supported
962-
963973
CodeGenContext* Ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
964974
if (Ctx->platform.supportMixMode())
965975
return;
@@ -971,6 +981,7 @@ void CustomSafeOptPass::removeHftoFCast(Instruction& I)
971981
if (!I.hasOneUse())
972982
return;
973983

984+
// Check if this instruction is used in a single FPExtInst
974985
FPExtInst* castInst = NULL;
975986
User* U = *I.user_begin();
976987
if (FPExtInst* inst = dyn_cast<FPExtInst>(U))
@@ -983,6 +994,79 @@ void CustomSafeOptPass::removeHftoFCast(Instruction& I)
983994
if (!castInst)
984995
return;
985996

997+
// Check for fmad pattern
998+
if (I.getOpcode() == Instruction::FAdd)
999+
{
1000+
Value* src0 = nullptr, * src1 = nullptr, * src2 = nullptr;
1001+
1002+
// CodeGenPatternMatch::MatchMad matches the first fmul.
1003+
Instruction* fmulInst = nullptr;
1004+
for (uint i = 0; i < 2; i++)
1005+
{
1006+
fmulInst = dyn_cast<Instruction>(I.getOperand(i));
1007+
if (fmulInst && fmulInst->getNumOperands() == 2)
1008+
{
1009+
src0 = fmulInst->getOperand(0);
1010+
src1 = fmulInst->getOperand(1);
1011+
src2 = I.getOperand(1 - i);
1012+
break;
1013+
}
1014+
}
1015+
if (fmulInst)
1016+
{
1017+
// Used to get the new float operands for the new instructions
1018+
auto getFloatValue = [](Value* operand, Instruction* I, Type* type)
1019+
{
1020+
if (FPTruncInst* inst = dyn_cast<FPTruncInst>(operand))
1021+
{
1022+
// Use the float input of the FPTrunc
1023+
if (inst->getOperand(0)->getType()->isFloatTy())
1024+
{
1025+
return inst->getOperand(0);
1026+
}
1027+
else
1028+
{
1029+
return (Value*)NULL;
1030+
}
1031+
}
1032+
else if (Instruction* inst = dyn_cast<Instruction>(operand))
1033+
{
1034+
// Cast the result of this operand to a float
1035+
return dyn_cast<Value>(new FPExtInst(inst, type, "", I));
1036+
}
1037+
return (Value*)NULL;
1038+
};
1039+
1040+
int convertCount = 0;
1041+
if (dyn_cast<FPTruncInst>(src0))
1042+
convertCount++;
1043+
if (dyn_cast<FPTruncInst>(src1))
1044+
convertCount++;
1045+
if (dyn_cast<FPTruncInst>(src2))
1046+
convertCount++;
1047+
if (convertCount >= 2)
1048+
{
1049+
// Conversion for the hf values
1050+
auto floatTy = castInst->getType();
1051+
src0 = getFloatValue(src0, fmulInst, floatTy);
1052+
src1 = getFloatValue(src1, fmulInst, floatTy);
1053+
src2 = getFloatValue(src2, &I, floatTy);
1054+
1055+
// Create new float fmul and fadd instructions
1056+
Value* newFmul = BinaryOperator::Create(Instruction::FMul, src0, src1, "", &I);
1057+
Value* newFadd = BinaryOperator::Create(Instruction::FAdd, newFmul, src2, "", &I);
1058+
1059+
// Copy fast math flags
1060+
Instruction* fmulInst = dyn_cast<Instruction>(newFmul);
1061+
Instruction* faddInst = dyn_cast<Instruction>(newFadd);
1062+
fmulInst->copyFastMathFlags(fmulInst);
1063+
faddInst->copyFastMathFlags(&I);
1064+
castInst->replaceAllUsesWith(faddInst);
1065+
return;
1066+
}
1067+
}
1068+
}
1069+
9861070
// Check if operands come from a Float to HF Cast
9871071
Value *S1 = NULL, *S2 = NULL;
9881072
if (FPTruncInst* inst = dyn_cast<FPTruncInst>(I.getOperand(0)))

0 commit comments

Comments
 (0)