Skip to content

Commit 3665837

Browse files
committed
[RISCV] Add support for fixed vector sqrt.
1 parent 8bd8534 commit 3665837

File tree

4 files changed

+84
-24
lines changed

4 files changed

+84
-24
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
577577
setOperationAction(ISD::FMUL, VT, Custom);
578578
setOperationAction(ISD::FDIV, VT, Custom);
579579
setOperationAction(ISD::FNEG, VT, Custom);
580+
setOperationAction(ISD::FSQRT, VT, Custom);
580581
setOperationAction(ISD::FMA, VT, Custom);
581582
}
582583
}
@@ -1209,6 +1210,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
12091210
return lowerToScalableOp(Op, DAG, RISCVISD::FDIV_VL);
12101211
case ISD::FNEG:
12111212
return lowerToScalableOp(Op, DAG, RISCVISD::FNEG_VL);
1213+
case ISD::FSQRT:
1214+
return lowerToScalableOp(Op, DAG, RISCVISD::FSQRT_VL);
12121215
case ISD::FMA:
12131216
return lowerToScalableOp(Op, DAG, RISCVISD::FMA_VL);
12141217
case ISD::SMIN:
@@ -4739,6 +4742,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
47394742
NODE_NAME_CASE(FMUL_VL)
47404743
NODE_NAME_CASE(FDIV_VL)
47414744
NODE_NAME_CASE(FNEG_VL)
4745+
NODE_NAME_CASE(FSQRT_VL)
47424746
NODE_NAME_CASE(FMA_VL)
47434747
NODE_NAME_CASE(SMIN_VL)
47444748
NODE_NAME_CASE(SMAX_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ enum NodeType : unsigned {
162162
FMUL_VL,
163163
FDIV_VL,
164164
FNEG_VL,
165+
FSQRT_VL,
165166
FMA_VL,
166167
SMIN_VL,
167168
SMAX_VL,

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,28 +51,29 @@ def riscv_vle_vl : SDNode<"RISCVISD::VLE_VL", SDT_RISCVVLE_VL,
5151
def riscv_vse_vl : SDNode<"RISCVISD::VSE_VL", SDT_RISCVVSE_VL,
5252
[SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
5353

54-
def riscv_add_vl : SDNode<"RISCVISD::ADD_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
55-
def riscv_sub_vl : SDNode<"RISCVISD::SUB_VL", SDT_RISCVIntBinOp_VL>;
56-
def riscv_mul_vl : SDNode<"RISCVISD::MUL_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
57-
def riscv_and_vl : SDNode<"RISCVISD::AND_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
58-
def riscv_or_vl : SDNode<"RISCVISD::OR_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
59-
def riscv_xor_vl : SDNode<"RISCVISD::XOR_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
60-
def riscv_sdiv_vl : SDNode<"RISCVISD::SDIV_VL", SDT_RISCVIntBinOp_VL>;
61-
def riscv_srem_vl : SDNode<"RISCVISD::SREM_VL", SDT_RISCVIntBinOp_VL>;
62-
def riscv_udiv_vl : SDNode<"RISCVISD::UDIV_VL", SDT_RISCVIntBinOp_VL>;
63-
def riscv_urem_vl : SDNode<"RISCVISD::UREM_VL", SDT_RISCVIntBinOp_VL>;
64-
def riscv_shl_vl : SDNode<"RISCVISD::SHL_VL", SDT_RISCVIntBinOp_VL>;
65-
def riscv_sra_vl : SDNode<"RISCVISD::SRA_VL", SDT_RISCVIntBinOp_VL>;
66-
def riscv_srl_vl : SDNode<"RISCVISD::SRL_VL", SDT_RISCVIntBinOp_VL>;
67-
def riscv_smin_vl : SDNode<"RISCVISD::SMIN_VL", SDT_RISCVIntBinOp_VL>;
68-
def riscv_smax_vl : SDNode<"RISCVISD::SMAX_VL", SDT_RISCVIntBinOp_VL>;
69-
def riscv_umin_vl : SDNode<"RISCVISD::UMIN_VL", SDT_RISCVIntBinOp_VL>;
70-
def riscv_umax_vl : SDNode<"RISCVISD::UMAX_VL", SDT_RISCVIntBinOp_VL>;
71-
def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
72-
def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>;
73-
def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
74-
def riscv_fdiv_vl : SDNode<"RISCVISD::FDIV_VL", SDT_RISCVFPBinOp_VL>;
75-
def riscv_fneg_vl : SDNode<"RISCVISD::FNEG_VL", SDT_RISCVFPUnOp_VL>;
54+
def riscv_add_vl : SDNode<"RISCVISD::ADD_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
55+
def riscv_sub_vl : SDNode<"RISCVISD::SUB_VL", SDT_RISCVIntBinOp_VL>;
56+
def riscv_mul_vl : SDNode<"RISCVISD::MUL_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
57+
def riscv_and_vl : SDNode<"RISCVISD::AND_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
58+
def riscv_or_vl : SDNode<"RISCVISD::OR_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
59+
def riscv_xor_vl : SDNode<"RISCVISD::XOR_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
60+
def riscv_sdiv_vl : SDNode<"RISCVISD::SDIV_VL", SDT_RISCVIntBinOp_VL>;
61+
def riscv_srem_vl : SDNode<"RISCVISD::SREM_VL", SDT_RISCVIntBinOp_VL>;
62+
def riscv_udiv_vl : SDNode<"RISCVISD::UDIV_VL", SDT_RISCVIntBinOp_VL>;
63+
def riscv_urem_vl : SDNode<"RISCVISD::UREM_VL", SDT_RISCVIntBinOp_VL>;
64+
def riscv_shl_vl : SDNode<"RISCVISD::SHL_VL", SDT_RISCVIntBinOp_VL>;
65+
def riscv_sra_vl : SDNode<"RISCVISD::SRA_VL", SDT_RISCVIntBinOp_VL>;
66+
def riscv_srl_vl : SDNode<"RISCVISD::SRL_VL", SDT_RISCVIntBinOp_VL>;
67+
def riscv_smin_vl : SDNode<"RISCVISD::SMIN_VL", SDT_RISCVIntBinOp_VL>;
68+
def riscv_smax_vl : SDNode<"RISCVISD::SMAX_VL", SDT_RISCVIntBinOp_VL>;
69+
def riscv_umin_vl : SDNode<"RISCVISD::UMIN_VL", SDT_RISCVIntBinOp_VL>;
70+
def riscv_umax_vl : SDNode<"RISCVISD::UMAX_VL", SDT_RISCVIntBinOp_VL>;
71+
def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
72+
def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>;
73+
def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
74+
def riscv_fdiv_vl : SDNode<"RISCVISD::FDIV_VL", SDT_RISCVFPBinOp_VL>;
75+
def riscv_fneg_vl : SDNode<"RISCVISD::FNEG_VL", SDT_RISCVFPUnOp_VL>;
76+
def riscv_fsqrt_vl : SDNode<"RISCVISD::FSQRT_VL", SDT_RISCVFPUnOp_VL>;
7677

7778
def SDT_RISCVVecFMA_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
7879
SDTCisSameAs<0, 2>,
@@ -440,9 +441,15 @@ foreach vti = AllFloatVectors in {
440441
GPR:$vl, vti.SEW)>;
441442
}
442443

443-
// 14.12. Vector Floating-Point Sign-Injection Instructions
444-
// Handle fneg with VFSGNJN using the same input for both operands.
445444
foreach vti = AllFloatVectors in {
445+
// 14.8. Vector Floating-Point Square-Root Instruction
446+
def : Pat<(riscv_fsqrt_vl (vti.Vector vti.RegClass:$rs2), (vti.Mask true_mask),
447+
(XLenVT (VLOp GPR:$vl))),
448+
(!cast<Instruction>("PseudoVFSQRT_V_"# vti.LMul.MX)
449+
vti.RegClass:$rs2, GPR:$vl, vti.SEW)>;
450+
451+
// 14.12. Vector Floating-Point Sign-Injection Instructions
452+
// Handle fneg with VFSGNJN using the same input for both operands.
446453
def : Pat<(riscv_fneg_vl (vti.Vector vti.RegClass:$rs), (vti.Mask true_mask),
447454
(XLenVT (VLOp GPR:$vl))),
448455
(!cast<Instruction>("PseudoVFSGNJN_VV_"# vti.LMul.MX)

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,54 @@ define void @fneg_v2f64(<2 x double>* %x) {
253253
ret void
254254
}
255255

256+
define void @sqrt_v8f16(<8 x half>* %x) {
257+
; CHECK-LABEL: sqrt_v8f16:
258+
; CHECK: # %bb.0:
259+
; CHECK-NEXT: addi a1, zero, 8
260+
; CHECK-NEXT: vsetvli a1, a1, e16,m1,ta,mu
261+
; CHECK-NEXT: vle16.v v25, (a0)
262+
; CHECK-NEXT: vfsqrt.v v25, v25
263+
; CHECK-NEXT: vse16.v v25, (a0)
264+
; CHECK-NEXT: ret
265+
%a = load <8 x half>, <8 x half>* %x
266+
%b = call <8 x half> @llvm.sqrt.v8f16(<8 x half> %a)
267+
store <8 x half> %b, <8 x half>* %x
268+
ret void
269+
}
270+
declare <8 x half> @llvm.sqrt.v8f16(<8 x half>)
271+
272+
define void @sqrt_v4f32(<4 x float>* %x) {
273+
; CHECK-LABEL: sqrt_v4f32:
274+
; CHECK: # %bb.0:
275+
; CHECK-NEXT: addi a1, zero, 4
276+
; CHECK-NEXT: vsetvli a1, a1, e32,m1,ta,mu
277+
; CHECK-NEXT: vle32.v v25, (a0)
278+
; CHECK-NEXT: vfsqrt.v v25, v25
279+
; CHECK-NEXT: vse32.v v25, (a0)
280+
; CHECK-NEXT: ret
281+
%a = load <4 x float>, <4 x float>* %x
282+
%b = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %a)
283+
store <4 x float> %b, <4 x float>* %x
284+
ret void
285+
}
286+
declare <4 x float> @llvm.sqrt.v4f32(<4 x float>)
287+
288+
define void @sqrt_v2f64(<2 x double>* %x) {
289+
; CHECK-LABEL: sqrt_v2f64:
290+
; CHECK: # %bb.0:
291+
; CHECK-NEXT: addi a1, zero, 2
292+
; CHECK-NEXT: vsetvli a1, a1, e64,m1,ta,mu
293+
; CHECK-NEXT: vle64.v v25, (a0)
294+
; CHECK-NEXT: vfsqrt.v v25, v25
295+
; CHECK-NEXT: vse64.v v25, (a0)
296+
; CHECK-NEXT: ret
297+
%a = load <2 x double>, <2 x double>* %x
298+
%b = call <2 x double> @llvm.sqrt.v2f64(<2 x double> %a)
299+
store <2 x double> %b, <2 x double>* %x
300+
ret void
301+
}
302+
declare <2 x double> @llvm.sqrt.v2f64(<2 x double>)
303+
256304
define void @fma_v8f16(<8 x half>* %x, <8 x half>* %y, <8 x half>* %z) {
257305
; CHECK-LABEL: fma_v8f16:
258306
; CHECK: # %bb.0:

0 commit comments

Comments
 (0)