@@ -956,10 +956,20 @@ bool CustomSafeOptPass::isEmulatedAdd(BinaryOperator& I)
956
956
// We can remove the extra casts in this case.
957
957
// This becomes:
958
958
// %41 = fadd fast float %34, %33
959
+ // Can also do matches with fadd/fmul that will later become an mad instruction.
960
+ // mad example:
961
+ // %.prec70.i = fptrunc float %273 to half
962
+ // %.prec78.i = fptrunc float %276 to half
963
+ // %279 = fmul fast half %233, %.prec70.i
964
+ // %282 = fadd fast half %279, %.prec78.i
965
+ // %.prec84.i = fpext half %282 to float
966
+ // This becomes:
967
+ // %279 = fpext half %233 to float
968
+ // %280 = fmul fast float %273, %279
969
+ // %281 = fadd fast float %280, %276
959
970
void CustomSafeOptPass::removeHftoFCast (Instruction& I)
960
971
{
961
972
// Skip if mix mode is supported
962
-
963
973
CodeGenContext* Ctx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext ();
964
974
if (Ctx->platform .supportMixMode ())
965
975
return ;
@@ -971,6 +981,7 @@ void CustomSafeOptPass::removeHftoFCast(Instruction& I)
971
981
if (!I.hasOneUse ())
972
982
return ;
973
983
984
+ // Check if this instruction is used in a single FPExtInst
974
985
FPExtInst* castInst = NULL ;
975
986
User* U = *I.user_begin ();
976
987
if (FPExtInst* inst = dyn_cast<FPExtInst>(U))
@@ -983,6 +994,79 @@ void CustomSafeOptPass::removeHftoFCast(Instruction& I)
983
994
if (!castInst)
984
995
return ;
985
996
997
+ // Check for fmad pattern
998
+ if (I.getOpcode () == Instruction::FAdd)
999
+ {
1000
+ Value* src0 = nullptr , * src1 = nullptr , * src2 = nullptr ;
1001
+
1002
+ // CodeGenPatternMatch::MatchMad matches the first fmul.
1003
+ Instruction* fmulInst = nullptr ;
1004
+ for (uint i = 0 ; i < 2 ; i++)
1005
+ {
1006
+ fmulInst = dyn_cast<Instruction>(I.getOperand (i));
1007
+ if (fmulInst && fmulInst->getNumOperands () == 2 )
1008
+ {
1009
+ src0 = fmulInst->getOperand (0 );
1010
+ src1 = fmulInst->getOperand (1 );
1011
+ src2 = I.getOperand (1 - i);
1012
+ break ;
1013
+ }
1014
+ }
1015
+ if (fmulInst)
1016
+ {
1017
+ // Used to get the new float operands for the new instructions
1018
+ auto getFloatValue = [](Value* operand, Instruction* I, Type* type)
1019
+ {
1020
+ if (FPTruncInst* inst = dyn_cast<FPTruncInst>(operand))
1021
+ {
1022
+ // Use the float input of the FPTrunc
1023
+ if (inst->getOperand (0 )->getType ()->isFloatTy ())
1024
+ {
1025
+ return inst->getOperand (0 );
1026
+ }
1027
+ else
1028
+ {
1029
+ return (Value*)NULL ;
1030
+ }
1031
+ }
1032
+ else if (Instruction* inst = dyn_cast<Instruction>(operand))
1033
+ {
1034
+ // Cast the result of this operand to a float
1035
+ return dyn_cast<Value>(new FPExtInst (inst, type, " " , I));
1036
+ }
1037
+ return (Value*)NULL ;
1038
+ };
1039
+
1040
+ int convertCount = 0 ;
1041
+ if (dyn_cast<FPTruncInst>(src0))
1042
+ convertCount++;
1043
+ if (dyn_cast<FPTruncInst>(src1))
1044
+ convertCount++;
1045
+ if (dyn_cast<FPTruncInst>(src2))
1046
+ convertCount++;
1047
+ if (convertCount >= 2 )
1048
+ {
1049
+ // Conversion for the hf values
1050
+ auto floatTy = castInst->getType ();
1051
+ src0 = getFloatValue (src0, fmulInst, floatTy);
1052
+ src1 = getFloatValue (src1, fmulInst, floatTy);
1053
+ src2 = getFloatValue (src2, &I, floatTy);
1054
+
1055
+ // Create new float fmul and fadd instructions
1056
+ Value* newFmul = BinaryOperator::Create (Instruction::FMul, src0, src1, " " , &I);
1057
+ Value* newFadd = BinaryOperator::Create (Instruction::FAdd, newFmul, src2, " " , &I);
1058
+
1059
+ // Copy fast math flags
1060
+ Instruction* fmulInst = dyn_cast<Instruction>(newFmul);
1061
+ Instruction* faddInst = dyn_cast<Instruction>(newFadd);
1062
+ fmulInst->copyFastMathFlags (fmulInst);
1063
+ faddInst->copyFastMathFlags (&I);
1064
+ castInst->replaceAllUsesWith (faddInst);
1065
+ return ;
1066
+ }
1067
+ }
1068
+ }
1069
+
986
1070
// Check if operands come from a Float to HF Cast
987
1071
Value *S1 = NULL , *S2 = NULL ;
988
1072
if (FPTruncInst* inst = dyn_cast<FPTruncInst>(I.getOperand (0 )))
0 commit comments