Skip to content

Commit 8630a7b

Browse files
Reapply "[DAGCombiner] Add support for scalarising extracts of a vector setcc (#117566)" (#118823)
[Reverts d57892a] For IR like this: %icmp = icmp ult <4 x i32> %a, splat (i32 5) %res = extractelement <4 x i1> %icmp, i32 1 where there is only one use of %icmp we can take a similar approach to what we already do for binary ops such add, sub, etc. and convert this into %ext = extractelement <4 x i32> %a, i32 1 %res = icmp ult i32 %ext, 5 For AArch64 targets at least the scalar boolean result will almost certainly need to be in a GPR anyway, since it will probably be used by branches for control flow. I've tried to reuse existing code in scalarizeExtractedBinop to also work for setcc. NOTE: The optimisations don't apply for tests such as extract_icmp_v4i32_splat_rhs in the file CodeGen/AArch64/extract-vector-cmp.ll because scalarizeExtractedBinOp only works if one of the input operands is a constant. --------- Co-authored-by: Paul Walker <[email protected]>
1 parent 6a52a51 commit 8630a7b

File tree

10 files changed

+304
-41
lines changed

10 files changed

+304
-41
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22755,16 +22755,22 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
2275522755

2275622756
/// Transform a vector binary operation into a scalar binary operation by moving
2275722757
/// the math/logic after an extract element of a vector.
22758-
static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
22759-
const SDLoc &DL, bool LegalOperations) {
22758+
static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
22759+
const SDLoc &DL, bool LegalTypes) {
2276022760
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2276122761
SDValue Vec = ExtElt->getOperand(0);
2276222762
SDValue Index = ExtElt->getOperand(1);
2276322763
auto *IndexC = dyn_cast<ConstantSDNode>(Index);
22764-
if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
22764+
unsigned Opc = Vec.getOpcode();
22765+
if (!IndexC || !Vec.hasOneUse() || (!TLI.isBinOp(Opc) && Opc != ISD::SETCC) ||
2276522766
Vec->getNumValues() != 1)
2276622767
return SDValue();
2276722768

22769+
EVT ResVT = ExtElt->getValueType(0);
22770+
if (Opc == ISD::SETCC &&
22771+
(ResVT != Vec.getValueType().getVectorElementType() || LegalTypes))
22772+
return SDValue();
22773+
2276822774
// Targets may want to avoid this to prevent an expensive register transfer.
2276922775
if (!TLI.shouldScalarizeBinop(Vec))
2277022776
return SDValue();
@@ -22775,19 +22781,24 @@ static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
2277522781
SDValue Op0 = Vec.getOperand(0);
2277622782
SDValue Op1 = Vec.getOperand(1);
2277722783
APInt SplatVal;
22778-
if (isAnyConstantBuildVector(Op0, true) ||
22779-
ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
22780-
isAnyConstantBuildVector(Op1, true) ||
22781-
ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
22782-
// extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
22783-
// extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
22784-
EVT VT = ExtElt->getValueType(0);
22785-
SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
22786-
SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
22787-
return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
22788-
}
22784+
if (!isAnyConstantBuildVector(Op0, true) &&
22785+
!ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
22786+
!isAnyConstantBuildVector(Op1, true) &&
22787+
!ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
22788+
return SDValue();
2278922789

22790-
return SDValue();
22790+
// extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
22791+
// extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
22792+
if (Opc == ISD::SETCC) {
22793+
EVT OpVT = Op0.getValueType().getVectorElementType();
22794+
Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
22795+
Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
22796+
return DAG.getSetCC(DL, ResVT, Op0, Op1,
22797+
cast<CondCodeSDNode>(Vec->getOperand(2))->get());
22798+
}
22799+
Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op0, Index);
22800+
Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op1, Index);
22801+
return DAG.getNode(Opc, DL, ResVT, Op0, Op1);
2279122802
}
2279222803

2279322804
// Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
@@ -23020,7 +23031,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
2302023031
}
2302123032
}
2302223033

23023-
if (SDValue BO = scalarizeExtractedBinop(N, DAG, DL, LegalOperations))
23034+
if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL, LegalTypes))
2302423035
return BO;
2302523036

2302623037
if (VecVT.isScalableVector())

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,6 +2835,7 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
28352835
case ISD::SELECT_CC: SplitRes_SELECT_CC(N, Lo, Hi); break;
28362836
case ISD::UNDEF: SplitRes_UNDEF(N, Lo, Hi); break;
28372837
case ISD::FREEZE: SplitRes_FREEZE(N, Lo, Hi); break;
2838+
case ISD::SETCC: ExpandIntRes_SETCC(N, Lo, Hi); break;
28382839

28392840
case ISD::BITCAST: ExpandRes_BITCAST(N, Lo, Hi); break;
28402841
case ISD::BUILD_PAIR: ExpandRes_BUILD_PAIR(N, Lo, Hi); break;
@@ -3316,6 +3317,20 @@ static std::pair<ISD::CondCode, ISD::NodeType> getExpandedMinMaxOps(int Op) {
33163317
}
33173318
}
33183319

3320+
void DAGTypeLegalizer::ExpandIntRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
3321+
SDLoc DL(N);
3322+
3323+
SDValue LHS = N->getOperand(0);
3324+
SDValue RHS = N->getOperand(1);
3325+
EVT NewVT = getSetCCResultType(LHS.getValueType());
3326+
3327+
// Taking the same approach as ScalarizeVecRes_SETCC
3328+
SDValue Res = DAG.getNode(ISD::SETCC, DL, NewVT, LHS, RHS, N->getOperand(2));
3329+
3330+
Res = DAG.getBoolExtOrTrunc(Res, DL, N->getValueType(0), NewVT);
3331+
SplitInteger(Res, Lo, Hi);
3332+
}
3333+
33193334
void DAGTypeLegalizer::ExpandIntRes_MINMAX(SDNode *N,
33203335
SDValue &Lo, SDValue &Hi) {
33213336
SDLoc DL(N);

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
487487
void ExpandIntRes_MINMAX (SDNode *N, SDValue &Lo, SDValue &Hi);
488488

489489
void ExpandIntRes_CMP (SDNode *N, SDValue &Lo, SDValue &Hi);
490+
void ExpandIntRes_SETCC (SDNode *N, SDValue &Lo, SDValue &Hi);
490491

491492
void ExpandIntRes_SADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
492493
void ExpandIntRes_UADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,6 +1348,10 @@ class AArch64TargetLowering : public TargetLowering {
13481348
unsigned getMinimumJumpTableEntries() const override;
13491349

13501350
bool softPromoteHalfType() const override { return true; }
1351+
1352+
bool shouldScalarizeBinop(SDValue VecOp) const override {
1353+
return VecOp.getOpcode() == ISD::SETCC;
1354+
}
13511355
};
13521356

13531357
namespace AArch64 {

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2107,7 +2107,7 @@ bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
21072107

21082108
// Assume target opcodes can't be scalarized.
21092109
// TODO - do we have any exceptions?
2110-
if (Opc >= ISD::BUILTIN_OP_END)
2110+
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
21112111
return false;
21122112

21132113
// If the vector op is not supported, try to convert to scalar.

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ bool WebAssemblyTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
429429

430430
// Assume target opcodes can't be scalarized.
431431
// TODO - do we have any exceptions?
432-
if (Opc >= ISD::BUILTIN_OP_END)
432+
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
433433
return false;
434434

435435
// If the vector op is not supported, try to convert to scalar.

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3306,7 +3306,7 @@ bool X86TargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
33063306

33073307
// Assume target opcodes can't be scalarized.
33083308
// TODO - do we have any exceptions?
3309-
if (Opc >= ISD::BUILTIN_OP_END)
3309+
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
33103310
return false;
33113311

33123312
// If the vector op is not supported, try to convert to scalar.

llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
declare void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8>, <vscale x 16 x ptr>, i32 immarg, <vscale x 16 x i1>)
77

8-
define fastcc i8 @allocno_reload_assign() {
8+
define fastcc i8 @allocno_reload_assign(ptr %p) {
99
; CHECK-LABEL: allocno_reload_assign:
1010
; CHECK: // %bb.0:
1111
; CHECK-NEXT: fmov d0, xzr
@@ -14,8 +14,8 @@ define fastcc i8 @allocno_reload_assign() {
1414
; CHECK-NEXT: cmpeq p0.d, p0/z, z0.d, #0
1515
; CHECK-NEXT: uzp1 p0.s, p0.s, p0.s
1616
; CHECK-NEXT: uzp1 p0.h, p0.h, p0.h
17-
; CHECK-NEXT: uzp1 p0.b, p0.b, p0.b
18-
; CHECK-NEXT: mov z0.b, p0/z, #1 // =0x1
17+
; CHECK-NEXT: uzp1 p8.b, p0.b, p0.b
18+
; CHECK-NEXT: mov z0.b, p8/z, #1 // =0x1
1919
; CHECK-NEXT: fmov w8, s0
2020
; CHECK-NEXT: mov z0.b, #0 // =0x0
2121
; CHECK-NEXT: uunpklo z1.h, z0.b
@@ -30,34 +30,35 @@ define fastcc i8 @allocno_reload_assign() {
3030
; CHECK-NEXT: punpklo p1.h, p0.b
3131
; CHECK-NEXT: punpkhi p0.h, p0.b
3232
; CHECK-NEXT: punpklo p2.h, p1.b
33-
; CHECK-NEXT: punpkhi p3.h, p1.b
33+
; CHECK-NEXT: punpkhi p4.h, p1.b
3434
; CHECK-NEXT: uunpklo z0.d, z2.s
3535
; CHECK-NEXT: uunpkhi z1.d, z2.s
36-
; CHECK-NEXT: punpklo p5.h, p0.b
36+
; CHECK-NEXT: punpklo p6.h, p0.b
3737
; CHECK-NEXT: uunpklo z2.d, z3.s
3838
; CHECK-NEXT: uunpkhi z3.d, z3.s
39-
; CHECK-NEXT: punpkhi p7.h, p0.b
39+
; CHECK-NEXT: punpkhi p0.h, p0.b
4040
; CHECK-NEXT: uunpklo z4.d, z5.s
4141
; CHECK-NEXT: uunpkhi z5.d, z5.s
4242
; CHECK-NEXT: uunpklo z6.d, z7.s
4343
; CHECK-NEXT: uunpkhi z7.d, z7.s
44-
; CHECK-NEXT: punpklo p0.h, p2.b
45-
; CHECK-NEXT: punpkhi p1.h, p2.b
46-
; CHECK-NEXT: punpklo p2.h, p3.b
47-
; CHECK-NEXT: punpkhi p3.h, p3.b
48-
; CHECK-NEXT: punpklo p4.h, p5.b
49-
; CHECK-NEXT: punpkhi p5.h, p5.b
50-
; CHECK-NEXT: punpklo p6.h, p7.b
51-
; CHECK-NEXT: punpkhi p7.h, p7.b
44+
; CHECK-NEXT: punpklo p1.h, p2.b
45+
; CHECK-NEXT: punpkhi p2.h, p2.b
46+
; CHECK-NEXT: punpklo p3.h, p4.b
47+
; CHECK-NEXT: punpkhi p4.h, p4.b
48+
; CHECK-NEXT: punpklo p5.h, p6.b
49+
; CHECK-NEXT: punpkhi p6.h, p6.b
50+
; CHECK-NEXT: punpklo p7.h, p0.b
51+
; CHECK-NEXT: punpkhi p0.h, p0.b
5252
; CHECK-NEXT: .LBB0_1: // =>This Inner Loop Header: Depth=1
53-
; CHECK-NEXT: st1b { z0.d }, p0, [z16.d]
54-
; CHECK-NEXT: st1b { z1.d }, p1, [z16.d]
55-
; CHECK-NEXT: st1b { z2.d }, p2, [z16.d]
56-
; CHECK-NEXT: st1b { z3.d }, p3, [z16.d]
57-
; CHECK-NEXT: st1b { z4.d }, p4, [z16.d]
58-
; CHECK-NEXT: st1b { z5.d }, p5, [z16.d]
59-
; CHECK-NEXT: st1b { z6.d }, p6, [z16.d]
60-
; CHECK-NEXT: st1b { z7.d }, p7, [z16.d]
53+
; CHECK-NEXT: st1b { z0.d }, p1, [z16.d]
54+
; CHECK-NEXT: st1b { z1.d }, p2, [z16.d]
55+
; CHECK-NEXT: st1b { z2.d }, p3, [z16.d]
56+
; CHECK-NEXT: st1b { z3.d }, p4, [z16.d]
57+
; CHECK-NEXT: st1b { z4.d }, p5, [z16.d]
58+
; CHECK-NEXT: st1b { z5.d }, p6, [z16.d]
59+
; CHECK-NEXT: st1b { z6.d }, p7, [z16.d]
60+
; CHECK-NEXT: st1b { z7.d }, p0, [z16.d]
61+
; CHECK-NEXT: str p8, [x0]
6162
; CHECK-NEXT: b .LBB0_1
6263
br label %1
6364

@@ -66,6 +67,7 @@ define fastcc i8 @allocno_reload_assign() {
6667
%constexpr1 = shufflevector <vscale x 16 x i1> %constexpr, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
6768
%constexpr2 = xor <vscale x 16 x i1> %constexpr1, shufflevector (<vscale x 16 x i1> insertelement (<vscale x 16 x i1> poison, i1 true, i64 0), <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer)
6869
call void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8> zeroinitializer, <vscale x 16 x ptr> zeroinitializer, i32 0, <vscale x 16 x i1> %constexpr2)
70+
store <vscale x 16 x i1> %constexpr, ptr %p, align 16
6971
br label %1
7072
}
7173

0 commit comments

Comments
 (0)