Skip to content

Commit dc72ec8

Browse files
authored
[RISCV] Custom legalize vp.merge for mask vectors. (#120479)
The default legalization uses vmslt with a vector of XLen to compute a mask. This doesn't work if the type isn't legal. For fixed vectors it will scalarize. For scalable vectors it crashes the compiler. This patch uses an alternate strategy that promotes the i1 vector to an i8 vector and does the merge. I don't claim this to be the best lowering. I wrote it quickly almost 3 years ago when a crash was reported in our downstream. Fixes #120405.
1 parent b56d1ec commit dc72ec8

File tree

4 files changed

+454
-14
lines changed

4 files changed

+454
-14
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -758,9 +758,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
758758
Custom);
759759

760760
setOperationAction(ISD::SELECT, VT, Custom);
761-
setOperationAction(
762-
{ISD::SELECT_CC, ISD::VSELECT, ISD::VP_MERGE, ISD::VP_SELECT}, VT,
763-
Expand);
761+
setOperationAction({ISD::SELECT_CC, ISD::VSELECT, ISD::VP_SELECT}, VT,
762+
Expand);
763+
setOperationAction(ISD::VP_MERGE, VT, Custom);
764764

765765
setOperationAction({ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF}, VT,
766766
Custom);
@@ -1237,6 +1237,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
12371237
ISD::VP_SETCC, ISD::VP_TRUNCATE},
12381238
VT, Custom);
12391239

1240+
setOperationAction(ISD::VP_MERGE, VT, Custom);
1241+
12401242
setOperationAction(ISD::EXPERIMENTAL_VP_SPLICE, VT, Custom);
12411243
setOperationAction(ISD::EXPERIMENTAL_VP_REVERSE, VT, Custom);
12421244
continue;
@@ -7492,8 +7494,11 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
74927494
return lowerSET_ROUNDING(Op, DAG);
74937495
case ISD::EH_DWARF_CFA:
74947496
return lowerEH_DWARF_CFA(Op, DAG);
7495-
case ISD::VP_SELECT:
74967497
case ISD::VP_MERGE:
7498+
if (Op.getSimpleValueType().getVectorElementType() == MVT::i1)
7499+
return lowerVPMergeMask(Op, DAG);
7500+
[[fallthrough]];
7501+
case ISD::VP_SELECT:
74977502
case ISD::VP_ADD:
74987503
case ISD::VP_SUB:
74997504
case ISD::VP_MUL:
@@ -12078,6 +12083,65 @@ SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op,
1207812083
return convertFromScalableVector(VT, Result, DAG, Subtarget);
1207912084
}
1208012085

12086+
SDValue RISCVTargetLowering::lowerVPMergeMask(SDValue Op,
12087+
SelectionDAG &DAG) const {
12088+
SDLoc DL(Op);
12089+
MVT VT = Op.getSimpleValueType();
12090+
MVT XLenVT = Subtarget.getXLenVT();
12091+
12092+
SDValue Mask = Op.getOperand(0);
12093+
SDValue TrueVal = Op.getOperand(1);
12094+
SDValue FalseVal = Op.getOperand(2);
12095+
SDValue VL = Op.getOperand(3);
12096+
12097+
// Use default legalization if a vector of EVL type would be legal.
12098+
EVT EVLVecVT = EVT::getVectorVT(*DAG.getContext(), VL.getValueType(),
12099+
VT.getVectorElementCount());
12100+
if (isTypeLegal(EVLVecVT))
12101+
return SDValue();
12102+
12103+
MVT ContainerVT = VT;
12104+
if (VT.isFixedLengthVector()) {
12105+
ContainerVT = getContainerForFixedLengthVector(VT);
12106+
Mask = convertToScalableVector(ContainerVT, Mask, DAG, Subtarget);
12107+
TrueVal = convertToScalableVector(ContainerVT, TrueVal, DAG, Subtarget);
12108+
FalseVal = convertToScalableVector(ContainerVT, FalseVal, DAG, Subtarget);
12109+
}
12110+
12111+
// Promote to a vector of i8.
12112+
MVT PromotedVT = ContainerVT.changeVectorElementType(MVT::i8);
12113+
12114+
// Promote TrueVal and FalseVal using VLMax.
12115+
// FIXME: Is there a better way to do this?
12116+
SDValue VLMax = DAG.getRegister(RISCV::X0, XLenVT);
12117+
SDValue SplatOne = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, PromotedVT,
12118+
DAG.getUNDEF(PromotedVT),
12119+
DAG.getConstant(1, DL, XLenVT), VLMax);
12120+
SDValue SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, PromotedVT,
12121+
DAG.getUNDEF(PromotedVT),
12122+
DAG.getConstant(0, DL, XLenVT), VLMax);
12123+
TrueVal = DAG.getNode(RISCVISD::VMERGE_VL, DL, PromotedVT, TrueVal, SplatOne,
12124+
SplatZero, DAG.getUNDEF(PromotedVT), VL);
12125+
// Any element past VL uses FalseVal, so use VLMax
12126+
FalseVal = DAG.getNode(RISCVISD::VMERGE_VL, DL, PromotedVT, FalseVal,
12127+
SplatOne, SplatZero, DAG.getUNDEF(PromotedVT), VLMax);
12128+
12129+
// VP_MERGE the two promoted values.
12130+
SDValue VPMerge = DAG.getNode(RISCVISD::VMERGE_VL, DL, PromotedVT, Mask,
12131+
TrueVal, FalseVal, FalseVal, VL);
12132+
12133+
// Convert back to mask.
12134+
SDValue TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, ContainerVT, VL);
12135+
SDValue Result = DAG.getNode(
12136+
RISCVISD::SETCC_VL, DL, ContainerVT,
12137+
{VPMerge, DAG.getConstant(0, DL, PromotedVT), DAG.getCondCode(ISD::SETNE),
12138+
DAG.getUNDEF(getMaskTypeFor(ContainerVT)), TrueMask, VLMax});
12139+
12140+
if (VT.isFixedLengthVector())
12141+
Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
12142+
return Result;
12143+
}
12144+
1208112145
SDValue
1208212146
RISCVTargetLowering::lowerVPSpliceExperimental(SDValue Op,
1208312147
SelectionDAG &DAG) const {

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,7 @@ class RISCVTargetLowering : public TargetLowering {
996996
SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG) const;
997997
SDValue lowerVPExtMaskOp(SDValue Op, SelectionDAG &DAG) const;
998998
SDValue lowerVPSetCCMaskOp(SDValue Op, SelectionDAG &DAG) const;
999+
SDValue lowerVPMergeMask(SDValue Op, SelectionDAG &DAG) const;
9991000
SDValue lowerVPSplatExperimental(SDValue Op, SelectionDAG &DAG) const;
10001001
SDValue lowerVPSpliceExperimental(SDValue Op, SelectionDAG &DAG) const;
10011002
SDValue lowerVPReverseExperimental(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vpmerge.ll

Lines changed: 180 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,182 @@ define <4 x i1> @vpmerge_vv_v4i1(<4 x i1> %va, <4 x i1> %vb, <4 x i1> %m, i32 ze
5858
ret <4 x i1> %v
5959
}
6060

61+
define <8 x i1> @vpmerge_vv_v8i1(<8 x i1> %va, <8 x i1> %vb, <8 x i1> %m, i32 zeroext %evl) {
62+
; RV32-LABEL: vpmerge_vv_v8i1:
63+
; RV32: # %bb.0:
64+
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, ma
65+
; RV32-NEXT: vid.v v10
66+
; RV32-NEXT: vmsltu.vx v12, v10, a0
67+
; RV32-NEXT: vmand.mm v9, v9, v12
68+
; RV32-NEXT: vmandn.mm v8, v8, v9
69+
; RV32-NEXT: vmand.mm v9, v0, v9
70+
; RV32-NEXT: vmor.mm v0, v9, v8
71+
; RV32-NEXT: ret
72+
;
73+
; RV64-LABEL: vpmerge_vv_v8i1:
74+
; RV64: # %bb.0:
75+
; RV64-NEXT: vsetivli zero, 8, e64, m4, ta, ma
76+
; RV64-NEXT: vid.v v12
77+
; RV64-NEXT: vmsltu.vx v10, v12, a0
78+
; RV64-NEXT: vmand.mm v9, v9, v10
79+
; RV64-NEXT: vmandn.mm v8, v8, v9
80+
; RV64-NEXT: vmand.mm v9, v0, v9
81+
; RV64-NEXT: vmor.mm v0, v9, v8
82+
; RV64-NEXT: ret
83+
;
84+
; RV32ZVFHMIN-LABEL: vpmerge_vv_v8i1:
85+
; RV32ZVFHMIN: # %bb.0:
86+
; RV32ZVFHMIN-NEXT: vsetivli zero, 8, e32, m2, ta, ma
87+
; RV32ZVFHMIN-NEXT: vid.v v10
88+
; RV32ZVFHMIN-NEXT: vmsltu.vx v12, v10, a0
89+
; RV32ZVFHMIN-NEXT: vmand.mm v9, v9, v12
90+
; RV32ZVFHMIN-NEXT: vmandn.mm v8, v8, v9
91+
; RV32ZVFHMIN-NEXT: vmand.mm v9, v0, v9
92+
; RV32ZVFHMIN-NEXT: vmor.mm v0, v9, v8
93+
; RV32ZVFHMIN-NEXT: ret
94+
;
95+
; RV64ZVFHMIN-LABEL: vpmerge_vv_v8i1:
96+
; RV64ZVFHMIN: # %bb.0:
97+
; RV64ZVFHMIN-NEXT: vsetivli zero, 8, e64, m4, ta, ma
98+
; RV64ZVFHMIN-NEXT: vid.v v12
99+
; RV64ZVFHMIN-NEXT: vmsltu.vx v10, v12, a0
100+
; RV64ZVFHMIN-NEXT: vmand.mm v9, v9, v10
101+
; RV64ZVFHMIN-NEXT: vmandn.mm v8, v8, v9
102+
; RV64ZVFHMIN-NEXT: vmand.mm v9, v0, v9
103+
; RV64ZVFHMIN-NEXT: vmor.mm v0, v9, v8
104+
; RV64ZVFHMIN-NEXT: ret
105+
%v = call <8 x i1> @llvm.vp.merge.v8i1(<8 x i1> %m, <8 x i1> %va, <8 x i1> %vb, i32 %evl)
106+
ret <8 x i1> %v
107+
}
108+
109+
define <16 x i1> @vpmerge_vv_v16i1(<16 x i1> %va, <16 x i1> %vb, <16 x i1> %m, i32 zeroext %evl) {
110+
; RV32-LABEL: vpmerge_vv_v16i1:
111+
; RV32: # %bb.0:
112+
; RV32-NEXT: vsetivli zero, 16, e32, m4, ta, ma
113+
; RV32-NEXT: vid.v v12
114+
; RV32-NEXT: vmsltu.vx v10, v12, a0
115+
; RV32-NEXT: vmand.mm v9, v9, v10
116+
; RV32-NEXT: vmandn.mm v8, v8, v9
117+
; RV32-NEXT: vmand.mm v9, v0, v9
118+
; RV32-NEXT: vmor.mm v0, v9, v8
119+
; RV32-NEXT: ret
120+
;
121+
; RV64-LABEL: vpmerge_vv_v16i1:
122+
; RV64: # %bb.0:
123+
; RV64-NEXT: vsetivli zero, 16, e64, m8, ta, ma
124+
; RV64-NEXT: vid.v v16
125+
; RV64-NEXT: vmsltu.vx v10, v16, a0
126+
; RV64-NEXT: vmand.mm v9, v9, v10
127+
; RV64-NEXT: vmandn.mm v8, v8, v9
128+
; RV64-NEXT: vmand.mm v9, v0, v9
129+
; RV64-NEXT: vmor.mm v0, v9, v8
130+
; RV64-NEXT: ret
131+
;
132+
; RV32ZVFHMIN-LABEL: vpmerge_vv_v16i1:
133+
; RV32ZVFHMIN: # %bb.0:
134+
; RV32ZVFHMIN-NEXT: vsetivli zero, 16, e32, m4, ta, ma
135+
; RV32ZVFHMIN-NEXT: vid.v v12
136+
; RV32ZVFHMIN-NEXT: vmsltu.vx v10, v12, a0
137+
; RV32ZVFHMIN-NEXT: vmand.mm v9, v9, v10
138+
; RV32ZVFHMIN-NEXT: vmandn.mm v8, v8, v9
139+
; RV32ZVFHMIN-NEXT: vmand.mm v9, v0, v9
140+
; RV32ZVFHMIN-NEXT: vmor.mm v0, v9, v8
141+
; RV32ZVFHMIN-NEXT: ret
142+
;
143+
; RV64ZVFHMIN-LABEL: vpmerge_vv_v16i1:
144+
; RV64ZVFHMIN: # %bb.0:
145+
; RV64ZVFHMIN-NEXT: vsetivli zero, 16, e64, m8, ta, ma
146+
; RV64ZVFHMIN-NEXT: vid.v v16
147+
; RV64ZVFHMIN-NEXT: vmsltu.vx v10, v16, a0
148+
; RV64ZVFHMIN-NEXT: vmand.mm v9, v9, v10
149+
; RV64ZVFHMIN-NEXT: vmandn.mm v8, v8, v9
150+
; RV64ZVFHMIN-NEXT: vmand.mm v9, v0, v9
151+
; RV64ZVFHMIN-NEXT: vmor.mm v0, v9, v8
152+
; RV64ZVFHMIN-NEXT: ret
153+
%v = call <16 x i1> @llvm.vp.merge.v16i1(<16 x i1> %m, <16 x i1> %va, <16 x i1> %vb, i32 %evl)
154+
ret <16 x i1> %v
155+
}
156+
157+
define <32 x i1> @vpmerge_vv_v32i1(<32 x i1> %va, <32 x i1> %vb, <32 x i1> %m, i32 zeroext %evl) {
158+
; RV32-LABEL: vpmerge_vv_v32i1:
159+
; RV32: # %bb.0:
160+
; RV32-NEXT: li a1, 32
161+
; RV32-NEXT: vsetvli zero, a1, e32, m8, ta, ma
162+
; RV32-NEXT: vid.v v16
163+
; RV32-NEXT: vmsltu.vx v10, v16, a0
164+
; RV32-NEXT: vmand.mm v9, v9, v10
165+
; RV32-NEXT: vmandn.mm v8, v8, v9
166+
; RV32-NEXT: vmand.mm v9, v0, v9
167+
; RV32-NEXT: vmor.mm v0, v9, v8
168+
; RV32-NEXT: ret
169+
;
170+
; RV64-LABEL: vpmerge_vv_v32i1:
171+
; RV64: # %bb.0:
172+
; RV64-NEXT: vsetvli a1, zero, e8, m2, ta, ma
173+
; RV64-NEXT: vmv.v.i v10, 0
174+
; RV64-NEXT: vsetvli zero, a0, e8, m2, ta, ma
175+
; RV64-NEXT: vmerge.vim v12, v10, 1, v0
176+
; RV64-NEXT: vmv1r.v v0, v8
177+
; RV64-NEXT: vsetvli a1, zero, e8, m2, ta, ma
178+
; RV64-NEXT: vmerge.vim v10, v10, 1, v0
179+
; RV64-NEXT: vmv1r.v v0, v9
180+
; RV64-NEXT: vsetvli zero, a0, e8, m2, tu, ma
181+
; RV64-NEXT: vmerge.vvm v10, v10, v12, v0
182+
; RV64-NEXT: vsetvli a0, zero, e8, m2, ta, ma
183+
; RV64-NEXT: vmsne.vi v0, v10, 0
184+
; RV64-NEXT: ret
185+
;
186+
; RV32ZVFHMIN-LABEL: vpmerge_vv_v32i1:
187+
; RV32ZVFHMIN: # %bb.0:
188+
; RV32ZVFHMIN-NEXT: li a1, 32
189+
; RV32ZVFHMIN-NEXT: vsetvli zero, a1, e32, m8, ta, ma
190+
; RV32ZVFHMIN-NEXT: vid.v v16
191+
; RV32ZVFHMIN-NEXT: vmsltu.vx v10, v16, a0
192+
; RV32ZVFHMIN-NEXT: vmand.mm v9, v9, v10
193+
; RV32ZVFHMIN-NEXT: vmandn.mm v8, v8, v9
194+
; RV32ZVFHMIN-NEXT: vmand.mm v9, v0, v9
195+
; RV32ZVFHMIN-NEXT: vmor.mm v0, v9, v8
196+
; RV32ZVFHMIN-NEXT: ret
197+
;
198+
; RV64ZVFHMIN-LABEL: vpmerge_vv_v32i1:
199+
; RV64ZVFHMIN: # %bb.0:
200+
; RV64ZVFHMIN-NEXT: vsetvli a1, zero, e8, m2, ta, ma
201+
; RV64ZVFHMIN-NEXT: vmv.v.i v10, 0
202+
; RV64ZVFHMIN-NEXT: vsetvli zero, a0, e8, m2, ta, ma
203+
; RV64ZVFHMIN-NEXT: vmerge.vim v12, v10, 1, v0
204+
; RV64ZVFHMIN-NEXT: vmv1r.v v0, v8
205+
; RV64ZVFHMIN-NEXT: vsetvli a1, zero, e8, m2, ta, ma
206+
; RV64ZVFHMIN-NEXT: vmerge.vim v10, v10, 1, v0
207+
; RV64ZVFHMIN-NEXT: vmv1r.v v0, v9
208+
; RV64ZVFHMIN-NEXT: vsetvli zero, a0, e8, m2, tu, ma
209+
; RV64ZVFHMIN-NEXT: vmerge.vvm v10, v10, v12, v0
210+
; RV64ZVFHMIN-NEXT: vsetvli a0, zero, e8, m2, ta, ma
211+
; RV64ZVFHMIN-NEXT: vmsne.vi v0, v10, 0
212+
; RV64ZVFHMIN-NEXT: ret
213+
%v = call <32 x i1> @llvm.vp.merge.v32i1(<32 x i1> %m, <32 x i1> %va, <32 x i1> %vb, i32 %evl)
214+
ret <32 x i1> %v
215+
}
216+
217+
define <64 x i1> @vpmerge_vv_v64i1(<64 x i1> %va, <64 x i1> %vb, <64 x i1> %m, i32 zeroext %evl) {
218+
; CHECK-LABEL: vpmerge_vv_v64i1:
219+
; CHECK: # %bb.0:
220+
; CHECK-NEXT: vsetvli a1, zero, e8, m4, ta, ma
221+
; CHECK-NEXT: vmv.v.i v12, 0
222+
; CHECK-NEXT: vsetvli zero, a0, e8, m4, ta, ma
223+
; CHECK-NEXT: vmerge.vim v16, v12, 1, v0
224+
; CHECK-NEXT: vmv1r.v v0, v8
225+
; CHECK-NEXT: vsetvli a1, zero, e8, m4, ta, ma
226+
; CHECK-NEXT: vmerge.vim v12, v12, 1, v0
227+
; CHECK-NEXT: vmv1r.v v0, v9
228+
; CHECK-NEXT: vsetvli zero, a0, e8, m4, tu, ma
229+
; CHECK-NEXT: vmerge.vvm v12, v12, v16, v0
230+
; CHECK-NEXT: vsetvli a0, zero, e8, m4, ta, ma
231+
; CHECK-NEXT: vmsne.vi v0, v12, 0
232+
; CHECK-NEXT: ret
233+
%v = call <64 x i1> @llvm.vp.merge.v64i1(<64 x i1> %m, <64 x i1> %va, <64 x i1> %vb, i32 %evl)
234+
ret <64 x i1> %v
235+
}
236+
61237
declare <2 x i8> @llvm.vp.merge.v2i8(<2 x i1>, <2 x i8>, <2 x i8>, i32)
62238

63239
define <2 x i8> @vpmerge_vv_v2i8(<2 x i8> %va, <2 x i8> %vb, <2 x i1> %m, i32 zeroext %evl) {
@@ -1188,10 +1364,10 @@ define <32 x double> @vpmerge_vv_v32f64(<32 x double> %va, <32 x double> %vb, <3
11881364
; CHECK-NEXT: vle64.v v8, (a0)
11891365
; CHECK-NEXT: li a1, 16
11901366
; CHECK-NEXT: mv a0, a2
1191-
; CHECK-NEXT: bltu a2, a1, .LBB79_2
1367+
; CHECK-NEXT: bltu a2, a1, .LBB83_2
11921368
; CHECK-NEXT: # %bb.1:
11931369
; CHECK-NEXT: li a0, 16
1194-
; CHECK-NEXT: .LBB79_2:
1370+
; CHECK-NEXT: .LBB83_2:
11951371
; CHECK-NEXT: vsetvli zero, a0, e64, m8, tu, ma
11961372
; CHECK-NEXT: vmerge.vvm v8, v8, v16, v0
11971373
; CHECK-NEXT: addi a0, a2, -16
@@ -1221,10 +1397,10 @@ define <32 x double> @vpmerge_vf_v32f64(double %a, <32 x double> %vb, <32 x i1>
12211397
; CHECK: # %bb.0:
12221398
; CHECK-NEXT: li a2, 16
12231399
; CHECK-NEXT: mv a1, a0
1224-
; CHECK-NEXT: bltu a0, a2, .LBB80_2
1400+
; CHECK-NEXT: bltu a0, a2, .LBB84_2
12251401
; CHECK-NEXT: # %bb.1:
12261402
; CHECK-NEXT: li a1, 16
1227-
; CHECK-NEXT: .LBB80_2:
1403+
; CHECK-NEXT: .LBB84_2:
12281404
; CHECK-NEXT: vsetvli zero, a1, e64, m8, tu, ma
12291405
; CHECK-NEXT: vfmerge.vfm v8, v8, fa0, v0
12301406
; CHECK-NEXT: addi a1, a0, -16

0 commit comments

Comments
 (0)