Skip to content

[LoongArch] Make ISD::FSQRT a legal operation with lsx/lasx feature #74795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setOperationAction({ISD::FADD, ISD::FSUB}, VT, Legal);
setOperationAction({ISD::FMUL, ISD::FDIV}, VT, Legal);
setOperationAction(ISD::FMA, VT, Legal);
setOperationAction(ISD::FSQRT, VT, Legal);
setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT,
ISD::SETUGE, ISD::SETUGT},
VT, Expand);
Expand Down Expand Up @@ -309,6 +310,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setOperationAction({ISD::FADD, ISD::FSUB}, VT, Legal);
setOperationAction({ISD::FMUL, ISD::FDIV}, VT, Legal);
setOperationAction(ISD::FMA, VT, Legal);
setOperationAction(ISD::FSQRT, VT, Legal);
setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT,
ISD::SETUGE, ISD::SETUGT},
VT, Expand);
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,13 @@ multiclass PatXr<SDPatternOperator OpNode, string Inst> {
(!cast<LAInst>(Inst#"_D") LASX256:$xj)>;
}

multiclass PatXrF<SDPatternOperator OpNode, string Inst> {
def : Pat<(v8f32 (OpNode (v8f32 LASX256:$xj))),
(!cast<LAInst>(Inst#"_S") LASX256:$xj)>;
def : Pat<(v4f64 (OpNode (v4f64 LASX256:$xj))),
(!cast<LAInst>(Inst#"_D") LASX256:$xj)>;
}

multiclass PatXrXr<SDPatternOperator OpNode, string Inst> {
def : Pat<(OpNode (v32i8 LASX256:$xj), (v32i8 LASX256:$xk)),
(!cast<LAInst>(Inst#"_B") LASX256:$xj, LASX256:$xk)>;
Expand Down Expand Up @@ -1448,6 +1455,21 @@ def : Pat<(fma v8f32:$xj, v8f32:$xk, v8f32:$xa),
def : Pat<(fma v4f64:$xj, v4f64:$xk, v4f64:$xa),
(XVFMADD_D v4f64:$xj, v4f64:$xk, v4f64:$xa)>;

// XVFSQRT_{S/D}
defm : PatXrF<fsqrt, "XVFSQRT">;

// XVRECIP_{S/D}
def : Pat<(fdiv vsplatf32_fpimm_eq_1, v8f32:$xj),
(XVFRECIP_S v8f32:$xj)>;
def : Pat<(fdiv vsplatf64_fpimm_eq_1, v4f64:$xj),
(XVFRECIP_D v4f64:$xj)>;

// XVFRSQRT_{S/D}
def : Pat<(fdiv vsplatf32_fpimm_eq_1, (fsqrt v8f32:$xj)),
(XVFRSQRT_S v8f32:$xj)>;
def : Pat<(fdiv vsplatf64_fpimm_eq_1, (fsqrt v4f64:$xj)),
(XVFRSQRT_D v4f64:$xj)>;

// XVSEQ[I]_{B/H/W/D}
defm : PatCCXrSimm5<SETEQ, "XVSEQI">;
defm : PatCCXrXr<SETEQ, "XVSEQ">;
Expand Down
45 changes: 45 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,29 @@ def vsplati64_imm_eq_63 : PatFrags<(ops), [(build_vector),
Imm.getBitWidth() == EltTy.getSizeInBits() && Imm == 63;
}]>;

def vsplatf32_fpimm_eq_1
: PatFrags<(ops), [(bitconvert (v4i32 (build_vector))),
(bitconvert (v8i32 (build_vector)))], [{
APInt Imm;
EVT EltTy = N->getValueType(0).getVectorElementType();
N = N->getOperand(0).getNode();

return selectVSplat(N, Imm, EltTy.getSizeInBits()) &&
Imm.getBitWidth() == EltTy.getSizeInBits() &&
Imm == APFloat(+1.0f).bitcastToAPInt();
}]>;
def vsplatf64_fpimm_eq_1
: PatFrags<(ops), [(bitconvert (v2i64 (build_vector))),
(bitconvert (v4i64 (build_vector)))], [{
APInt Imm;
EVT EltTy = N->getValueType(0).getVectorElementType();
N = N->getOperand(0).getNode();

return selectVSplat(N, Imm, EltTy.getSizeInBits()) &&
Imm.getBitWidth() == EltTy.getSizeInBits() &&
Imm == APFloat(+1.0).bitcastToAPInt();
}]>;

def vsplati8imm7 : PatFrag<(ops node:$reg),
(and node:$reg, vsplati8_imm_eq_7)>;
def vsplati16imm15 : PatFrag<(ops node:$reg),
Expand Down Expand Up @@ -1173,6 +1196,13 @@ multiclass PatVr<SDPatternOperator OpNode, string Inst> {
(!cast<LAInst>(Inst#"_D") LSX128:$vj)>;
}

multiclass PatVrF<SDPatternOperator OpNode, string Inst> {
def : Pat<(v4f32 (OpNode (v4f32 LSX128:$vj))),
(!cast<LAInst>(Inst#"_S") LSX128:$vj)>;
def : Pat<(v2f64 (OpNode (v2f64 LSX128:$vj))),
(!cast<LAInst>(Inst#"_D") LSX128:$vj)>;
}

multiclass PatVrVr<SDPatternOperator OpNode, string Inst> {
def : Pat<(OpNode (v16i8 LSX128:$vj), (v16i8 LSX128:$vk)),
(!cast<LAInst>(Inst#"_B") LSX128:$vj, LSX128:$vk)>;
Expand Down Expand Up @@ -1525,6 +1555,21 @@ def : Pat<(fma v4f32:$vj, v4f32:$vk, v4f32:$va),
def : Pat<(fma v2f64:$vj, v2f64:$vk, v2f64:$va),
(VFMADD_D v2f64:$vj, v2f64:$vk, v2f64:$va)>;

// VFSQRT_{S/D}
defm : PatVrF<fsqrt, "VFSQRT">;

// VFRECIP_{S/D}
def : Pat<(fdiv vsplatf32_fpimm_eq_1, v4f32:$vj),
(VFRECIP_S v4f32:$vj)>;
def : Pat<(fdiv vsplatf64_fpimm_eq_1, v2f64:$vj),
(VFRECIP_D v2f64:$vj)>;

// VFRSQRT_{S/D}
def : Pat<(fdiv vsplatf32_fpimm_eq_1, (fsqrt v4f32:$vj)),
(VFRSQRT_S v4f32:$vj)>;
def : Pat<(fdiv vsplatf64_fpimm_eq_1, (fsqrt v2f64:$vj)),
(VFRSQRT_D v2f64:$vj)>;

// VSEQ[I]_{B/H/W/D}
defm : PatCCVrSimm5<SETEQ, "VSEQI">;
defm : PatCCVrVr<SETEQ, "VSEQ">;
Expand Down
65 changes: 65 additions & 0 deletions llvm/test/CodeGen/LoongArch/lasx/fsqrt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc --mtriple=loongarch64 --mattr=+lasx < %s | FileCheck %s

;; fsqrt
define void @sqrt_v8f32(ptr %res, ptr %a0) nounwind {
; CHECK-LABEL: sqrt_v8f32:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: xvld $xr0, $a1, 0
; CHECK-NEXT: xvfsqrt.s $xr0, $xr0
; CHECK-NEXT: xvst $xr0, $a0, 0
; CHECK-NEXT: ret
entry:
%v0 = load <8 x float>, ptr %a0, align 16
%sqrt = call <8 x float> @llvm.sqrt.v8f32 (<8 x float> %v0)
store <8 x float> %sqrt, ptr %res, align 16
ret void
}

define void @sqrt_v4f64(ptr %res, ptr %a0) nounwind {
; CHECK-LABEL: sqrt_v4f64:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: xvld $xr0, $a1, 0
; CHECK-NEXT: xvfsqrt.d $xr0, $xr0
; CHECK-NEXT: xvst $xr0, $a0, 0
; CHECK-NEXT: ret
entry:
%v0 = load <4 x double>, ptr %a0, align 16
%sqrt = call <4 x double> @llvm.sqrt.v4f64 (<4 x double> %v0)
store <4 x double> %sqrt, ptr %res, align 16
ret void
}

;; 1.0 / (fsqrt vec)
define void @one_div_sqrt_v8f32(ptr %res, ptr %a0) nounwind {
; CHECK-LABEL: one_div_sqrt_v8f32:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: xvld $xr0, $a1, 0
; CHECK-NEXT: xvfrsqrt.s $xr0, $xr0
; CHECK-NEXT: xvst $xr0, $a0, 0
; CHECK-NEXT: ret
entry:
%v0 = load <8 x float>, ptr %a0, align 16
%sqrt = call <8 x float> @llvm.sqrt.v8f32 (<8 x float> %v0)
%div = fdiv <8 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
store <8 x float> %div, ptr %res, align 16
ret void
}

define void @one_div_sqrt_v4f64(ptr %res, ptr %a0) nounwind {
; CHECK-LABEL: one_div_sqrt_v4f64:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: xvld $xr0, $a1, 0
; CHECK-NEXT: xvfrsqrt.d $xr0, $xr0
; CHECK-NEXT: xvst $xr0, $a0, 0
; CHECK-NEXT: ret
entry:
%v0 = load <4 x double>, ptr %a0, align 16
%sqrt = call <4 x double> @llvm.sqrt.v4f64 (<4 x double> %v0)
%div = fdiv <4 x double> <double 1.0, double 1.0, double 1.0, double 1.0>, %sqrt
store <4 x double> %div, ptr %res, align 16
ret void
}

declare <8 x float> @llvm.sqrt.v8f32(<8 x float>)
declare <4 x double> @llvm.sqrt.v4f64(<4 x double>)
29 changes: 29 additions & 0 deletions llvm/test/CodeGen/LoongArch/lasx/ir-instruction/fdiv.ll
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,32 @@ entry:
store <4 x double> %v2, ptr %res
ret void
}

;; 1.0 / vec
define void @one_fdiv_v8f32(ptr %res, ptr %a0) nounwind {
; CHECK-LABEL: one_fdiv_v8f32:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: xvld $xr0, $a1, 0
; CHECK-NEXT: xvfrecip.s $xr0, $xr0
; CHECK-NEXT: xvst $xr0, $a0, 0
; CHECK-NEXT: ret
entry:
%v0 = load <8 x float>, ptr %a0
%div = fdiv <8 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %v0
store <8 x float> %div, ptr %res
ret void
}

define void @one_fdiv_v4f64(ptr %res, ptr %a0) nounwind {
; CHECK-LABEL: one_fdiv_v4f64:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: xvld $xr0, $a1, 0
; CHECK-NEXT: xvfrecip.d $xr0, $xr0
; CHECK-NEXT: xvst $xr0, $a0, 0
; CHECK-NEXT: ret
entry:
%v0 = load <4 x double>, ptr %a0
%div = fdiv <4 x double> <double 1.0, double 1.0, double 1.0, double 1.0>, %v0
store <4 x double> %div, ptr %res
ret void
}
65 changes: 65 additions & 0 deletions llvm/test/CodeGen/LoongArch/lsx/fsqrt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc --mtriple=loongarch64 --mattr=+lsx < %s | FileCheck %s

;; fsqrt
define void @sqrt_v4f32(ptr %res, ptr %a0) nounwind {
; CHECK-LABEL: sqrt_v4f32:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vld $vr0, $a1, 0
; CHECK-NEXT: vfsqrt.s $vr0, $vr0
; CHECK-NEXT: vst $vr0, $a0, 0
; CHECK-NEXT: ret
entry:
%v0 = load <4 x float>, ptr %a0, align 16
%sqrt = call <4 x float> @llvm.sqrt.v4f32 (<4 x float> %v0)
store <4 x float> %sqrt, ptr %res, align 16
ret void
}

define void @sqrt_v2f64(ptr %res, ptr %a0) nounwind {
; CHECK-LABEL: sqrt_v2f64:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vld $vr0, $a1, 0
; CHECK-NEXT: vfsqrt.d $vr0, $vr0
; CHECK-NEXT: vst $vr0, $a0, 0
; CHECK-NEXT: ret
entry:
%v0 = load <2 x double>, ptr %a0, align 16
%sqrt = call <2 x double> @llvm.sqrt.v2f64 (<2 x double> %v0)
store <2 x double> %sqrt, ptr %res, align 16
ret void
}

;; 1.0 / (fsqrt vec)
define void @one_div_sqrt_v4f32(ptr %res, ptr %a0) nounwind {
; CHECK-LABEL: one_div_sqrt_v4f32:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vld $vr0, $a1, 0
; CHECK-NEXT: vfrsqrt.s $vr0, $vr0
; CHECK-NEXT: vst $vr0, $a0, 0
; CHECK-NEXT: ret
entry:
%v0 = load <4 x float>, ptr %a0, align 16
%sqrt = call <4 x float> @llvm.sqrt.v4f32 (<4 x float> %v0)
%div = fdiv <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
store <4 x float> %div, ptr %res, align 16
ret void
}

define void @one_div_sqrt_v2f64(ptr %res, ptr %a0) nounwind {
; CHECK-LABEL: one_div_sqrt_v2f64:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vld $vr0, $a1, 0
; CHECK-NEXT: vfrsqrt.d $vr0, $vr0
; CHECK-NEXT: vst $vr0, $a0, 0
; CHECK-NEXT: ret
entry:
%v0 = load <2 x double>, ptr %a0, align 16
%sqrt = call <2 x double> @llvm.sqrt.v2f64 (<2 x double> %v0)
%div = fdiv <2 x double> <double 1.0, double 1.0>, %sqrt
store <2 x double> %div, ptr %res, align 16
ret void
}

declare <4 x float> @llvm.sqrt.v4f32(<4 x float>)
declare <2 x double> @llvm.sqrt.v2f64(<2 x double>)
29 changes: 29 additions & 0 deletions llvm/test/CodeGen/LoongArch/lsx/ir-instruction/fdiv.ll
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,32 @@ entry:
store <2 x double> %v2, ptr %res
ret void
}

;; 1.0 / vec
define void @one_fdiv_v4f32(ptr %res, ptr %a0) nounwind {
; CHECK-LABEL: one_fdiv_v4f32:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vld $vr0, $a1, 0
; CHECK-NEXT: vfrecip.s $vr0, $vr0
; CHECK-NEXT: vst $vr0, $a0, 0
; CHECK-NEXT: ret
entry:
%v0 = load <4 x float>, ptr %a0
%div = fdiv <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, %v0
store <4 x float> %div, ptr %res
ret void
}

define void @one_fdiv_v2f64(ptr %res, ptr %a0) nounwind {
; CHECK-LABEL: one_fdiv_v2f64:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vld $vr0, $a1, 0
; CHECK-NEXT: vfrecip.d $vr0, $vr0
; CHECK-NEXT: vst $vr0, $a0, 0
; CHECK-NEXT: ret
entry:
%v0 = load <2 x double>, ptr %a0
%div = fdiv <2 x double> <double 1.0, double 1.0>, %v0
store <2 x double> %div, ptr %res
ret void
}