Skip to content

[LLVM][CodeGen] Add lowering for scalable vector bfloat operations. #109803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 7, 2024

Conversation

paulwalker-arm
Copy link
Collaborator

Specifically:
fabs, fadd, fceil, fdiv, ffloor, fma, fmax, fmaxnm, fmin, fminnm,
fmul, fnearbyint, fneg, frint, fround, froundeven, fsub, fsqrt &
ftrunc

@llvmbot llvmbot added backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well labels Sep 24, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 24, 2024

@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-aarch64

Author: Paul Walker (paulwalker-arm)

Changes

Specifically:
fabs, fadd, fceil, fdiv, ffloor, fma, fmax, fmaxnm, fmin, fminnm,
fmul, fnearbyint, fneg, frint, fround, froundeven, fsub, fsqrt &
ftrunc


Patch is 51.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109803.diff

8 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+6)
  • (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+3)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+88)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+1)
  • (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+7)
  • (modified) llvm/lib/Target/AArch64/SVEInstrFormats.td (+6)
  • (added) llvm/test/CodeGen/AArch64/sve-bf16-arith.ll (+752)
  • (added) llvm/test/CodeGen/AArch64/sve-bf16-rounding.ll (+355)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index d6c2c36a0d482a..c7e0c704efceff 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1565,6 +1565,12 @@ class SelectionDAG {
   SDValue getSetFPEnv(SDValue Chain, const SDLoc &dl, SDValue Ptr, EVT MemVT,
                       MachineMemOperand *MMO);
 
+  SDValue getExtractSubvector(const SDLoc &DL, EVT VT, SDValue V,
+                              uint64_t Idx) {
+    return getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V,
+                   getVectorIdxConstant(Idx, DL));
+  }
+
   /// Construct a node to track a Value* through the backend.
   SDValue getSrcValue(const Value *v);
 
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 6067b3b29ea181..82bba661dba0f9 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -191,6 +191,9 @@ class SDValue {
     return getValueType().getSimpleVT();
   }
 
+  /// Return the scalar ValueType of the referenced return value.
+  EVT getScalarValueType() const { return getValueType().getScalarType(); }
+
   /// Returns the size of the value in bits.
   ///
   /// If the value type is a scalable vector type, the scalable property will
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4166d9bd22bc01..c77d9631b5ffab 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1663,12 +1663,32 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
       setOperationAction(ISD::BITCAST, VT, Custom);
       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
+      setOperationAction(ISD::FCEIL, VT, Custom);
+      setOperationAction(ISD::FDIV, VT, Custom);
+      setOperationAction(ISD::FFLOOR, VT, Custom);
+      setOperationAction(ISD::FMA, VT, Custom);
+      setOperationAction(ISD::FMAXIMUM, VT, Custom);
+      setOperationAction(ISD::FMAXNUM, VT, Custom);
+      setOperationAction(ISD::FMINIMUM, VT, Custom);
+      setOperationAction(ISD::FMINNUM, VT, Custom);
+      setOperationAction(ISD::FNEARBYINT, VT, Custom);
       setOperationAction(ISD::FP_EXTEND, VT, Custom);
       setOperationAction(ISD::FP_ROUND, VT, Custom);
+      setOperationAction(ISD::FRINT, VT, Custom);
+      setOperationAction(ISD::FROUND, VT, Custom);
+      setOperationAction(ISD::FROUNDEVEN, VT, Custom);
+      setOperationAction(ISD::FSQRT, VT, Custom);
+      setOperationAction(ISD::FTRUNC, VT, Custom);
       setOperationAction(ISD::MLOAD, VT, Custom);
       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
       setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
       setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
+
+      if (!Subtarget->hasSVEB16B16()) {
+        setOperationAction(ISD::FADD, VT, Custom);
+        setOperationAction(ISD::FMUL, VT, Custom);
+        setOperationAction(ISD::FSUB, VT, Custom);
+      }
     }
 
     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7051,32 +7071,58 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   case ISD::UMULO:
     return LowerXALUO(Op, DAG);
   case ISD::FADD:
+    if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED);
   case ISD::FSUB:
+    if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
   case ISD::FMUL:
+    if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
   case ISD::FMA:
+    if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
   case ISD::FDIV:
+    if (Op.getScalarValueType() == MVT::bf16)
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
   case ISD::FNEG:
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU);
   case ISD::FCEIL:
+    if (Op.getScalarValueType() == MVT::bf16)
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FCEIL_MERGE_PASSTHRU);
   case ISD::FFLOOR:
+    if (Op.getScalarValueType() == MVT::bf16)
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FFLOOR_MERGE_PASSTHRU);
   case ISD::FNEARBYINT:
+    if (Op.getScalarValueType() == MVT::bf16)
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEARBYINT_MERGE_PASSTHRU);
   case ISD::FRINT:
+    if (Op.getScalarValueType() == MVT::bf16)
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FRINT_MERGE_PASSTHRU);
   case ISD::FROUND:
+    if (Op.getScalarValueType() == MVT::bf16)
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUND_MERGE_PASSTHRU);
   case ISD::FROUNDEVEN:
+    if (Op.getScalarValueType() == MVT::bf16)
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU);
   case ISD::FTRUNC:
+    if (Op.getScalarValueType() == MVT::bf16)
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU);
   case ISD::FSQRT:
+    if (Op.getScalarValueType() == MVT::bf16)
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU);
   case ISD::FABS:
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FABS_MERGE_PASSTHRU);
@@ -7242,12 +7288,20 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   case ISD::SUB:
     return LowerToScalableOp(Op, DAG);
   case ISD::FMAXIMUM:
+    if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAX_PRED);
   case ISD::FMAXNUM:
+    if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAXNM_PRED);
   case ISD::FMINIMUM:
+    if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMIN_PRED);
   case ISD::FMINNUM:
+    if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+      return LowerBFloatOp(Op, DAG);
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMINNM_PRED);
   case ISD::VSELECT:
     return LowerFixedLengthVectorSelectToSVE(Op, DAG);
@@ -28466,6 +28520,40 @@ SDValue AArch64TargetLowering::LowerFixedLengthInsertVectorElt(
   return convertFromScalableVector(DAG, VT, ScalableRes);
 }
 
+// Lower bfloat16 operations by upcasting to float32, performing the operation
+// and then downcasting the result back to bfloat16.
+SDValue AArch64TargetLowering::LowerBFloatOp(SDValue Op,
+                                             SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  EVT VT = Op.getValueType();
+  assert(isTypeLegal(VT) && VT.isScalableVector() && "Unexpected type!");
+
+  // Split the vector and try again.
+  if (VT == MVT::nxv8bf16) {
+    SmallVector<SDValue, 4> LoOps, HiOps;
+    for (const SDValue &V : Op->op_values()) {
+      LoOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 0));
+      HiOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 4));
+    }
+
+    unsigned Opc = Op.getOpcode();
+    SDValue SplitOpLo = DAG.getNode(Opc, DL, MVT::nxv4bf16, LoOps);
+    SDValue SplitOpHi = DAG.getNode(Opc, DL, MVT::nxv4bf16, HiOps);
+    return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
+  }
+
+  // Promote to float and try again.
+  EVT PromoteVT = VT.changeVectorElementType(MVT::f32);
+
+  SmallVector<SDValue, 4> Ops;
+  for (const SDValue &V : Op->op_values())
+    Ops.push_back(DAG.getNode(ISD::FP_EXTEND, DL, PromoteVT, V));
+
+  SDValue PromotedOp = DAG.getNode(Op.getOpcode(), DL, PromoteVT, Ops);
+  return DAG.getNode(ISD::FP_ROUND, DL, VT, PromotedOp,
+                     DAG.getIntPtrConstant(0, DL, true));
+}
+
 // Convert vector operation 'Op' to an equivalent predicated operation whereby
 // the original operation's type is used to construct a suitable predicate.
 // NOTE: The results for inactive lanes are undefined.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 480bf60360bf55..8c06214bba5b54 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1224,6 +1224,7 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerBFloatOp(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const;
 
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 7240f6a22a87bd..078f4f2e14cabf 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -663,6 +663,13 @@ let Predicates = [HasSVEorSME] in {
   defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs", AArch64fabs_mt>;
   defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", AArch64fneg_mt>;
 
+  foreach VT = [nxv2bf16, nxv4bf16, nxv8bf16] in {
+    def : Pat<(VT (fabs VT:$op)),
+              (AND_ZI $op, (i64 (logical_imm64_XFORM(i64 0x7fff7fff7fff7fff))))>;
+    def : Pat<(VT (fneg VT:$op)),
+              (EOR_ZI $op, (i64 (logical_imm64_XFORM(i64 0x8000800080008000))))>;
+  }
+
   // zext(cmpeq(x, splat(0))) -> cnot(x)
   def : Pat<(nxv16i8 (zext (nxv16i1 (AArch64setcc_z (nxv16i1 (SVEAllActive):$Pg), nxv16i8:$Op2, (SVEDup0), SETEQ)))),
             (CNOT_ZPmZ_B $Op2, $Pg, $Op2)>;
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 0bfac6465a1f30..c7059b8e4e8d4a 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -2299,6 +2299,8 @@ multiclass sve_fp_3op_u_zd_bfloat<bits<3> opc, string asm, SDPatternOperator op>
   def NAME : sve_fp_3op_u_zd<0b00, opc, asm, ZPR16>;
 
   def : SVE_2_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME)>;
+  def : SVE_2_Op_Pat<nxv4bf16, op, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME)>;
+  def : SVE_2_Op_Pat<nxv2bf16, op, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME)>;
 }
 
 multiclass sve_fp_3op_u_zd_ftsmul<bits<3> opc, string asm, SDPatternOperator op> {
@@ -9078,6 +9080,8 @@ multiclass sve_fp_bin_pred_bfloat<SDPatternOperator op> {
   def _UNDEF : PredTwoOpPseudo<NAME, ZPR16, FalseLanesUndef>;
 
   def : SVE_3_Op_Pat<nxv8bf16, op, nxv8i1,  nxv8bf16, nxv8bf16, !cast<Pseudo>(NAME # _UNDEF)>;
+  def : SVE_3_Op_Pat<nxv4bf16, op, nxv4i1,  nxv4bf16, nxv4bf16, !cast<Pseudo>(NAME # _UNDEF)>;
+  def : SVE_3_Op_Pat<nxv2bf16, op, nxv2i1,  nxv2bf16, nxv2bf16, !cast<Pseudo>(NAME # _UNDEF)>;
 }
 
 // Predicated pseudo floating point three operand instructions.
@@ -9099,6 +9103,8 @@ multiclass sve_fp_3op_pred_bfloat<SDPatternOperator op> {
   def _UNDEF : PredThreeOpPseudo<NAME, ZPR16, FalseLanesUndef>;
 
   def : SVE_4_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME # _UNDEF)>;
+  def : SVE_4_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME # _UNDEF)>;
+  def : SVE_4_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME # _UNDEF)>;
 }
 
 // Predicated pseudo integer two operand instructions.
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
new file mode 100644
index 00000000000000..e8468ddfeed181
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
@@ -0,0 +1,752 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve,+bf16                        < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
+; RUN: llc -mattr=+sve,+bf16,+sve-b16b16            < %s | FileCheck %s --check-prefixes=CHECK,B16B16
+; RUN: llc -mattr=+sme -force-streaming             < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
+; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,B16B16
+
+target triple = "aarch64-unknown-linux-gnu"
+
+;
+; FABS
+;
+
+define <vscale x 2 x bfloat> @fabs_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fabs_nxv2bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    and z0.h, z0.h, #0x7fff
+; CHECK-NEXT:    ret
+  %res = call <vscale x 2 x bfloat> @llvm.fabs.nxv2bf16(<vscale x 2 x bfloat> %a)
+  ret <vscale x 2 x bfloat> %res
+}
+
+define <vscale x 4 x bfloat> @fabs_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fabs_nxv4bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    and z0.h, z0.h, #0x7fff
+; CHECK-NEXT:    ret
+  %res = call <vscale x 4 x bfloat> @llvm.fabs.nxv4bf16(<vscale x 4 x bfloat> %a)
+  ret <vscale x 4 x bfloat> %res
+}
+
+define <vscale x 8 x bfloat> @fabs_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fabs_nxv8bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    and z0.h, z0.h, #0x7fff
+; CHECK-NEXT:    ret
+  %res = call <vscale x 8 x bfloat> @llvm.fabs.nxv8bf16(<vscale x 8 x bfloat> %a)
+  ret <vscale x 8 x bfloat> %res
+}
+
+;
+; FADD
+;
+
+define <vscale x 2 x bfloat> @fadd_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
+; NOB16B16-LABEL: fadd_nxv2bf16:
+; NOB16B16:       // %bb.0:
+; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
+; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
+; NOB16B16-NEXT:    ptrue p0.d
+; NOB16B16-NEXT:    fadd z0.s, p0/m, z0.s, z1.s
+; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT:    ret
+;
+; B16B16-LABEL: fadd_nxv2bf16:
+; B16B16:       // %bb.0:
+; B16B16-NEXT:    bfadd z0.h, z0.h, z1.h
+; B16B16-NEXT:    ret
+  %res = fadd <vscale x 2 x bfloat> %a, %b
+  ret <vscale x 2 x bfloat> %res
+}
+
+define <vscale x 4 x bfloat> @fadd_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
+; NOB16B16-LABEL: fadd_nxv4bf16:
+; NOB16B16:       // %bb.0:
+; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
+; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
+; NOB16B16-NEXT:    ptrue p0.s
+; NOB16B16-NEXT:    fadd z0.s, z0.s, z1.s
+; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT:    ret
+;
+; B16B16-LABEL: fadd_nxv4bf16:
+; B16B16:       // %bb.0:
+; B16B16-NEXT:    bfadd z0.h, z0.h, z1.h
+; B16B16-NEXT:    ret
+  %res = fadd <vscale x 4 x bfloat> %a, %b
+  ret <vscale x 4 x bfloat> %res
+}
+
+define <vscale x 8 x bfloat> @fadd_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
+; NOB16B16-LABEL: fadd_nxv8bf16:
+; NOB16B16:       // %bb.0:
+; NOB16B16-NEXT:    uunpkhi z2.s, z1.h
+; NOB16B16-NEXT:    uunpkhi z3.s, z0.h
+; NOB16B16-NEXT:    uunpklo z1.s, z1.h
+; NOB16B16-NEXT:    uunpklo z0.s, z0.h
+; NOB16B16-NEXT:    ptrue p0.s
+; NOB16B16-NEXT:    lsl z2.s, z2.s, #16
+; NOB16B16-NEXT:    lsl z3.s, z3.s, #16
+; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
+; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
+; NOB16B16-NEXT:    fadd z2.s, z3.s, z2.s
+; NOB16B16-NEXT:    fadd z0.s, z0.s, z1.s
+; NOB16B16-NEXT:    bfcvt z1.h, p0/m, z2.s
+; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT:    uzp1 z0.h, z0.h, z1.h
+; NOB16B16-NEXT:    ret
+;
+; B16B16-LABEL: fadd_nxv8bf16:
+; B16B16:       // %bb.0:
+; B16B16-NEXT:    bfadd z0.h, z0.h, z1.h
+; B16B16-NEXT:    ret
+  %res = fadd <vscale x 8 x bfloat> %a, %b
+  ret <vscale x 8 x bfloat> %res
+}
+
+;
+; FDIV
+;
+
+define <vscale x 2 x bfloat> @fdiv_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
+; CHECK-LABEL: fdiv_nxv2bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fdiv z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    bfcvt z0.h, p0/m, z0.s
+; CHECK-NEXT:    ret
+  %res = fdiv <vscale x 2 x bfloat> %a, %b
+  ret <vscale x 2 x bfloat> %res
+}
+
+define <vscale x 4 x bfloat> @fdiv_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
+; CHECK-LABEL: fdiv_nxv4bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fdiv z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    bfcvt z0.h, p0/m, z0.s
+; CHECK-NEXT:    ret
+  %res = fdiv <vscale x 4 x bfloat> %a, %b
+  ret <vscale x 4 x bfloat> %res
+}
+
+define <vscale x 8 x bfloat> @fdiv_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
+; CHECK-LABEL: fdiv_nxv8bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpkhi z2.s, z1.h
+; CHECK-NEXT:    uunpkhi z3.s, z0.h
+; CHECK-NEXT:    uunpklo z1.s, z1.h
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    lsl z2.s, z2.s, #16
+; CHECK-NEXT:    lsl z3.s, z3.s, #16
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    fdivr z2.s, p0/m, z2.s, z3.s
+; CHECK-NEXT:    fdiv z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    bfcvt z1.h, p0/m, z2.s
+; CHECK-NEXT:    bfcvt z0.h, p0/m, z0.s
+; CHECK-NEXT:    uzp1 z0.h, z0.h, z1.h
+; CHECK-NEXT:    ret
+  %res = fdiv <vscale x 8 x bfloat> %a, %b
+  ret <vscale x 8 x bfloat> %res
+}
+
+;
+; FMAX
+;
+
+define <vscale x 2 x bfloat> @fmax_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
+; NOB16B16-LABEL: fmax_nxv2bf16:
+; NOB16B16:       // %bb.0:
+; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
+; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
+; NOB16B16-NEXT:    ptrue p0.d
+; NOB16B16-NEXT:    fmax z0.s, p0/m, z0.s, z1.s
+; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT:    ret
+;
+; B16B16-LABEL: fmax_nxv2bf16:
+; B16B16:       // %bb.0:
+; B16B16-NEXT:    ptrue p0.d
+; B16B16-NEXT:    bfmax z0.h, p0/m, z0.h, z1.h
+; B16B16-NEXT:    ret
+  %res = call <vscale x 2 x bfloat> @llvm.maximum.nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b)
+  ret <vscale x 2 x bfloat> %res
+}
+
+define <vscale x 4 x bfloat> @fmax_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
+; NOB16B16-LABEL: fmax_nxv4bf16:
+; NOB16B16:       // %bb.0:
+; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
+; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
+; NOB16B16-NEXT:    ptrue p0.s
+; NOB16B16-NEXT:    fmax z0.s, p0/m, z0.s, z1.s
+; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT:    ret
+;
+; B16B16-LABEL: fmax_nxv4bf16:
+; B16B16:       // %bb.0:
+; B16B16-NEXT:    ptrue p0.s
+; B16B16-NEXT:    bfmax z0.h, p0/m, z0.h, z1.h
+; B16B16-NEXT:    ret
+  %res = call <vscale x 4 x bfloat> @llvm.maximum.nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b)
+  ret <vscale x 4 x bfloat> %res
+}
+
+define <vscale x 8 x bfloat> @fmax_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
+; NOB16B16-LABEL: fmax_nxv8bf16:
+; NOB16B16:       // %bb.0:
+; NOB16B16-NEXT:    uunpkhi z2.s, z1.h
+; NOB16B16-NEXT:    uunpkhi z3.s, z0.h
+; NOB16B16-NEXT:    uunpklo z1.s, z1.h
+; NOB16B16-NEXT:    uunpklo z0.s, z0.h
+; NOB16B16-NEXT:    ptrue p0.s
+; NOB16B16-NEXT:    lsl z2.s, z2.s, #16
+; NOB16B16-NEXT:    lsl z3.s, z3.s, #16
+; NOB16B16-NEXT:    lsl z1.s, z1.s, #16
+; NOB16B16-NEXT:    lsl z0.s, z0.s, #16
+; NOB16B16-NEXT:    fmax z2.s, p0/m, z2.s, z3.s
+; NOB16B16-NEXT:    fmax z0.s, p0/m, z0.s, z1.s
+; NOB16B16-NEXT:    bfcvt z1.h, p0/m, z2.s
+; NOB16B16-NEXT:    bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT:    uzp1 z0.h, z0.h, z1.h
+; NOB16B16-NEXT:    ret
+;
+; B16B16-LABEL: fmax_nxv8bf16:
+; B16B16:       // %bb.0:
+; B16B16-NEXT:    ptrue p0.h
+; B16B16-NEXT:    bfmax z0.h, p0/m, z0.h, z1.h
+; B16B16-NEXT:    ret
+  %res = call <vscale x 8 x bfloat> @llvm.maximum.nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b)
+  ret <vscale x 8 x bfloat> %res
+}
+
+;
...
[truncated]

@arsenm arsenm added the floating-point Floating-point math label Sep 24, 2024
Comment on lines 28533 to 28543
SmallVector<SDValue, 4> LoOps, HiOps;
for (const SDValue &V : Op->op_values()) {
LoOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 0));
HiOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 4));
}

unsigned Opc = Op.getOpcode();
SDValue SplitOpLo = DAG.getNode(Opc, DL, MVT::nxv4bf16, LoOps);
SDValue SplitOpHi = DAG.getNode(Opc, DL, MVT::nxv4bf16, HiOps);
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the generic vector legalizer could have gotten this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just me not looking for trouble :)

I've got a target neutral implementation but as you'd expect the effect is more widespread (all positive so far), so I'll pull it into a separate PR and then update this one if/when it lands.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was perhaps being too adventurous but by restricting the expansion to bf16 vectors the fallout shrank significantly so I've kept it with this PR to maintain sufficient testing.

Comment on lines 28548 to 28554
SmallVector<SDValue, 4> Ops;
for (const SDValue &V : Op->op_values())
Ops.push_back(DAG.getNode(ISD::FP_EXTEND, DL, PromoteVT, V));

SDValue PromotedOp = DAG.getNode(Op.getOpcode(), DL, PromoteVT, Ops);
return DAG.getNode(ISD::FP_ROUND, DL, VT, PromotedOp,
DAG.getIntPtrConstant(0, DL, true));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very generic and I would hope doesn't need repeating in target code

@@ -1070,13 +1071,21 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
break;
case ISD::FMINNUM:
case ISD::FMAXNUM:
if (SDValue Expanded = ExpandBF16Arith(Node)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably shouldn't just bypass the dedicated function to lower something

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was a little unsure how generic expandFMINNUM_FMAXNUM actually is. The description seemed pretty specific alongside a use in SIISelLowering, so wasn't sure if I could just add the splitting code into it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved the expansion function into TargetLowering and updated the two expandFMIN###_FMAXNUM### functions to make us of it.

Comment on lines 2169 to 2173
// Try to lower BFloat arithmetic by performing the same operation on operands
// that have been promoted to Float32, the result of which is then truncated.
// If promotion requires non-legal types the operation is split with the
// promotion occuring during a successive call to this function.
SDValue VectorLegalizer::ExpandBF16Arith(SDNode *Node) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just the promote action? It shouldn't be getting used for expand?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I misunderstood your original question. I hadn't realised promotion was an option for floating point types. I will investigate further.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed all the promotion code and am instead making use of common code via setOperationPromotedToType.

Specifically:
  fabs, fadd, fceil, fdiv, ffloor, fma, fmax, fmaxnm, fmin, fminnm,
  fmul, fnearbyint, fneg, frint, fround, froundeven, fsub, fsqrt &
  ftrunc
Copy link
Collaborator

@huntergr-arm huntergr-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -663,6 +663,13 @@ let Predicates = [HasSVEorSME] in {
defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs", AArch64fabs_mt>;
defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", AArch64fneg_mt>;

foreach VT = [nxv2bf16, nxv4bf16, nxv8bf16] in {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: May be worth a comment clarifying that you're zeroing or inverting the sign bit.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return SDValue();

EVT LoVT, HiVT;
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a request, more of an style option to match the use of auto elsewhere in the patch for you to consider.

Suggested change
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VT);

That lets you get rid of the separate declaration line.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice. Done.

SmallVector<SDValue, 4> LoOps, HiOps;
for (const SDValue &V : Node->op_values()) {
SDValue Lo, Hi;
std::tie(Lo, Hi) = DAG.SplitVector(V, DL, LoVT, HiVT);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same structured binding style can be applied here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@paulwalker-arm paulwalker-arm merged commit 02dd6b1 into llvm:main Oct 7, 2024
8 checks passed
@paulwalker-arm paulwalker-arm deleted the sve-bfloat-lowering branch October 7, 2024 12:02
@llvm-ci
Copy link
Collaborator

llvm-ci commented Oct 7, 2024

LLVM Buildbot has detected a new failure on builder lldb-aarch64-ubuntu running on linaro-lldb-aarch64-ubuntu while building llvm at step 6 "test".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/59/builds/6177

Here is the relevant piece of the build log for the reference
Step 6 (test) failure: build (failure)
...
PASS: lldb-unit :: ValueObject/./LLDBValueObjectTests/1/3 (2024 of 2033)
PASS: lldb-unit :: ValueObject/./LLDBValueObjectTests/2/3 (2025 of 2033)
PASS: lldb-unit :: tools/lldb-server/tests/./LLDBServerTests/0/2 (2026 of 2033)
PASS: lldb-unit :: tools/lldb-server/tests/./LLDBServerTests/1/2 (2027 of 2033)
PASS: lldb-unit :: Utility/./UtilityTests/7/8 (2028 of 2033)
PASS: lldb-unit :: Host/./HostTests/11/12 (2029 of 2033)
PASS: lldb-unit :: Target/./TargetTests/11/14 (2030 of 2033)
PASS: lldb-unit :: Host/./HostTests/3/12 (2031 of 2033)
PASS: lldb-unit :: Process/gdb-remote/./ProcessGdbRemoteTests/8/9 (2032 of 2033)
UNRESOLVED: lldb-api :: tools/lldb-server/TestLldbGdbServer.py (2033 of 2033)
******************** TEST 'lldb-api :: tools/lldb-server/TestLldbGdbServer.py' FAILED ********************
Script:
--
/usr/bin/python3.10 /home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/llvm-project/lldb/test/API/dotest.py -u CXXFLAGS -u CFLAGS --env ARCHIVER=/usr/local/bin/llvm-ar --env OBJCOPY=/usr/bin/llvm-objcopy --env LLVM_LIBS_DIR=/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/./lib --env LLVM_INCLUDE_DIR=/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/include --env LLVM_TOOLS_DIR=/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/./bin --arch aarch64 --build-dir /home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/lldb-test-build.noindex --lldb-module-cache-dir /home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/lldb-test-build.noindex/module-cache-lldb/lldb-api --clang-module-cache-dir /home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/lldb-test-build.noindex/module-cache-clang/lldb-api --executable /home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/./bin/lldb --compiler /home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/./bin/clang --dsymutil /home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/./bin/dsymutil --llvm-tools-dir /home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/./bin --lldb-obj-root /home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/tools/lldb --lldb-libs-dir /home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/./lib /home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/llvm-project/lldb/test/API/tools/lldb-server -p TestLldbGdbServer.py
--
Exit Code: 1

Command Output (stdout):
--
lldb version 20.0.0git (https://github.com/llvm/llvm-project.git revision 02dd6b1014c708591fa0e8d46efb328c513a86e7)
  clang revision 02dd6b1014c708591fa0e8d46efb328c513a86e7
  llvm revision 02dd6b1014c708591fa0e8d46efb328c513a86e7
Skipping the following test categories: ['libc++', 'dsym', 'gmodules', 'debugserver', 'objc']

--
Command Output (stderr):
--
UNSUPPORTED: LLDB (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/clang-aarch64) :: test_Hc_then_Csignal_signals_correct_thread_launch_debugserver (TestLldbGdbServer.LldbGdbServerTestCase) (test case does not fall in any category of interest for this run) 
PASS: LLDB (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/clang-aarch64) :: test_Hc_then_Csignal_signals_correct_thread_launch_llgs (TestLldbGdbServer.LldbGdbServerTestCase)
PASS: LLDB (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/clang-aarch64) :: test_Hg_fails_on_another_pid_llgs (TestLldbGdbServer.LldbGdbServerTestCase)
PASS: LLDB (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/clang-aarch64) :: test_Hg_fails_on_minus_one_pid_llgs (TestLldbGdbServer.LldbGdbServerTestCase)
PASS: LLDB (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/clang-aarch64) :: test_Hg_fails_on_zero_pid_llgs (TestLldbGdbServer.LldbGdbServerTestCase)
UNSUPPORTED: LLDB (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/clang-aarch64) :: test_Hg_switches_to_3_threads_launch_debugserver (TestLldbGdbServer.LldbGdbServerTestCase) (test case does not fall in any category of interest for this run) 
PASS: LLDB (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/clang-aarch64) :: test_Hg_switches_to_3_threads_launch_llgs (TestLldbGdbServer.LldbGdbServerTestCase)
UNSUPPORTED: LLDB (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/clang-aarch64) :: test_P_and_p_thread_suffix_work_debugserver (TestLldbGdbServer.LldbGdbServerTestCase) (test case does not fall in any category of interest for this run) 
PASS: LLDB (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/clang-aarch64) :: test_P_and_p_thread_suffix_work_llgs (TestLldbGdbServer.LldbGdbServerTestCase)
UNSUPPORTED: LLDB (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/clang-aarch64) :: test_P_writes_all_gpr_registers_debugserver (TestLldbGdbServer.LldbGdbServerTestCase) (test case does not fall in any category of interest for this run) 
PASS: LLDB (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/clang-aarch64) :: test_P_writes_all_gpr_registers_llgs (TestLldbGdbServer.LldbGdbServerTestCase)
UNSUPPORTED: LLDB (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/clang-aarch64) :: test_attach_commandline_continue_app_exits_debugserver (TestLldbGdbServer.LldbGdbServerTestCase) (test case does not fall in any category of interest for this run) 
Program aborted due to an unhandled Error:
Operation not permitted
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: /home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/lldb-server gdbserver --attach=3810585 --reverse-connect [127.0.0.1]:49795
 #0 0x0000aaaab0acaf54 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/lldb-server+0xb8af54)
 #1 0x0000aaaab0ac8f84 llvm::sys::RunSignalHandlers() (/home/tcwg-buildbot/worker/lldb-aarch64-ubuntu/build/bin/lldb-server+0xb88f84)
 #2 0x0000aaaab0acb664 SignalHandler(int) Signals.cpp:0:0
 #3 0x0000ffffab9737dc (linux-vdso.so.1+0x7dc)
 #4 0x0000ffffab17f200 __pthread_kill_implementation ./nptl/pthread_kill.c:44:76

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants