Skip to content

Commit 5ec0cca

Browse files
committed
Address review comments
1 parent d8f9f07 commit 5ec0cca

File tree

7 files changed

+88
-63
lines changed

7 files changed

+88
-63
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3342,10 +3342,6 @@ class TargetLoweringBase {
33423342
return false;
33433343
}
33443344

3345-
/// Try to convert an extract element of a vector setcc operation into an
3346-
/// extract element followed by a scalar operation.
3347-
virtual bool shouldScalarizeSetCC(SDValue VecOp) const { return false; }
3348-
33493345
/// Return true if extraction of a scalar element from the given vector type
33503346
/// at the given index is cheap. For example, if scalar operations occur on
33513347
/// the same register file as vector operations, then an extract element may

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 30 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -22746,68 +22746,49 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
2274622746

2274722747
/// Transform a vector binary operation into a scalar binary operation by moving
2274822748
/// the math/logic after an extract element of a vector.
22749-
static bool scalarizeExtractedBinOpCommon(SDNode *ExtElt, SelectionDAG &DAG,
22750-
const SDLoc &DL, bool IsSetCC,
22751-
SDValue &ScalarOp1,
22752-
SDValue &ScalarOp2) {
22749+
static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
22750+
const SDLoc &DL) {
22751+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2275322752
SDValue Vec = ExtElt->getOperand(0);
2275422753
SDValue Index = ExtElt->getOperand(1);
2275522754
auto *IndexC = dyn_cast<ConstantSDNode>(Index);
22756-
if (!IndexC || !Vec.hasOneUse() || Vec->getNumValues() != 1)
22757-
return false;
22755+
if (!IndexC ||
22756+
(!TLI.isBinOp(Vec.getOpcode()) && Vec.getOpcode() != ISD::SETCC) ||
22757+
!Vec.hasOneUse() || Vec->getNumValues() != 1)
22758+
return SDValue();
22759+
22760+
EVT ResVT = ExtElt->getValueType(0);
22761+
if (Vec.getOpcode() == ISD::SETCC &&
22762+
ResVT != Vec.getValueType().getVectorElementType())
22763+
return SDValue();
22764+
22765+
// Targets may want to avoid this to prevent an expensive register transfer.
22766+
if (!TLI.shouldScalarizeBinop(Vec))
22767+
return SDValue();
2275822768

2275922769
// Extracting an element of a vector constant is constant-folded, so this
2276022770
// transform is just replacing a vector op with a scalar op while moving the
2276122771
// extract.
2276222772
SDValue Op0 = Vec.getOperand(0);
2276322773
SDValue Op1 = Vec.getOperand(1);
2276422774
APInt SplatVal;
22765-
if (isAnyConstantBuildVector(Op0, true) ||
22766-
ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
22767-
isAnyConstantBuildVector(Op1, true) ||
22768-
ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
22769-
// extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
22770-
// extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
22771-
// extractelt (setcc X, C, op), IndexC -> setcc (extractelt X, IndexC)), C
22772-
// extractelt (setcc C, X, op), IndexC -> setcc (extractelt IndexC, X)), C
22773-
EVT VT = Op0->getValueType(0).getVectorElementType();
22774-
ScalarOp1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
22775-
ScalarOp2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
22776-
return true;
22777-
}
22778-
22779-
return false;
22780-
}
22781-
22782-
static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
22783-
const SDLoc &DL) {
22784-
SDValue Op1, Op2;
22785-
SDValue Vec = ExtElt->getOperand(0);
22786-
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22787-
if (!TLI.isBinOp(Vec.getOpcode()) || !TLI.shouldScalarizeBinop(Vec))
22788-
return SDValue();
22789-
22790-
if (!scalarizeExtractedBinOpCommon(ExtElt, DAG, DL, false, Op1, Op2))
22775+
if (!isAnyConstantBuildVector(Op0, true) &&
22776+
!ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
22777+
!isAnyConstantBuildVector(Op1, true) &&
22778+
!ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
2279122779
return SDValue();
2279222780

22793-
EVT VT = ExtElt->getValueType(0);
22794-
return DAG.getNode(Vec.getOpcode(), DL, VT, Op1, Op2);
22795-
}
22781+
// extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
22782+
// extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
22783+
EVT OpVT = Op0->getValueType(0).getVectorElementType();
22784+
Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
22785+
Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
2279622786

22797-
static SDValue scalarizeExtractedSetCC(SDNode *ExtElt, SelectionDAG &DAG,
22798-
const SDLoc &DL) {
22799-
SDValue Op1, Op2;
22800-
SDValue Vec = ExtElt->getOperand(0);
22801-
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22802-
if (Vec.getOpcode() != ISD::SETCC || !TLI.shouldScalarizeSetCC(Vec))
22803-
return SDValue();
22804-
22805-
if (!scalarizeExtractedBinOpCommon(ExtElt, DAG, DL, true, Op1, Op2))
22806-
return SDValue();
22807-
22808-
EVT VT = ExtElt->getValueType(0);
22809-
return DAG.getSetCC(DL, VT, Op1, Op2,
22810-
cast<CondCodeSDNode>(Vec->getOperand(2))->get());
22787+
if (Vec.getOpcode() == ISD::SETCC)
22788+
return DAG.getSetCC(DL, ResVT, Op0, Op1,
22789+
cast<CondCodeSDNode>(Vec->getOperand(2))->get());
22790+
else
22791+
return DAG.getNode(Vec.getOpcode(), DL, ResVT, Op0, Op1);
2281122792
}
2281222793

2281322794
// Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
@@ -23043,11 +23024,6 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
2304323024
if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL))
2304423025
return BO;
2304523026

23046-
// extract (setcc x, splat(y)), i -> setcc (extract x, i)), y
23047-
if (ScalarVT == VecVT.getVectorElementType())
23048-
if (SDValue SetCC = scalarizeExtractedSetCC(N, DAG, DL))
23049-
return SetCC;
23050-
2305123027
if (VecVT.isScalableVector())
2305223028
return SDValue();
2305323029

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1349,7 +1349,9 @@ class AArch64TargetLowering : public TargetLowering {
13491349

13501350
bool softPromoteHalfType() const override { return true; }
13511351

1352-
bool shouldScalarizeSetCC(SDValue VecOp) const override { return true; }
1352+
bool shouldScalarizeBinop(SDValue VecOp) const override {
1353+
return VecOp.getOpcode() == ISD::SETCC;
1354+
}
13531355
};
13541356

13551357
namespace AArch64 {

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2093,7 +2093,7 @@ bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
20932093

20942094
// Assume target opcodes can't be scalarized.
20952095
// TODO - do we have any exceptions?
2096-
if (Opc >= ISD::BUILTIN_OP_END)
2096+
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
20972097
return false;
20982098

20992099
// If the vector op is not supported, try to convert to scalar.

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ bool WebAssemblyTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
429429

430430
// Assume target opcodes can't be scalarized.
431431
// TODO - do we have any exceptions?
432-
if (Opc >= ISD::BUILTIN_OP_END)
432+
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
433433
return false;
434434

435435
// If the vector op is not supported, try to convert to scalar.

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3300,7 +3300,7 @@ bool X86TargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
33003300

33013301
// Assume target opcodes can't be scalarized.
33023302
// TODO - do we have any exceptions?
3303-
if (Opc >= ISD::BUILTIN_OP_END)
3303+
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
33043304
return false;
33053305

33063306
// If the vector op is not supported, try to convert to scalar.

llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,60 @@
55

66
declare void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8>, <vscale x 16 x ptr>, i32 immarg, <vscale x 16 x i1>)
77

8-
define fastcc i8 @allocno_reload_assign() {
8+
define fastcc i8 @allocno_reload_assign(ptr %p) {
99
; CHECK-LABEL: allocno_reload_assign:
1010
; CHECK: // %bb.0:
11+
; CHECK-NEXT: fmov d0, xzr
12+
; CHECK-NEXT: ptrue p0.d
13+
; CHECK-NEXT: mov z16.d, #0 // =0x0
14+
; CHECK-NEXT: cmpeq p0.d, p0/z, z0.d, #0
15+
; CHECK-NEXT: uzp1 p0.s, p0.s, p0.s
16+
; CHECK-NEXT: uzp1 p0.h, p0.h, p0.h
17+
; CHECK-NEXT: uzp1 p8.b, p0.b, p0.b
18+
; CHECK-NEXT: mov z0.b, p8/z, #1 // =0x1
19+
; CHECK-NEXT: fmov w8, s0
20+
; CHECK-NEXT: mov z0.b, #0 // =0x0
21+
; CHECK-NEXT: uunpklo z1.h, z0.b
22+
; CHECK-NEXT: uunpkhi z0.h, z0.b
23+
; CHECK-NEXT: mvn w8, w8
24+
; CHECK-NEXT: sbfx x8, x8, #0, #1
25+
; CHECK-NEXT: whilelo p0.b, xzr, x8
26+
; CHECK-NEXT: uunpklo z2.s, z1.h
27+
; CHECK-NEXT: uunpkhi z3.s, z1.h
28+
; CHECK-NEXT: uunpklo z5.s, z0.h
29+
; CHECK-NEXT: uunpkhi z7.s, z0.h
30+
; CHECK-NEXT: punpklo p1.h, p0.b
31+
; CHECK-NEXT: punpkhi p0.h, p0.b
32+
; CHECK-NEXT: punpklo p2.h, p1.b
33+
; CHECK-NEXT: punpkhi p4.h, p1.b
34+
; CHECK-NEXT: uunpklo z0.d, z2.s
35+
; CHECK-NEXT: uunpkhi z1.d, z2.s
36+
; CHECK-NEXT: punpklo p6.h, p0.b
37+
; CHECK-NEXT: uunpklo z2.d, z3.s
38+
; CHECK-NEXT: uunpkhi z3.d, z3.s
39+
; CHECK-NEXT: punpkhi p0.h, p0.b
40+
; CHECK-NEXT: uunpklo z4.d, z5.s
41+
; CHECK-NEXT: uunpkhi z5.d, z5.s
42+
; CHECK-NEXT: uunpklo z6.d, z7.s
43+
; CHECK-NEXT: uunpkhi z7.d, z7.s
44+
; CHECK-NEXT: punpklo p1.h, p2.b
45+
; CHECK-NEXT: punpkhi p2.h, p2.b
46+
; CHECK-NEXT: punpklo p3.h, p4.b
47+
; CHECK-NEXT: punpkhi p4.h, p4.b
48+
; CHECK-NEXT: punpklo p5.h, p6.b
49+
; CHECK-NEXT: punpkhi p6.h, p6.b
50+
; CHECK-NEXT: punpklo p7.h, p0.b
51+
; CHECK-NEXT: punpkhi p0.h, p0.b
1152
; CHECK-NEXT: .LBB0_1: // =>This Inner Loop Header: Depth=1
53+
; CHECK-NEXT: st1b { z0.d }, p1, [z16.d]
54+
; CHECK-NEXT: st1b { z1.d }, p2, [z16.d]
55+
; CHECK-NEXT: st1b { z2.d }, p3, [z16.d]
56+
; CHECK-NEXT: st1b { z3.d }, p4, [z16.d]
57+
; CHECK-NEXT: st1b { z4.d }, p5, [z16.d]
58+
; CHECK-NEXT: st1b { z5.d }, p6, [z16.d]
59+
; CHECK-NEXT: st1b { z6.d }, p7, [z16.d]
60+
; CHECK-NEXT: st1b { z7.d }, p0, [z16.d]
61+
; CHECK-NEXT: str p8, [x0]
1262
; CHECK-NEXT: b .LBB0_1
1363
br label %1
1464

@@ -17,6 +67,7 @@ define fastcc i8 @allocno_reload_assign() {
1767
%constexpr1 = shufflevector <vscale x 16 x i1> %constexpr, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
1868
%constexpr2 = xor <vscale x 16 x i1> %constexpr1, shufflevector (<vscale x 16 x i1> insertelement (<vscale x 16 x i1> poison, i1 true, i64 0), <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer)
1969
call void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8> zeroinitializer, <vscale x 16 x ptr> zeroinitializer, i32 0, <vscale x 16 x i1> %constexpr2)
70+
store <vscale x 16 x i1> %constexpr, ptr %p, align 16
2071
br label %1
2172
}
2273

0 commit comments

Comments
 (0)