Skip to content

Commit 6592bce

Browse files
committed
[x86] invert a vector select IR canonicalization with a binop identity constant
This is an intentionally limited/different form of D90113. That patch bravely tries to generalize folds where we pull a binop into the arms of a select: N0 + (Cond ? 0 : FVal) --> Cond ? N0 : (N0 + FVal) ...but it is not universally profitable. This is the inverse of IR canonicalization as discussed in D113442. We know that this transform is not entirely profitable even within x86, so we only handle x86 vector fadd/fsub as a 1st step. The intent is to prevent AVX512 regressions as mentioned in D113442. The plan is to port this to DAGCombiner (so it will eventually look more like D90113) and add more types/cases in pieces with many more tests to verify that we are seeing improvements. Differential Revision: https://reviews.llvm.org/D118644
1 parent ccf02cd commit 6592bce

File tree

3 files changed

+102
-26
lines changed

3 files changed

+102
-26
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48942,6 +48942,83 @@ static SDValue combineFaddCFmul(SDNode *N, SelectionDAG &DAG,
4894248942
return DAG.getBitcast(VT, CFmul);
4894348943
}
4894448944

48945+
/// This inverts a canonicalization in IR that replaces a variable select arm
48946+
/// with an identity constant. Codegen improves if we re-use the variable
48947+
/// operand rather than load a constant. This can also be converted into a
48948+
/// masked vector operation if the target supports it.
48949+
static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
48950+
bool ShouldCommuteOperands) {
48951+
// Match a select as operand 1. The identity constant that we are looking for
48952+
// is only valid as operand 1 of a non-commutative binop.
48953+
SDValue N0 = N->getOperand(0);
48954+
SDValue N1 = N->getOperand(1);
48955+
if (ShouldCommuteOperands)
48956+
std::swap(N0, N1);
48957+
48958+
// TODO: Should this apply to scalar select too?
48959+
if (!N1.hasOneUse() || N1.getOpcode() != ISD::VSELECT)
48960+
return SDValue();
48961+
48962+
unsigned Opcode = N->getOpcode();
48963+
EVT VT = N->getValueType(0);
48964+
SDValue Cond = N1.getOperand(0);
48965+
SDValue TVal = N1.getOperand(1);
48966+
SDValue FVal = N1.getOperand(2);
48967+
48968+
// TODO: This (and possibly the entire function) belongs in a
48969+
// target-independent location with target hooks.
48970+
// TODO: The cases should match with IR's ConstantExpr::getBinOpIdentity().
48971+
// TODO: With fast-math (NSZ), allow the opposite-sign form of zero?
48972+
auto isIdentityConstantForOpcode = [](unsigned Opcode, SDValue V) {
48973+
if (ConstantFPSDNode *C = isConstOrConstSplatFP(V)) {
48974+
switch (Opcode) {
48975+
case ISD::FADD: // X + -0.0 --> X
48976+
return C->isZero() && C->isNegative();
48977+
case ISD::FSUB: // X - 0.0 --> X
48978+
return C->isZero() && !C->isNegative();
48979+
}
48980+
}
48981+
return false;
48982+
};
48983+
48984+
// This transform increases uses of N0, so freeze it to be safe.
48985+
// binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
48986+
if (isIdentityConstantForOpcode(Opcode, TVal)) {
48987+
SDValue F0 = DAG.getFreeze(N0);
48988+
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
48989+
return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
48990+
}
48991+
// binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
48992+
if (isIdentityConstantForOpcode(Opcode, FVal)) {
48993+
SDValue F0 = DAG.getFreeze(N0);
48994+
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
48995+
return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
48996+
}
48997+
48998+
return SDValue();
48999+
}
49000+
49001+
static SDValue combineBinopWithSelect(SDNode *N, SelectionDAG &DAG,
49002+
const X86Subtarget &Subtarget) {
49003+
// TODO: This is too general. There are cases where pre-AVX512 codegen would
49004+
// benefit. The transform may also be profitable for scalar code.
49005+
if (!Subtarget.hasAVX512())
49006+
return SDValue();
49007+
49008+
if (!Subtarget.hasVLX() && !N->getValueType(0).is512BitVector())
49009+
return SDValue();
49010+
49011+
if (SDValue Sel = foldSelectWithIdentityConstant(N, DAG, false))
49012+
return Sel;
49013+
49014+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
49015+
if (TLI.isCommutativeBinOp(N->getOpcode()))
49016+
if (SDValue Sel = foldSelectWithIdentityConstant(N, DAG, true))
49017+
return Sel;
49018+
49019+
return SDValue();
49020+
}
49021+
4894549022
/// Do target-specific dag combines on floating-point adds/subs.
4894649023
static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
4894749024
const X86Subtarget &Subtarget) {
@@ -48951,6 +49028,9 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
4895149028
if (SDValue COp = combineFaddCFmul(N, DAG, Subtarget))
4895249029
return COp;
4895349030

49031+
if (SDValue Sel = combineBinopWithSelect(N, DAG, Subtarget))
49032+
return Sel;
49033+
4895449034
return SDValue();
4895549035
}
4895649036

llvm/test/CodeGen/X86/avx512fp16-arith-intrinsics.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ define <32 x half> @test_int_x86_avx512fp16_maskz_sub_ph_512(<32 x half> %src, <
8383
; CHECK: # %bb.0:
8484
; CHECK-NEXT: kmovd %edi, %k1
8585
; CHECK-NEXT: vsubph %zmm2, %zmm1, %zmm0 {%k1} {z}
86-
; CHECK-NEXT: vsubph (%rsi), %zmm1, %zmm1 {%k1} {z}
87-
; CHECK-NEXT: vsubph %zmm1, %zmm0, %zmm0
86+
; CHECK-NEXT: vsubph (%rsi), %zmm1, %zmm1
87+
; CHECK-NEXT: vsubph %zmm1, %zmm0, %zmm0 {%k1}
8888
; CHECK-NEXT: retq
8989
%mask = bitcast i32 %msk to <32 x i1>
9090
%val = load <32 x half>, <32 x half>* %ptr

llvm/test/CodeGen/X86/vector-bo-select.ll

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ define <4 x float> @fadd_v4f32(<4 x i1> %b, <4 x float> noundef %x, <4 x float>
2727
; AVX512VL: # %bb.0:
2828
; AVX512VL-NEXT: vpslld $31, %xmm0, %xmm0
2929
; AVX512VL-NEXT: vptestmd %xmm0, %xmm0, %k1
30-
; AVX512VL-NEXT: vbroadcastss {{.*#+}} xmm0 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
31-
; AVX512VL-NEXT: vmovaps %xmm2, %xmm0 {%k1}
32-
; AVX512VL-NEXT: vaddps %xmm0, %xmm1, %xmm0
30+
; AVX512VL-NEXT: vaddps %xmm2, %xmm1, %xmm1 {%k1}
31+
; AVX512VL-NEXT: vmovaps %xmm1, %xmm0
3332
; AVX512VL-NEXT: retq
3433
%s = select <4 x i1> %b, <4 x float> %y, <4 x float> <float -0.0, float -0.0, float -0.0, float -0.0>
3534
%r = fadd <4 x float> %x, %s
@@ -62,9 +61,8 @@ define <8 x float> @fadd_v8f32_commute(<8 x i1> %b, <8 x float> noundef %x, <8 x
6261
; AVX512VL-NEXT: vpmovsxwd %xmm0, %ymm0
6362
; AVX512VL-NEXT: vpslld $31, %ymm0, %ymm0
6463
; AVX512VL-NEXT: vptestmd %ymm0, %ymm0, %k1
65-
; AVX512VL-NEXT: vbroadcastss {{.*#+}} ymm0 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
66-
; AVX512VL-NEXT: vmovaps %ymm2, %ymm0 {%k1}
67-
; AVX512VL-NEXT: vaddps %ymm1, %ymm0, %ymm0
64+
; AVX512VL-NEXT: vaddps %ymm2, %ymm1, %ymm1 {%k1}
65+
; AVX512VL-NEXT: vmovaps %ymm1, %ymm0
6866
; AVX512VL-NEXT: retq
6967
%s = select <8 x i1> %b, <8 x float> %y, <8 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>
7068
%r = fadd <8 x float> %s, %x
@@ -92,8 +90,8 @@ define <16 x float> @fadd_v16f32_swap(<16 x i1> %b, <16 x float> noundef %x, <16
9290
; AVX512-NEXT: vpmovsxbd %xmm0, %zmm0
9391
; AVX512-NEXT: vpslld $31, %zmm0, %zmm0
9492
; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k1
95-
; AVX512-NEXT: vbroadcastss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2 {%k1}
9693
; AVX512-NEXT: vaddps %zmm2, %zmm1, %zmm0
94+
; AVX512-NEXT: vmovaps %zmm1, %zmm0 {%k1}
9795
; AVX512-NEXT: retq
9896
%s = select <16 x i1> %b, <16 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>, <16 x float> %y
9997
%r = fadd <16 x float> %x, %s
@@ -121,8 +119,8 @@ define <16 x float> @fadd_v16f32_commute_swap(<16 x i1> %b, <16 x float> noundef
121119
; AVX512-NEXT: vpmovsxbd %xmm0, %zmm0
122120
; AVX512-NEXT: vpslld $31, %zmm0, %zmm0
123121
; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k1
124-
; AVX512-NEXT: vbroadcastss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2 {%k1}
125-
; AVX512-NEXT: vaddps %zmm1, %zmm2, %zmm0
122+
; AVX512-NEXT: vaddps %zmm2, %zmm1, %zmm0
123+
; AVX512-NEXT: vmovaps %zmm1, %zmm0 {%k1}
126124
; AVX512-NEXT: retq
127125
%s = select <16 x i1> %b, <16 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>, <16 x float> %y
128126
%r = fadd <16 x float> %s, %x
@@ -152,14 +150,16 @@ define <4 x float> @fsub_v4f32(<4 x i1> %b, <4 x float> noundef %x, <4 x float>
152150
; AVX512VL: # %bb.0:
153151
; AVX512VL-NEXT: vpslld $31, %xmm0, %xmm0
154152
; AVX512VL-NEXT: vptestmd %xmm0, %xmm0, %k1
155-
; AVX512VL-NEXT: vmovaps %xmm2, %xmm0 {%k1} {z}
156-
; AVX512VL-NEXT: vsubps %xmm0, %xmm1, %xmm0
153+
; AVX512VL-NEXT: vsubps %xmm2, %xmm1, %xmm1 {%k1}
154+
; AVX512VL-NEXT: vmovaps %xmm1, %xmm0
157155
; AVX512VL-NEXT: retq
158156
%s = select <4 x i1> %b, <4 x float> %y, <4 x float> zeroinitializer
159157
%r = fsub <4 x float> %x, %s
160158
ret <4 x float> %r
161159
}
162160

161+
; negative test - fsub is not commutative; there is no identity constant for operand 0
162+
163163
define <8 x float> @fsub_v8f32_commute(<8 x i1> %b, <8 x float> noundef %x, <8 x float> noundef %y) {
164164
; AVX2-LABEL: fsub_v8f32_commute:
165165
; AVX2: # %bb.0:
@@ -214,15 +214,17 @@ define <16 x float> @fsub_v16f32_swap(<16 x i1> %b, <16 x float> noundef %x, <16
214214
; AVX512: # %bb.0:
215215
; AVX512-NEXT: vpmovsxbd %xmm0, %zmm0
216216
; AVX512-NEXT: vpslld $31, %zmm0, %zmm0
217-
; AVX512-NEXT: vptestnmd %zmm0, %zmm0, %k1
218-
; AVX512-NEXT: vmovaps %zmm2, %zmm0 {%k1} {z}
219-
; AVX512-NEXT: vsubps %zmm0, %zmm1, %zmm0
217+
; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k1
218+
; AVX512-NEXT: vsubps %zmm2, %zmm1, %zmm0
219+
; AVX512-NEXT: vmovaps %zmm1, %zmm0 {%k1}
220220
; AVX512-NEXT: retq
221221
%s = select <16 x i1> %b, <16 x float> zeroinitializer, <16 x float> %y
222222
%r = fsub <16 x float> %x, %s
223223
ret <16 x float> %r
224224
}
225225

226+
; negative test - fsub is not commutative; there is no identity constant for operand 0
227+
226228
define <16 x float> @fsub_v16f32_commute_swap(<16 x i1> %b, <16 x float> noundef %x, <16 x float> noundef %y) {
227229
; AVX2-LABEL: fsub_v16f32_commute_swap:
228230
; AVX2: # %bb.0:
@@ -570,9 +572,7 @@ define <8 x float> @fadd_v8f32_cast_cond(i8 noundef zeroext %pb, <8 x float> nou
570572
; AVX512VL-LABEL: fadd_v8f32_cast_cond:
571573
; AVX512VL: # %bb.0:
572574
; AVX512VL-NEXT: kmovw %edi, %k1
573-
; AVX512VL-NEXT: vbroadcastss {{.*#+}} ymm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
574-
; AVX512VL-NEXT: vmovaps %ymm1, %ymm2 {%k1}
575-
; AVX512VL-NEXT: vaddps %ymm2, %ymm0, %ymm0
575+
; AVX512VL-NEXT: vaddps %ymm1, %ymm0, %ymm0 {%k1}
576576
; AVX512VL-NEXT: retq
577577
%b = bitcast i8 %pb to <8 x i1>
578578
%s = select <8 x i1> %b, <8 x float> %y, <8 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>
@@ -636,9 +636,7 @@ define <8 x double> @fadd_v8f64_cast_cond(i8 noundef zeroext %pb, <8 x double> n
636636
; AVX512-LABEL: fadd_v8f64_cast_cond:
637637
; AVX512: # %bb.0:
638638
; AVX512-NEXT: kmovw %edi, %k1
639-
; AVX512-NEXT: vbroadcastsd {{.*#+}} zmm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
640-
; AVX512-NEXT: vmovapd %zmm1, %zmm2 {%k1}
641-
; AVX512-NEXT: vaddpd %zmm2, %zmm0, %zmm0
639+
; AVX512-NEXT: vaddpd %zmm1, %zmm0, %zmm0 {%k1}
642640
; AVX512-NEXT: retq
643641
%b = bitcast i8 %pb to <8 x i1>
644642
%s = select <8 x i1> %b, <8 x double> %y, <8 x double> <double -0.0, double -0.0, double -0.0, double -0.0, double -0.0, double -0.0, double -0.0, double -0.0>
@@ -709,8 +707,7 @@ define <8 x float> @fsub_v8f32_cast_cond(i8 noundef zeroext %pb, <8 x float> nou
709707
; AVX512VL-LABEL: fsub_v8f32_cast_cond:
710708
; AVX512VL: # %bb.0:
711709
; AVX512VL-NEXT: kmovw %edi, %k1
712-
; AVX512VL-NEXT: vmovaps %ymm1, %ymm1 {%k1} {z}
713-
; AVX512VL-NEXT: vsubps %ymm1, %ymm0, %ymm0
710+
; AVX512VL-NEXT: vsubps %ymm1, %ymm0, %ymm0 {%k1}
714711
; AVX512VL-NEXT: retq
715712
%b = bitcast i8 %pb to <8 x i1>
716713
%s = select <8 x i1> %b, <8 x float> %y, <8 x float> zeroinitializer
@@ -775,8 +772,7 @@ define <8 x double> @fsub_v8f64_cast_cond(i8 noundef zeroext %pb, <8 x double> n
775772
; AVX512-LABEL: fsub_v8f64_cast_cond:
776773
; AVX512: # %bb.0:
777774
; AVX512-NEXT: kmovw %edi, %k1
778-
; AVX512-NEXT: vmovapd %zmm1, %zmm1 {%k1} {z}
779-
; AVX512-NEXT: vsubpd %zmm1, %zmm0, %zmm0
775+
; AVX512-NEXT: vsubpd %zmm1, %zmm0, %zmm0 {%k1}
780776
; AVX512-NEXT: retq
781777
%b = bitcast i8 %pb to <8 x i1>
782778
%s = select <8 x i1> %b, <8 x double> %y, <8 x double> zeroinitializer

0 commit comments

Comments
 (0)