Skip to content

Commit 6d6314b

Browse files
committed
[DAGCombiner] Extend combineFMulOrFDivWithIntPow2 to work for non-splat float vecs
Do so by extending `matchUnaryPredicate` to also work for `ConstantFPSDNode` types then encapsulate the constant checks in a lambda and pass it to `matchUnaryPredicate`. Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D154868
1 parent 47c642f commit 6d6314b

File tree

4 files changed

+72
-86
lines changed

4 files changed

+72
-86
lines changed

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3128,9 +3128,25 @@ namespace ISD {
31283128
/// Attempt to match a unary predicate against a scalar/splat constant or
31293129
/// every element of a constant BUILD_VECTOR.
31303130
/// If AllowUndef is true, then UNDEF elements will pass nullptr to Match.
3131-
bool matchUnaryPredicate(SDValue Op,
3132-
std::function<bool(ConstantSDNode *)> Match,
3133-
bool AllowUndefs = false);
3131+
template <typename ConstNodeType>
3132+
bool matchUnaryPredicateImpl(SDValue Op,
3133+
std::function<bool(ConstNodeType *)> Match,
3134+
bool AllowUndefs = false);
3135+
3136+
/// Hook for matching ConstantSDNode predicate
3137+
inline bool matchUnaryPredicate(SDValue Op,
3138+
std::function<bool(ConstantSDNode *)> Match,
3139+
bool AllowUndefs = false) {
3140+
return matchUnaryPredicateImpl<ConstantSDNode>(Op, Match, AllowUndefs);
3141+
}
3142+
3143+
/// Hook for matching ConstantFPSDNode predicate
3144+
inline bool
3145+
matchUnaryFpPredicate(SDValue Op,
3146+
std::function<bool(ConstantFPSDNode *)> Match,
3147+
bool AllowUndefs = false) {
3148+
return matchUnaryPredicateImpl<ConstantFPSDNode>(Op, Match, AllowUndefs);
3149+
}
31343150

31353151
/// Attempt to match a binary predicate against a pair of scalar/splat
31363152
/// constants or every element of a pair of constant BUILD_VECTORs.

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16352,7 +16352,7 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
1635216352
EVT VT = N->getValueType(0);
1635316353
SDValue ConstOp, Pow2Op;
1635416354

16355-
int Mantissa = -1;
16355+
std::optional<int> Mantissa;
1635616356
auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
1635716357
if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
1635816358
return false;
@@ -16366,36 +16366,43 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
1636616366

1636716367
Pow2Op = Pow2Op.getOperand(0);
1636816368

16369-
// TODO(1): We may be able to include undefs.
16370-
// TODO(2): We could also handle non-splat vector types.
16371-
ConstantFPSDNode *CFP =
16372-
isConstOrConstSplatFP(ConstOp, /*AllowUndefs*/ false);
16373-
if (CFP == nullptr)
16374-
return false;
16375-
const APFloat &APF = CFP->getValueAPF();
16376-
16377-
// Make sure we have normal/ieee constant.
16378-
if (!APF.isNormal() || !APF.isIEEE())
16379-
return false;
16380-
1638116369
// `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
1638216370
// TODO: We could use knownbits to make this bound more precise.
1638316371
int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
1638416372

16385-
// Make sure the floats exponent is within the bounds that this transform
16386-
// produces bitwise equals value.
16387-
int CurExp = ilogb(APF);
16388-
// FMul by pow2 will only increase exponent.
16389-
int MinExp = N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
16390-
// FDiv by pow2 will only decrease exponent.
16391-
int MaxExp = N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
16392-
if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
16393-
MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
16394-
return false;
16373+
auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
16374+
if (CFP == nullptr)
16375+
return false;
16376+
16377+
const APFloat &APF = CFP->getValueAPF();
16378+
16379+
// Make sure we have normal/ieee constant.
16380+
if (!APF.isNormal() || !APF.isIEEE())
16381+
return false;
16382+
16383+
// Make sure the floats exponent is within the bounds that this transform
16384+
// produces bitwise equals value.
16385+
int CurExp = ilogb(APF);
16386+
// FMul by pow2 will only increase exponent.
16387+
int MinExp =
16388+
N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
16389+
// FDiv by pow2 will only decrease exponent.
16390+
int MaxExp =
16391+
N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
16392+
if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
16393+
MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
16394+
return false;
16395+
16396+
// Finally make sure we actually know the mantissa for the float type.
16397+
int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
16398+
if (!Mantissa)
16399+
Mantissa = ThisMantissa;
16400+
16401+
return *Mantissa == ThisMantissa && ThisMantissa > 0;
16402+
};
1639516403

16396-
// Finally make sure we actually know the mantissa for the float type.
16397-
Mantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
16398-
return Mantissa > 0;
16404+
// TODO: We may be able to include undefs.
16405+
return ISD::matchUnaryFpPredicate(ConstOp, IsFPConstValid);
1639916406
};
1640016407

1640116408
if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
@@ -16420,7 +16427,7 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
1642016427

1642116428
// Perform actual transform.
1642216429
SDValue MantissaShiftCnt =
16423-
DAG.getConstant(Mantissa, DL, getShiftAmountTy(NewIntVT));
16430+
DAG.getConstant(*Mantissa, DL, getShiftAmountTy(NewIntVT));
1642416431
// TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
1642516432
// `(X << C1) + (C << C1)`, but that isn't always the case because of the
1642616433
// cast. We could implement that by handle here to handle the casts.

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -344,12 +344,13 @@ bool ISD::isFreezeUndef(const SDNode *N) {
344344
return N->getOpcode() == ISD::FREEZE && N->getOperand(0).isUndef();
345345
}
346346

347-
bool ISD::matchUnaryPredicate(SDValue Op,
348-
std::function<bool(ConstantSDNode *)> Match,
349-
bool AllowUndefs) {
347+
template <typename ConstNodeType>
348+
bool ISD::matchUnaryPredicateImpl(SDValue Op,
349+
std::function<bool(ConstNodeType *)> Match,
350+
bool AllowUndefs) {
350351
// FIXME: Add support for scalar UNDEF cases?
351-
if (auto *Cst = dyn_cast<ConstantSDNode>(Op))
352-
return Match(Cst);
352+
if (auto *C = dyn_cast<ConstNodeType>(Op))
353+
return Match(C);
353354

354355
// FIXME: Add support for vector UNDEF cases?
355356
if (ISD::BUILD_VECTOR != Op.getOpcode() &&
@@ -364,12 +365,17 @@ bool ISD::matchUnaryPredicate(SDValue Op,
364365
continue;
365366
}
366367

367-
auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(i));
368+
auto *Cst = dyn_cast<ConstNodeType>(Op.getOperand(i));
368369
if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst))
369370
return false;
370371
}
371372
return true;
372373
}
374+
// Build used template types.
375+
template bool ISD::matchUnaryPredicateImpl<ConstantSDNode>(
376+
SDValue, std::function<bool(ConstantSDNode *)>, bool);
377+
template bool ISD::matchUnaryPredicateImpl<ConstantFPSDNode>(
378+
SDValue, std::function<bool(ConstantFPSDNode *)>, bool);
373379

374380
bool ISD::matchBinaryPredicate(
375381
SDValue LHS, SDValue RHS,

llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll

Lines changed: 7 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,58 +1104,15 @@ define <4 x float> @fmul_pow_shl_cnt_vec_preserve_fma(<4 x i32> %cnt, <4 x float
11041104
define <2 x double> @fmul_pow_shl_cnt_vec_non_splat_todo(<2 x i64> %cnt) nounwind {
11051105
; CHECK-SSE-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
11061106
; CHECK-SSE: # %bb.0:
1107-
; CHECK-SSE-NEXT: movdqa {{.*#+}} xmm1 = [2,2]
1108-
; CHECK-SSE-NEXT: movdqa %xmm1, %xmm2
1109-
; CHECK-SSE-NEXT: psllq %xmm0, %xmm2
1110-
; CHECK-SSE-NEXT: pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3]
1111-
; CHECK-SSE-NEXT: psllq %xmm0, %xmm1
1112-
; CHECK-SSE-NEXT: movsd {{.*#+}} xmm1 = xmm2[0],xmm1[1]
1113-
; CHECK-SSE-NEXT: movapd {{.*#+}} xmm0 = [4294967295,4294967295]
1114-
; CHECK-SSE-NEXT: andpd %xmm1, %xmm0
1115-
; CHECK-SSE-NEXT: orpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
1116-
; CHECK-SSE-NEXT: psrlq $32, %xmm1
1117-
; CHECK-SSE-NEXT: por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
1118-
; CHECK-SSE-NEXT: subpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
1119-
; CHECK-SSE-NEXT: addpd %xmm0, %xmm1
1120-
; CHECK-SSE-NEXT: mulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
1121-
; CHECK-SSE-NEXT: movapd %xmm1, %xmm0
1107+
; CHECK-SSE-NEXT: psllq $52, %xmm0
1108+
; CHECK-SSE-NEXT: paddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
11221109
; CHECK-SSE-NEXT: retq
11231110
;
1124-
; CHECK-AVX2-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
1125-
; CHECK-AVX2: # %bb.0:
1126-
; CHECK-AVX2-NEXT: vpbroadcastq {{.*#+}} xmm1 = [2,2]
1127-
; CHECK-AVX2-NEXT: vpsllvq %xmm0, %xmm1, %xmm0
1128-
; CHECK-AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1
1129-
; CHECK-AVX2-NEXT: vpblendd {{.*#+}} xmm1 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
1130-
; CHECK-AVX2-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
1131-
; CHECK-AVX2-NEXT: vpsrlq $32, %xmm0, %xmm0
1132-
; CHECK-AVX2-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1133-
; CHECK-AVX2-NEXT: vsubpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1134-
; CHECK-AVX2-NEXT: vaddpd %xmm0, %xmm1, %xmm0
1135-
; CHECK-AVX2-NEXT: vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1136-
; CHECK-AVX2-NEXT: retq
1137-
;
1138-
; CHECK-NO-FASTFMA-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
1139-
; CHECK-NO-FASTFMA: # %bb.0:
1140-
; CHECK-NO-FASTFMA-NEXT: vpbroadcastq {{.*#+}} xmm1 = [2,2]
1141-
; CHECK-NO-FASTFMA-NEXT: vpsllvq %xmm0, %xmm1, %xmm0
1142-
; CHECK-NO-FASTFMA-NEXT: vpxor %xmm1, %xmm1, %xmm1
1143-
; CHECK-NO-FASTFMA-NEXT: vpblendd {{.*#+}} xmm1 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
1144-
; CHECK-NO-FASTFMA-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
1145-
; CHECK-NO-FASTFMA-NEXT: vpsrlq $32, %xmm0, %xmm0
1146-
; CHECK-NO-FASTFMA-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1147-
; CHECK-NO-FASTFMA-NEXT: vsubpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1148-
; CHECK-NO-FASTFMA-NEXT: vaddpd %xmm0, %xmm1, %xmm0
1149-
; CHECK-NO-FASTFMA-NEXT: vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1150-
; CHECK-NO-FASTFMA-NEXT: retq
1151-
;
1152-
; CHECK-FMA-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
1153-
; CHECK-FMA: # %bb.0:
1154-
; CHECK-FMA-NEXT: vpbroadcastq {{.*#+}} xmm1 = [2,2]
1155-
; CHECK-FMA-NEXT: vpsllvq %xmm0, %xmm1, %xmm0
1156-
; CHECK-FMA-NEXT: vcvtuqq2pd %xmm0, %xmm0
1157-
; CHECK-FMA-NEXT: vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1158-
; CHECK-FMA-NEXT: retq
1111+
; CHECK-AVX-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
1112+
; CHECK-AVX: # %bb.0:
1113+
; CHECK-AVX-NEXT: vpsllq $52, %xmm0, %xmm0
1114+
; CHECK-AVX-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1115+
; CHECK-AVX-NEXT: retq
11591116
%shl = shl nsw nuw <2 x i64> <i64 2, i64 2>, %cnt
11601117
%conv = uitofp <2 x i64> %shl to <2 x double>
11611118
%mul = fmul <2 x double> <double 15.000000e+00, double 14.000000e+00>, %conv

0 commit comments

Comments
 (0)