Skip to content

Commit 25445b9

Browse files
jacquesguanjacquesguan
authored andcommitted
[RISCV] Add rvv codegen support for vp.fptrunc.
This patch adds rvv codegen support for vp.fptrunc. The lowering of fp_round and vp.fptrunc share most code so use a common lowering function to handle those two, similar to vp.trunc. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D123841
1 parent 1881d6f commit 25445b9

File tree

5 files changed

+230
-43
lines changed

5 files changed

+230
-43
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
502502
ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX,
503503
ISD::VP_MERGE, ISD::VP_SELECT,
504504
ISD::VP_SITOFP, ISD::VP_UITOFP,
505-
ISD::VP_SETCC};
505+
ISD::VP_SETCC, ISD::VP_FP_ROUND};
506506

507507
if (!Subtarget.is64Bit()) {
508508
// We must custom-lower certain vXi64 operations on RV32 due to the vector
@@ -3280,48 +3280,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
32803280
return convertFromScalableVector(VT, Extend, DAG, Subtarget);
32813281
return Extend;
32823282
}
3283-
case ISD::FP_ROUND: {
3284-
// RVV can only do fp_round to types half the size as the source. We
3285-
// custom-lower f64->f16 rounds via RVV's round-to-odd float
3286-
// conversion instruction.
3287-
SDLoc DL(Op);
3288-
MVT VT = Op.getSimpleValueType();
3289-
SDValue Src = Op.getOperand(0);
3290-
MVT SrcVT = Src.getSimpleValueType();
3291-
3292-
// Prepare any fixed-length vector operands.
3293-
MVT ContainerVT = VT;
3294-
if (VT.isFixedLengthVector()) {
3295-
MVT SrcContainerVT = getContainerForFixedLengthVector(SrcVT);
3296-
ContainerVT =
3297-
SrcContainerVT.changeVectorElementType(VT.getVectorElementType());
3298-
Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
3299-
}
3300-
3301-
if (!VT.isVector() || VT.getVectorElementType() != MVT::f16 ||
3302-
SrcVT.getVectorElementType() != MVT::f64) {
3303-
// For scalable vectors, we only need to close the gap between
3304-
// vXf64<->vXf16.
3305-
if (!VT.isFixedLengthVector())
3306-
return Op;
3307-
// For fixed-length vectors, lower the FP_ROUND to a custom "VL" version.
3308-
Src = getRVVFPExtendOrRound(Src, VT, ContainerVT, DL, DAG, Subtarget);
3309-
return convertFromScalableVector(VT, Src, DAG, Subtarget);
3310-
}
3311-
3312-
SDValue Mask, VL;
3313-
std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
3314-
3315-
MVT InterVT = ContainerVT.changeVectorElementType(MVT::f32);
3316-
SDValue IntermediateRound =
3317-
DAG.getNode(RISCVISD::VFNCVT_ROD_VL, DL, InterVT, Src, Mask, VL);
3318-
SDValue Round = getRVVFPExtendOrRound(IntermediateRound, VT, ContainerVT,
3319-
DL, DAG, Subtarget);
3320-
3321-
if (VT.isFixedLengthVector())
3322-
return convertFromScalableVector(VT, Round, DAG, Subtarget);
3323-
return Round;
3324-
}
3283+
case ISD::FP_ROUND:
3284+
if (!Op.getValueType().isVector())
3285+
return Op;
3286+
return lowerVectorFPRoundLike(Op, DAG);
33253287
case ISD::FP_TO_SINT:
33263288
case ISD::FP_TO_UINT:
33273289
case ISD::SINT_TO_FP:
@@ -3664,6 +3626,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
36643626
: RISCVISD::VZEXT_VL);
36653627
case ISD::VP_TRUNC:
36663628
return lowerVectorTruncLike(Op, DAG);
3629+
case ISD::VP_FP_ROUND:
3630+
return lowerVectorFPRoundLike(Op, DAG);
36673631
case ISD::VP_FPTOSI:
36683632
return lowerVPFPIntConvOp(Op, DAG, RISCVISD::FP_TO_SINT_VL);
36693633
case ISD::VP_FPTOUI:
@@ -4430,6 +4394,67 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
44304394
return Result;
44314395
}
44324396

4397+
SDValue RISCVTargetLowering::lowerVectorFPRoundLike(SDValue Op,
4398+
SelectionDAG &DAG) const {
4399+
bool IsVPFPTrunc = Op.getOpcode() == ISD::VP_FP_ROUND;
4400+
// RVV can only do truncate fp to types half the size as the source. We
4401+
// custom-lower f64->f16 rounds via RVV's round-to-odd float
4402+
// conversion instruction.
4403+
SDLoc DL(Op);
4404+
MVT VT = Op.getSimpleValueType();
4405+
4406+
assert(VT.isVector() && "Unexpected type for vector truncate lowering");
4407+
4408+
SDValue Src = Op.getOperand(0);
4409+
MVT SrcVT = Src.getSimpleValueType();
4410+
4411+
bool IsDirectConv = VT.getVectorElementType() != MVT::f16 ||
4412+
SrcVT.getVectorElementType() != MVT::f64;
4413+
4414+
// For FP_ROUND of scalable vectors, leave it to the pattern.
4415+
if (!VT.isFixedLengthVector() && !IsVPFPTrunc && IsDirectConv)
4416+
return Op;
4417+
4418+
// Prepare any fixed-length vector operands.
4419+
MVT ContainerVT = VT;
4420+
SDValue Mask, VL;
4421+
if (IsVPFPTrunc) {
4422+
Mask = Op.getOperand(1);
4423+
VL = Op.getOperand(2);
4424+
}
4425+
if (VT.isFixedLengthVector()) {
4426+
MVT SrcContainerVT = getContainerForFixedLengthVector(SrcVT);
4427+
ContainerVT =
4428+
SrcContainerVT.changeVectorElementType(VT.getVectorElementType());
4429+
Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
4430+
if (IsVPFPTrunc) {
4431+
MVT MaskVT =
4432+
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
4433+
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
4434+
}
4435+
}
4436+
4437+
if (!IsVPFPTrunc)
4438+
std::tie(Mask, VL) =
4439+
getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);
4440+
4441+
if (IsDirectConv) {
4442+
Src = DAG.getNode(RISCVISD::FP_ROUND_VL, DL, ContainerVT, Src, Mask, VL);
4443+
if (VT.isFixedLengthVector())
4444+
Src = convertFromScalableVector(VT, Src, DAG, Subtarget);
4445+
return Src;
4446+
}
4447+
4448+
MVT InterVT = ContainerVT.changeVectorElementType(MVT::f32);
4449+
SDValue IntermediateRound =
4450+
DAG.getNode(RISCVISD::VFNCVT_ROD_VL, DL, InterVT, Src, Mask, VL);
4451+
SDValue Round = DAG.getNode(RISCVISD::FP_ROUND_VL, DL, ContainerVT,
4452+
IntermediateRound, Mask, VL);
4453+
if (VT.isFixedLengthVector())
4454+
return convertFromScalableVector(VT, Round, DAG, Subtarget);
4455+
return Round;
4456+
}
4457+
44334458
// Custom-legalize INSERT_VECTOR_ELT so that the value is inserted into the
44344459
// first position of a vector, and that vector is slid up to the insert index.
44354460
// By limiting the active vector length to index+1 and merging with the

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ class RISCVTargetLowering : public TargetLowering {
614614
int64_t ExtTrueVal) const;
615615
SDValue lowerVectorMaskTruncLike(SDValue Op, SelectionDAG &DAG) const;
616616
SDValue lowerVectorTruncLike(SDValue Op, SelectionDAG &DAG) const;
617+
SDValue lowerVectorFPRoundLike(SDValue Op, SelectionDAG &DAG) const;
617618
SDValue lowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
618619
SDValue lowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
619620
SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,13 @@ foreach fvti = AllFloatVectors in {
15911591
VLOpFrag)),
15921592
(!cast<Instruction>("PseudoVFNCVT_ROD_F_F_W_"#fvti.LMul.MX)
15931593
fwti.RegClass:$rs1, GPR:$vl, fvti.Log2SEW)>;
1594+
1595+
def : Pat<(fvti.Vector (riscv_fncvt_rod_vl (fwti.Vector fwti.RegClass:$rs1),
1596+
(fwti.Mask V0),
1597+
VLOpFrag)),
1598+
(!cast<Instruction>("PseudoVFNCVT_ROD_F_F_W_"#fvti.LMul.MX#"_MASK")
1599+
(fvti.Vector (IMPLICIT_DEF)), fwti.RegClass:$rs1,
1600+
(fwti.Mask V0), GPR:$vl, fvti.Log2SEW, TAIL_AGNOSTIC)>;
15941601
}
15951602
}
15961603

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zvfh,+v -riscv-v-vector-bits-min=128 -verify-machineinstrs < %s | FileCheck %s
3+
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zvfh,+v -riscv-v-vector-bits-min=128 -verify-machineinstrs < %s | FileCheck %s
4+
5+
declare <2 x half> @llvm.vp.fptrunc.v2f16.v2f32(<2 x float>, <2 x i1>, i32)
6+
7+
define <2 x half> @vfptrunc_v2f16_v2f32(<2 x float> %a, <2 x i1> %m, i32 zeroext %vl) {
8+
; CHECK-LABEL: vfptrunc_v2f16_v2f32:
9+
; CHECK: # %bb.0:
10+
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, mu
11+
; CHECK-NEXT: vfncvt.f.f.w v9, v8, v0.t
12+
; CHECK-NEXT: vmv1r.v v8, v9
13+
; CHECK-NEXT: ret
14+
%v = call <2 x half> @llvm.vp.fptrunc.v2f16.v2f32(<2 x float> %a, <2 x i1> %m, i32 %vl)
15+
ret <2 x half> %v
16+
}
17+
18+
define <2 x half> @vfptrunc_v2f16_v2f32_unmasked(<2 x float> %a, i32 zeroext %vl) {
19+
; CHECK-LABEL: vfptrunc_v2f16_v2f32_unmasked:
20+
; CHECK: # %bb.0:
21+
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, mu
22+
; CHECK-NEXT: vfncvt.f.f.w v9, v8
23+
; CHECK-NEXT: vmv1r.v v8, v9
24+
; CHECK-NEXT: ret
25+
%v = call <2 x half> @llvm.vp.fptrunc.v2f16.v2f32(<2 x float> %a, <2 x i1> shufflevector (<2 x i1> insertelement (<2 x i1> undef, i1 true, i32 0), <2 x i1> undef, <2 x i32> zeroinitializer), i32 %vl)
26+
ret <2 x half> %v
27+
}
28+
29+
declare <2 x half> @llvm.vp.fptrunc.v2f16.v2f64(<2 x double>, <2 x i1>, i32)
30+
31+
define <2 x half> @vfptrunc_v2f16_v2f64(<2 x double> %a, <2 x i1> %m, i32 zeroext %vl) {
32+
; CHECK-LABEL: vfptrunc_v2f16_v2f64:
33+
; CHECK: # %bb.0:
34+
; CHECK-NEXT: vsetvli zero, a0, e32, mf2, ta, mu
35+
; CHECK-NEXT: vfncvt.rod.f.f.w v9, v8, v0.t
36+
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, mu
37+
; CHECK-NEXT: vfncvt.f.f.w v8, v9, v0.t
38+
; CHECK-NEXT: ret
39+
%v = call <2 x half> @llvm.vp.fptrunc.v2f16.v2f64(<2 x double> %a, <2 x i1> %m, i32 %vl)
40+
ret <2 x half> %v
41+
}
42+
43+
define <2 x half> @vfptrunc_v2f16_v2f64_unmasked(<2 x double> %a, i32 zeroext %vl) {
44+
; CHECK-LABEL: vfptrunc_v2f16_v2f64_unmasked:
45+
; CHECK: # %bb.0:
46+
; CHECK-NEXT: vsetvli zero, a0, e32, mf2, ta, mu
47+
; CHECK-NEXT: vfncvt.rod.f.f.w v9, v8
48+
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, mu
49+
; CHECK-NEXT: vfncvt.f.f.w v8, v9
50+
; CHECK-NEXT: ret
51+
%v = call <2 x half> @llvm.vp.fptrunc.v2f16.v2f64(<2 x double> %a, <2 x i1> shufflevector (<2 x i1> insertelement (<2 x i1> undef, i1 true, i32 0), <2 x i1> undef, <2 x i32> zeroinitializer), i32 %vl)
52+
ret <2 x half> %v
53+
}
54+
55+
declare <2 x float> @llvm.vp.fptrunc.v2f64.v2f32(<2 x double>, <2 x i1>, i32)
56+
57+
define <2 x float> @vfptrunc_v2f32_v2f64(<2 x double> %a, <2 x i1> %m, i32 zeroext %vl) {
58+
; CHECK-LABEL: vfptrunc_v2f32_v2f64:
59+
; CHECK: # %bb.0:
60+
; CHECK-NEXT: vsetvli zero, a0, e32, mf2, ta, mu
61+
; CHECK-NEXT: vfncvt.f.f.w v9, v8, v0.t
62+
; CHECK-NEXT: vmv1r.v v8, v9
63+
; CHECK-NEXT: ret
64+
%v = call <2 x float> @llvm.vp.fptrunc.v2f64.v2f32(<2 x double> %a, <2 x i1> %m, i32 %vl)
65+
ret <2 x float> %v
66+
}
67+
68+
define <2 x float> @vfptrunc_v2f32_v2f64_unmasked(<2 x double> %a, i32 zeroext %vl) {
69+
; CHECK-LABEL: vfptrunc_v2f32_v2f64_unmasked:
70+
; CHECK: # %bb.0:
71+
; CHECK-NEXT: vsetvli zero, a0, e32, mf2, ta, mu
72+
; CHECK-NEXT: vfncvt.f.f.w v9, v8
73+
; CHECK-NEXT: vmv1r.v v8, v9
74+
; CHECK-NEXT: ret
75+
%v = call <2 x float> @llvm.vp.fptrunc.v2f64.v2f32(<2 x double> %a, <2 x i1> shufflevector (<2 x i1> insertelement (<2 x i1> undef, i1 true, i32 0), <2 x i1> undef, <2 x i32> zeroinitializer), i32 %vl)
76+
ret <2 x float> %v
77+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zvfh,+v -verify-machineinstrs < %s | FileCheck %s
3+
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zvfh,+v -verify-machineinstrs < %s | FileCheck %s
4+
5+
declare <vscale x 2 x half> @llvm.vp.fptrunc.nxv2f16.nxv2f32(<vscale x 2 x float>, <vscale x 2 x i1>, i32)
6+
7+
define <vscale x 2 x half> @vfptrunc_nxv2f16_nxv2f32(<vscale x 2 x float> %a, <vscale x 2 x i1> %m, i32 zeroext %vl) {
8+
; CHECK-LABEL: vfptrunc_nxv2f16_nxv2f32:
9+
; CHECK: # %bb.0:
10+
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, mu
11+
; CHECK-NEXT: vfncvt.f.f.w v9, v8, v0.t
12+
; CHECK-NEXT: vmv1r.v v8, v9
13+
; CHECK-NEXT: ret
14+
%v = call <vscale x 2 x half> @llvm.vp.fptrunc.nxv2f16.nxv2f32(<vscale x 2 x float> %a, <vscale x 2 x i1> %m, i32 %vl)
15+
ret <vscale x 2 x half> %v
16+
}
17+
18+
define <vscale x 2 x half> @vfptrunc_nxv2f16_nxv2f32_unmasked(<vscale x 2 x float> %a, i32 zeroext %vl) {
19+
; CHECK-LABEL: vfptrunc_nxv2f16_nxv2f32_unmasked:
20+
; CHECK: # %bb.0:
21+
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, mu
22+
; CHECK-NEXT: vfncvt.f.f.w v9, v8
23+
; CHECK-NEXT: vmv1r.v v8, v9
24+
; CHECK-NEXT: ret
25+
%v = call <vscale x 2 x half> @llvm.vp.fptrunc.nxv2f16.nxv2f32(<vscale x 2 x float> %a, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), i32 %vl)
26+
ret <vscale x 2 x half> %v
27+
}
28+
29+
declare <vscale x 2 x half> @llvm.vp.fptrunc.nxv2f16.nxv2f64(<vscale x 2 x double>, <vscale x 2 x i1>, i32)
30+
31+
define <vscale x 2 x half> @vfptrunc_nxv2f16_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x i1> %m, i32 zeroext %vl) {
32+
; CHECK-LABEL: vfptrunc_nxv2f16_nxv2f64:
33+
; CHECK: # %bb.0:
34+
; CHECK-NEXT: vsetvli zero, a0, e32, m1, ta, mu
35+
; CHECK-NEXT: vfncvt.rod.f.f.w v10, v8, v0.t
36+
; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu
37+
; CHECK-NEXT: vfncvt.f.f.w v8, v10, v0.t
38+
; CHECK-NEXT: ret
39+
%v = call <vscale x 2 x half> @llvm.vp.fptrunc.nxv2f16.nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x i1> %m, i32 %vl)
40+
ret <vscale x 2 x half> %v
41+
}
42+
43+
define <vscale x 2 x half> @vfptrunc_nxv2f16_nxv2f64_unmasked(<vscale x 2 x double> %a, i32 zeroext %vl) {
44+
; CHECK-LABEL: vfptrunc_nxv2f16_nxv2f64_unmasked:
45+
; CHECK: # %bb.0:
46+
; CHECK-NEXT: vsetvli zero, a0, e32, m1, ta, mu
47+
; CHECK-NEXT: vfncvt.rod.f.f.w v10, v8
48+
; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu
49+
; CHECK-NEXT: vfncvt.f.f.w v8, v10
50+
; CHECK-NEXT: ret
51+
%v = call <vscale x 2 x half> @llvm.vp.fptrunc.nxv2f16.nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), i32 %vl)
52+
ret <vscale x 2 x half> %v
53+
}
54+
55+
declare <vscale x 2 x float> @llvm.vp.fptrunc.nxv2f64.nxv2f32(<vscale x 2 x double>, <vscale x 2 x i1>, i32)
56+
57+
define <vscale x 2 x float> @vfptrunc_nxv2f32_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x i1> %m, i32 zeroext %vl) {
58+
; CHECK-LABEL: vfptrunc_nxv2f32_nxv2f64:
59+
; CHECK: # %bb.0:
60+
; CHECK-NEXT: vsetvli zero, a0, e32, m1, ta, mu
61+
; CHECK-NEXT: vfncvt.f.f.w v10, v8, v0.t
62+
; CHECK-NEXT: vmv.v.v v8, v10
63+
; CHECK-NEXT: ret
64+
%v = call <vscale x 2 x float> @llvm.vp.fptrunc.nxv2f64.nxv2f32(<vscale x 2 x double> %a, <vscale x 2 x i1> %m, i32 %vl)
65+
ret <vscale x 2 x float> %v
66+
}
67+
68+
define <vscale x 2 x float> @vfptrunc_nxv2f32_nxv2f64_unmasked(<vscale x 2 x double> %a, i32 zeroext %vl) {
69+
; CHECK-LABEL: vfptrunc_nxv2f32_nxv2f64_unmasked:
70+
; CHECK: # %bb.0:
71+
; CHECK-NEXT: vsetvli zero, a0, e32, m1, ta, mu
72+
; CHECK-NEXT: vfncvt.f.f.w v10, v8
73+
; CHECK-NEXT: vmv.v.v v8, v10
74+
; CHECK-NEXT: ret
75+
%v = call <vscale x 2 x float> @llvm.vp.fptrunc.nxv2f64.nxv2f32(<vscale x 2 x double> %a, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), i32 %vl)
76+
ret <vscale x 2 x float> %v
77+
}

0 commit comments

Comments
 (0)