Skip to content

Reapply "[DAGCombiner] Add support for scalarising extracts of a vector setcc (#117566)" #118823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22751,16 +22751,22 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,

/// Transform a vector binary operation into a scalar binary operation by moving
/// the math/logic after an extract element of a vector.
static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
const SDLoc &DL, bool LegalOperations) {
static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
const SDLoc &DL, bool LegalTypes) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDValue Vec = ExtElt->getOperand(0);
SDValue Index = ExtElt->getOperand(1);
auto *IndexC = dyn_cast<ConstantSDNode>(Index);
if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
unsigned Opc = Vec.getOpcode();
if (!IndexC || !Vec.hasOneUse() || (!TLI.isBinOp(Opc) && Opc != ISD::SETCC) ||
Vec->getNumValues() != 1)
return SDValue();

EVT ResVT = ExtElt->getValueType(0);
if (Opc == ISD::SETCC &&
(ResVT != Vec.getValueType().getVectorElementType() || LegalTypes))
return SDValue();

// Targets may want to avoid this to prevent an expensive register transfer.
if (!TLI.shouldScalarizeBinop(Vec))
return SDValue();
Expand All @@ -22771,19 +22777,24 @@ static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
SDValue Op0 = Vec.getOperand(0);
SDValue Op1 = Vec.getOperand(1);
APInt SplatVal;
if (isAnyConstantBuildVector(Op0, true) ||
ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
isAnyConstantBuildVector(Op1, true) ||
ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
// extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
// extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
EVT VT = ExtElt->getValueType(0);
SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
}
if (!isAnyConstantBuildVector(Op0, true) &&
!ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
!isAnyConstantBuildVector(Op1, true) &&
!ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
return SDValue();

return SDValue();
// extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
// extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
if (Opc == ISD::SETCC) {
EVT OpVT = Op0.getValueType().getVectorElementType();
Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
return DAG.getSetCC(DL, ResVT, Op0, Op1,
cast<CondCodeSDNode>(Vec->getOperand(2))->get());
}
Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op0, Index);
Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op1, Index);
return DAG.getNode(Opc, DL, ResVT, Op0, Op1);
}

// Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
Expand Down Expand Up @@ -23016,7 +23027,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
}
}

if (SDValue BO = scalarizeExtractedBinop(N, DAG, DL, LegalOperations))
if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL, LegalTypes))
return BO;

if (VecVT.isScalableVector())
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2835,6 +2835,7 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::SELECT_CC: SplitRes_SELECT_CC(N, Lo, Hi); break;
case ISD::UNDEF: SplitRes_UNDEF(N, Lo, Hi); break;
case ISD::FREEZE: SplitRes_FREEZE(N, Lo, Hi); break;
case ISD::SETCC: ExpandIntRes_SETCC(N, Lo, Hi); break;

case ISD::BITCAST: ExpandRes_BITCAST(N, Lo, Hi); break;
case ISD::BUILD_PAIR: ExpandRes_BUILD_PAIR(N, Lo, Hi); break;
Expand Down Expand Up @@ -3316,6 +3317,20 @@ static std::pair<ISD::CondCode, ISD::NodeType> getExpandedMinMaxOps(int Op) {
}
}

void DAGTypeLegalizer::ExpandIntRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
SDLoc DL(N);

SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
EVT NewVT = getSetCCResultType(LHS.getValueType());

// Taking the same approach as ScalarizeVecRes_SETCC
SDValue Res = DAG.getNode(ISD::SETCC, DL, NewVT, LHS, RHS, N->getOperand(2));

Res = DAG.getBoolExtOrTrunc(Res, DL, N->getValueType(0), NewVT);
SplitInteger(Res, Lo, Hi);
}

void DAGTypeLegalizer::ExpandIntRes_MINMAX(SDNode *N,
SDValue &Lo, SDValue &Hi) {
SDLoc DL(N);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void ExpandIntRes_MINMAX (SDNode *N, SDValue &Lo, SDValue &Hi);

void ExpandIntRes_CMP (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_SETCC (SDNode *N, SDValue &Lo, SDValue &Hi);

void ExpandIntRes_SADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_UADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,10 @@ class AArch64TargetLowering : public TargetLowering {
unsigned getMinimumJumpTableEntries() const override;

bool softPromoteHalfType() const override { return true; }

bool shouldScalarizeBinop(SDValue VecOp) const override {
return VecOp.getOpcode() == ISD::SETCC;
}
};

namespace AArch64 {
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2093,7 +2093,7 @@ bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {

// Assume target opcodes can't be scalarized.
// TODO - do we have any exceptions?
if (Opc >= ISD::BUILTIN_OP_END)
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
return false;

// If the vector op is not supported, try to convert to scalar.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ bool WebAssemblyTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {

// Assume target opcodes can't be scalarized.
// TODO - do we have any exceptions?
if (Opc >= ISD::BUILTIN_OP_END)
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
return false;

// If the vector op is not supported, try to convert to scalar.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3306,7 +3306,7 @@ bool X86TargetLowering::shouldScalarizeBinop(SDValue VecOp) const {

// Assume target opcodes can't be scalarized.
// TODO - do we have any exceptions?
if (Opc >= ISD::BUILTIN_OP_END)
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
return false;

// If the vector op is not supported, try to convert to scalar.
Expand Down
46 changes: 24 additions & 22 deletions llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

declare void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8>, <vscale x 16 x ptr>, i32 immarg, <vscale x 16 x i1>)

define fastcc i8 @allocno_reload_assign() {
define fastcc i8 @allocno_reload_assign(ptr %p) {
; CHECK-LABEL: allocno_reload_assign:
; CHECK: // %bb.0:
; CHECK-NEXT: fmov d0, xzr
Expand All @@ -14,8 +14,8 @@ define fastcc i8 @allocno_reload_assign() {
; CHECK-NEXT: cmpeq p0.d, p0/z, z0.d, #0
; CHECK-NEXT: uzp1 p0.s, p0.s, p0.s
; CHECK-NEXT: uzp1 p0.h, p0.h, p0.h
; CHECK-NEXT: uzp1 p0.b, p0.b, p0.b
; CHECK-NEXT: mov z0.b, p0/z, #1 // =0x1
; CHECK-NEXT: uzp1 p8.b, p0.b, p0.b
; CHECK-NEXT: mov z0.b, p8/z, #1 // =0x1
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: mov z0.b, #0 // =0x0
; CHECK-NEXT: uunpklo z1.h, z0.b
Expand All @@ -30,34 +30,35 @@ define fastcc i8 @allocno_reload_assign() {
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: punpklo p2.h, p1.b
; CHECK-NEXT: punpkhi p3.h, p1.b
; CHECK-NEXT: punpkhi p4.h, p1.b
; CHECK-NEXT: uunpklo z0.d, z2.s
; CHECK-NEXT: uunpkhi z1.d, z2.s
; CHECK-NEXT: punpklo p5.h, p0.b
; CHECK-NEXT: punpklo p6.h, p0.b
; CHECK-NEXT: uunpklo z2.d, z3.s
; CHECK-NEXT: uunpkhi z3.d, z3.s
; CHECK-NEXT: punpkhi p7.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: uunpklo z4.d, z5.s
; CHECK-NEXT: uunpkhi z5.d, z5.s
; CHECK-NEXT: uunpklo z6.d, z7.s
; CHECK-NEXT: uunpkhi z7.d, z7.s
; CHECK-NEXT: punpklo p0.h, p2.b
; CHECK-NEXT: punpkhi p1.h, p2.b
; CHECK-NEXT: punpklo p2.h, p3.b
; CHECK-NEXT: punpkhi p3.h, p3.b
; CHECK-NEXT: punpklo p4.h, p5.b
; CHECK-NEXT: punpkhi p5.h, p5.b
; CHECK-NEXT: punpklo p6.h, p7.b
; CHECK-NEXT: punpkhi p7.h, p7.b
; CHECK-NEXT: punpklo p1.h, p2.b
; CHECK-NEXT: punpkhi p2.h, p2.b
; CHECK-NEXT: punpklo p3.h, p4.b
; CHECK-NEXT: punpkhi p4.h, p4.b
; CHECK-NEXT: punpklo p5.h, p6.b
; CHECK-NEXT: punpkhi p6.h, p6.b
; CHECK-NEXT: punpklo p7.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: .LBB0_1: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: st1b { z0.d }, p0, [z16.d]
; CHECK-NEXT: st1b { z1.d }, p1, [z16.d]
; CHECK-NEXT: st1b { z2.d }, p2, [z16.d]
; CHECK-NEXT: st1b { z3.d }, p3, [z16.d]
; CHECK-NEXT: st1b { z4.d }, p4, [z16.d]
; CHECK-NEXT: st1b { z5.d }, p5, [z16.d]
; CHECK-NEXT: st1b { z6.d }, p6, [z16.d]
; CHECK-NEXT: st1b { z7.d }, p7, [z16.d]
; CHECK-NEXT: st1b { z0.d }, p1, [z16.d]
; CHECK-NEXT: st1b { z1.d }, p2, [z16.d]
; CHECK-NEXT: st1b { z2.d }, p3, [z16.d]
; CHECK-NEXT: st1b { z3.d }, p4, [z16.d]
; CHECK-NEXT: st1b { z4.d }, p5, [z16.d]
; CHECK-NEXT: st1b { z5.d }, p6, [z16.d]
; CHECK-NEXT: st1b { z6.d }, p7, [z16.d]
; CHECK-NEXT: st1b { z7.d }, p0, [z16.d]
; CHECK-NEXT: str p8, [x0]
; CHECK-NEXT: b .LBB0_1
br label %1

Expand All @@ -66,6 +67,7 @@ define fastcc i8 @allocno_reload_assign() {
%constexpr1 = shufflevector <vscale x 16 x i1> %constexpr, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
%constexpr2 = xor <vscale x 16 x i1> %constexpr1, shufflevector (<vscale x 16 x i1> insertelement (<vscale x 16 x i1> poison, i1 true, i64 0), <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer)
call void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8> zeroinitializer, <vscale x 16 x ptr> zeroinitializer, i32 0, <vscale x 16 x i1> %constexpr2)
store <vscale x 16 x i1> %constexpr, ptr %p, align 16
br label %1
}

Expand Down
Loading
Loading