-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[LLVM][CodeGen] Add lowering for scalable vector bfloat operations. #109803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-amdgpu @llvm/pr-subscribers-backend-aarch64 Author: Paul Walker (paulwalker-arm) ChangesSpecifically: Patch is 51.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109803.diff 8 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index d6c2c36a0d482a..c7e0c704efceff 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1565,6 +1565,12 @@ class SelectionDAG {
SDValue getSetFPEnv(SDValue Chain, const SDLoc &dl, SDValue Ptr, EVT MemVT,
MachineMemOperand *MMO);
+ SDValue getExtractSubvector(const SDLoc &DL, EVT VT, SDValue V,
+ uint64_t Idx) {
+ return getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V,
+ getVectorIdxConstant(Idx, DL));
+ }
+
/// Construct a node to track a Value* through the backend.
SDValue getSrcValue(const Value *v);
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 6067b3b29ea181..82bba661dba0f9 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -191,6 +191,9 @@ class SDValue {
return getValueType().getSimpleVT();
}
+ /// Return the scalar ValueType of the referenced return value.
+ EVT getScalarValueType() const { return getValueType().getScalarType(); }
+
/// Returns the size of the value in bits.
///
/// If the value type is a scalable vector type, the scalable property will
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4166d9bd22bc01..c77d9631b5ffab 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1663,12 +1663,32 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
+ setOperationAction(ISD::FCEIL, VT, Custom);
+ setOperationAction(ISD::FDIV, VT, Custom);
+ setOperationAction(ISD::FFLOOR, VT, Custom);
+ setOperationAction(ISD::FMA, VT, Custom);
+ setOperationAction(ISD::FMAXIMUM, VT, Custom);
+ setOperationAction(ISD::FMAXNUM, VT, Custom);
+ setOperationAction(ISD::FMINIMUM, VT, Custom);
+ setOperationAction(ISD::FMINNUM, VT, Custom);
+ setOperationAction(ISD::FNEARBYINT, VT, Custom);
setOperationAction(ISD::FP_EXTEND, VT, Custom);
setOperationAction(ISD::FP_ROUND, VT, Custom);
+ setOperationAction(ISD::FRINT, VT, Custom);
+ setOperationAction(ISD::FROUND, VT, Custom);
+ setOperationAction(ISD::FROUNDEVEN, VT, Custom);
+ setOperationAction(ISD::FSQRT, VT, Custom);
+ setOperationAction(ISD::FTRUNC, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
+
+ if (!Subtarget->hasSVEB16B16()) {
+ setOperationAction(ISD::FADD, VT, Custom);
+ setOperationAction(ISD::FMUL, VT, Custom);
+ setOperationAction(ISD::FSUB, VT, Custom);
+ }
}
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7051,32 +7071,58 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::UMULO:
return LowerXALUO(Op, DAG);
case ISD::FADD:
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED);
case ISD::FSUB:
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
case ISD::FMUL:
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
case ISD::FMA:
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
case ISD::FDIV:
+ if (Op.getScalarValueType() == MVT::bf16)
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
case ISD::FNEG:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU);
case ISD::FCEIL:
+ if (Op.getScalarValueType() == MVT::bf16)
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FCEIL_MERGE_PASSTHRU);
case ISD::FFLOOR:
+ if (Op.getScalarValueType() == MVT::bf16)
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FFLOOR_MERGE_PASSTHRU);
case ISD::FNEARBYINT:
+ if (Op.getScalarValueType() == MVT::bf16)
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEARBYINT_MERGE_PASSTHRU);
case ISD::FRINT:
+ if (Op.getScalarValueType() == MVT::bf16)
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FRINT_MERGE_PASSTHRU);
case ISD::FROUND:
+ if (Op.getScalarValueType() == MVT::bf16)
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUND_MERGE_PASSTHRU);
case ISD::FROUNDEVEN:
+ if (Op.getScalarValueType() == MVT::bf16)
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU);
case ISD::FTRUNC:
+ if (Op.getScalarValueType() == MVT::bf16)
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU);
case ISD::FSQRT:
+ if (Op.getScalarValueType() == MVT::bf16)
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU);
case ISD::FABS:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FABS_MERGE_PASSTHRU);
@@ -7242,12 +7288,20 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::SUB:
return LowerToScalableOp(Op, DAG);
case ISD::FMAXIMUM:
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAX_PRED);
case ISD::FMAXNUM:
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAXNM_PRED);
case ISD::FMINIMUM:
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMIN_PRED);
case ISD::FMINNUM:
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
+ return LowerBFloatOp(Op, DAG);
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMINNM_PRED);
case ISD::VSELECT:
return LowerFixedLengthVectorSelectToSVE(Op, DAG);
@@ -28466,6 +28520,40 @@ SDValue AArch64TargetLowering::LowerFixedLengthInsertVectorElt(
return convertFromScalableVector(DAG, VT, ScalableRes);
}
+// Lower bfloat16 operations by upcasting to float32, performing the operation
+// and then downcasting the result back to bfloat16.
+SDValue AArch64TargetLowering::LowerBFloatOp(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ EVT VT = Op.getValueType();
+ assert(isTypeLegal(VT) && VT.isScalableVector() && "Unexpected type!");
+
+ // Split the vector and try again.
+ if (VT == MVT::nxv8bf16) {
+ SmallVector<SDValue, 4> LoOps, HiOps;
+ for (const SDValue &V : Op->op_values()) {
+ LoOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 0));
+ HiOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 4));
+ }
+
+ unsigned Opc = Op.getOpcode();
+ SDValue SplitOpLo = DAG.getNode(Opc, DL, MVT::nxv4bf16, LoOps);
+ SDValue SplitOpHi = DAG.getNode(Opc, DL, MVT::nxv4bf16, HiOps);
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
+ }
+
+ // Promote to float and try again.
+ EVT PromoteVT = VT.changeVectorElementType(MVT::f32);
+
+ SmallVector<SDValue, 4> Ops;
+ for (const SDValue &V : Op->op_values())
+ Ops.push_back(DAG.getNode(ISD::FP_EXTEND, DL, PromoteVT, V));
+
+ SDValue PromotedOp = DAG.getNode(Op.getOpcode(), DL, PromoteVT, Ops);
+ return DAG.getNode(ISD::FP_ROUND, DL, VT, PromotedOp,
+ DAG.getIntPtrConstant(0, DL, true));
+}
+
// Convert vector operation 'Op' to an equivalent predicated operation whereby
// the original operation's type is used to construct a suitable predicate.
// NOTE: The results for inactive lanes are undefined.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 480bf60360bf55..8c06214bba5b54 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1224,6 +1224,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerBFloatOp(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const;
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 7240f6a22a87bd..078f4f2e14cabf 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -663,6 +663,13 @@ let Predicates = [HasSVEorSME] in {
defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs", AArch64fabs_mt>;
defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", AArch64fneg_mt>;
+ foreach VT = [nxv2bf16, nxv4bf16, nxv8bf16] in {
+ def : Pat<(VT (fabs VT:$op)),
+ (AND_ZI $op, (i64 (logical_imm64_XFORM(i64 0x7fff7fff7fff7fff))))>;
+ def : Pat<(VT (fneg VT:$op)),
+ (EOR_ZI $op, (i64 (logical_imm64_XFORM(i64 0x8000800080008000))))>;
+ }
+
// zext(cmpeq(x, splat(0))) -> cnot(x)
def : Pat<(nxv16i8 (zext (nxv16i1 (AArch64setcc_z (nxv16i1 (SVEAllActive):$Pg), nxv16i8:$Op2, (SVEDup0), SETEQ)))),
(CNOT_ZPmZ_B $Op2, $Pg, $Op2)>;
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 0bfac6465a1f30..c7059b8e4e8d4a 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -2299,6 +2299,8 @@ multiclass sve_fp_3op_u_zd_bfloat<bits<3> opc, string asm, SDPatternOperator op>
def NAME : sve_fp_3op_u_zd<0b00, opc, asm, ZPR16>;
def : SVE_2_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME)>;
+ def : SVE_2_Op_Pat<nxv4bf16, op, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME)>;
+ def : SVE_2_Op_Pat<nxv2bf16, op, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME)>;
}
multiclass sve_fp_3op_u_zd_ftsmul<bits<3> opc, string asm, SDPatternOperator op> {
@@ -9078,6 +9080,8 @@ multiclass sve_fp_bin_pred_bfloat<SDPatternOperator op> {
def _UNDEF : PredTwoOpPseudo<NAME, ZPR16, FalseLanesUndef>;
def : SVE_3_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, !cast<Pseudo>(NAME # _UNDEF)>;
+ def : SVE_3_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, !cast<Pseudo>(NAME # _UNDEF)>;
+ def : SVE_3_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, !cast<Pseudo>(NAME # _UNDEF)>;
}
// Predicated pseudo floating point three operand instructions.
@@ -9099,6 +9103,8 @@ multiclass sve_fp_3op_pred_bfloat<SDPatternOperator op> {
def _UNDEF : PredThreeOpPseudo<NAME, ZPR16, FalseLanesUndef>;
def : SVE_4_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME # _UNDEF)>;
+ def : SVE_4_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME # _UNDEF)>;
+ def : SVE_4_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME # _UNDEF)>;
}
// Predicated pseudo integer two operand instructions.
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
new file mode 100644
index 00000000000000..e8468ddfeed181
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
@@ -0,0 +1,752 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
+; RUN: llc -mattr=+sve,+bf16,+sve-b16b16 < %s | FileCheck %s --check-prefixes=CHECK,B16B16
+; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,NOB16B16
+; RUN: llc -mattr=+sme,+sve-b16b16 -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,B16B16
+
+target triple = "aarch64-unknown-linux-gnu"
+
+;
+; FABS
+;
+
+define <vscale x 2 x bfloat> @fabs_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fabs_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: and z0.h, z0.h, #0x7fff
+; CHECK-NEXT: ret
+ %res = call <vscale x 2 x bfloat> @llvm.fabs.nxv2bf16(<vscale x 2 x bfloat> %a)
+ ret <vscale x 2 x bfloat> %res
+}
+
+define <vscale x 4 x bfloat> @fabs_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fabs_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: and z0.h, z0.h, #0x7fff
+; CHECK-NEXT: ret
+ %res = call <vscale x 4 x bfloat> @llvm.fabs.nxv4bf16(<vscale x 4 x bfloat> %a)
+ ret <vscale x 4 x bfloat> %res
+}
+
+define <vscale x 8 x bfloat> @fabs_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fabs_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: and z0.h, z0.h, #0x7fff
+; CHECK-NEXT: ret
+ %res = call <vscale x 8 x bfloat> @llvm.fabs.nxv8bf16(<vscale x 8 x bfloat> %a)
+ ret <vscale x 8 x bfloat> %res
+}
+
+;
+; FADD
+;
+
+define <vscale x 2 x bfloat> @fadd_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
+; NOB16B16-LABEL: fadd_nxv2bf16:
+; NOB16B16: // %bb.0:
+; NOB16B16-NEXT: lsl z1.s, z1.s, #16
+; NOB16B16-NEXT: lsl z0.s, z0.s, #16
+; NOB16B16-NEXT: ptrue p0.d
+; NOB16B16-NEXT: fadd z0.s, p0/m, z0.s, z1.s
+; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT: ret
+;
+; B16B16-LABEL: fadd_nxv2bf16:
+; B16B16: // %bb.0:
+; B16B16-NEXT: bfadd z0.h, z0.h, z1.h
+; B16B16-NEXT: ret
+ %res = fadd <vscale x 2 x bfloat> %a, %b
+ ret <vscale x 2 x bfloat> %res
+}
+
+define <vscale x 4 x bfloat> @fadd_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
+; NOB16B16-LABEL: fadd_nxv4bf16:
+; NOB16B16: // %bb.0:
+; NOB16B16-NEXT: lsl z1.s, z1.s, #16
+; NOB16B16-NEXT: lsl z0.s, z0.s, #16
+; NOB16B16-NEXT: ptrue p0.s
+; NOB16B16-NEXT: fadd z0.s, z0.s, z1.s
+; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT: ret
+;
+; B16B16-LABEL: fadd_nxv4bf16:
+; B16B16: // %bb.0:
+; B16B16-NEXT: bfadd z0.h, z0.h, z1.h
+; B16B16-NEXT: ret
+ %res = fadd <vscale x 4 x bfloat> %a, %b
+ ret <vscale x 4 x bfloat> %res
+}
+
+define <vscale x 8 x bfloat> @fadd_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
+; NOB16B16-LABEL: fadd_nxv8bf16:
+; NOB16B16: // %bb.0:
+; NOB16B16-NEXT: uunpkhi z2.s, z1.h
+; NOB16B16-NEXT: uunpkhi z3.s, z0.h
+; NOB16B16-NEXT: uunpklo z1.s, z1.h
+; NOB16B16-NEXT: uunpklo z0.s, z0.h
+; NOB16B16-NEXT: ptrue p0.s
+; NOB16B16-NEXT: lsl z2.s, z2.s, #16
+; NOB16B16-NEXT: lsl z3.s, z3.s, #16
+; NOB16B16-NEXT: lsl z1.s, z1.s, #16
+; NOB16B16-NEXT: lsl z0.s, z0.s, #16
+; NOB16B16-NEXT: fadd z2.s, z3.s, z2.s
+; NOB16B16-NEXT: fadd z0.s, z0.s, z1.s
+; NOB16B16-NEXT: bfcvt z1.h, p0/m, z2.s
+; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT: uzp1 z0.h, z0.h, z1.h
+; NOB16B16-NEXT: ret
+;
+; B16B16-LABEL: fadd_nxv8bf16:
+; B16B16: // %bb.0:
+; B16B16-NEXT: bfadd z0.h, z0.h, z1.h
+; B16B16-NEXT: ret
+ %res = fadd <vscale x 8 x bfloat> %a, %b
+ ret <vscale x 8 x bfloat> %res
+}
+
+;
+; FDIV
+;
+
+define <vscale x 2 x bfloat> @fdiv_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
+; CHECK-LABEL: fdiv_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fdiv z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: bfcvt z0.h, p0/m, z0.s
+; CHECK-NEXT: ret
+ %res = fdiv <vscale x 2 x bfloat> %a, %b
+ ret <vscale x 2 x bfloat> %res
+}
+
+define <vscale x 4 x bfloat> @fdiv_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
+; CHECK-LABEL: fdiv_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fdiv z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: bfcvt z0.h, p0/m, z0.s
+; CHECK-NEXT: ret
+ %res = fdiv <vscale x 4 x bfloat> %a, %b
+ ret <vscale x 4 x bfloat> %res
+}
+
+define <vscale x 8 x bfloat> @fdiv_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
+; CHECK-LABEL: fdiv_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpkhi z2.s, z1.h
+; CHECK-NEXT: uunpkhi z3.s, z0.h
+; CHECK-NEXT: uunpklo z1.s, z1.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: lsl z2.s, z2.s, #16
+; CHECK-NEXT: lsl z3.s, z3.s, #16
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: fdivr z2.s, p0/m, z2.s, z3.s
+; CHECK-NEXT: fdiv z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: bfcvt z1.h, p0/m, z2.s
+; CHECK-NEXT: bfcvt z0.h, p0/m, z0.s
+; CHECK-NEXT: uzp1 z0.h, z0.h, z1.h
+; CHECK-NEXT: ret
+ %res = fdiv <vscale x 8 x bfloat> %a, %b
+ ret <vscale x 8 x bfloat> %res
+}
+
+;
+; FMAX
+;
+
+define <vscale x 2 x bfloat> @fmax_nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
+; NOB16B16-LABEL: fmax_nxv2bf16:
+; NOB16B16: // %bb.0:
+; NOB16B16-NEXT: lsl z1.s, z1.s, #16
+; NOB16B16-NEXT: lsl z0.s, z0.s, #16
+; NOB16B16-NEXT: ptrue p0.d
+; NOB16B16-NEXT: fmax z0.s, p0/m, z0.s, z1.s
+; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT: ret
+;
+; B16B16-LABEL: fmax_nxv2bf16:
+; B16B16: // %bb.0:
+; B16B16-NEXT: ptrue p0.d
+; B16B16-NEXT: bfmax z0.h, p0/m, z0.h, z1.h
+; B16B16-NEXT: ret
+ %res = call <vscale x 2 x bfloat> @llvm.maximum.nxv2bf16(<vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b)
+ ret <vscale x 2 x bfloat> %res
+}
+
+define <vscale x 4 x bfloat> @fmax_nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
+; NOB16B16-LABEL: fmax_nxv4bf16:
+; NOB16B16: // %bb.0:
+; NOB16B16-NEXT: lsl z1.s, z1.s, #16
+; NOB16B16-NEXT: lsl z0.s, z0.s, #16
+; NOB16B16-NEXT: ptrue p0.s
+; NOB16B16-NEXT: fmax z0.s, p0/m, z0.s, z1.s
+; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT: ret
+;
+; B16B16-LABEL: fmax_nxv4bf16:
+; B16B16: // %bb.0:
+; B16B16-NEXT: ptrue p0.s
+; B16B16-NEXT: bfmax z0.h, p0/m, z0.h, z1.h
+; B16B16-NEXT: ret
+ %res = call <vscale x 4 x bfloat> @llvm.maximum.nxv4bf16(<vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b)
+ ret <vscale x 4 x bfloat> %res
+}
+
+define <vscale x 8 x bfloat> @fmax_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
+; NOB16B16-LABEL: fmax_nxv8bf16:
+; NOB16B16: // %bb.0:
+; NOB16B16-NEXT: uunpkhi z2.s, z1.h
+; NOB16B16-NEXT: uunpkhi z3.s, z0.h
+; NOB16B16-NEXT: uunpklo z1.s, z1.h
+; NOB16B16-NEXT: uunpklo z0.s, z0.h
+; NOB16B16-NEXT: ptrue p0.s
+; NOB16B16-NEXT: lsl z2.s, z2.s, #16
+; NOB16B16-NEXT: lsl z3.s, z3.s, #16
+; NOB16B16-NEXT: lsl z1.s, z1.s, #16
+; NOB16B16-NEXT: lsl z0.s, z0.s, #16
+; NOB16B16-NEXT: fmax z2.s, p0/m, z2.s, z3.s
+; NOB16B16-NEXT: fmax z0.s, p0/m, z0.s, z1.s
+; NOB16B16-NEXT: bfcvt z1.h, p0/m, z2.s
+; NOB16B16-NEXT: bfcvt z0.h, p0/m, z0.s
+; NOB16B16-NEXT: uzp1 z0.h, z0.h, z1.h
+; NOB16B16-NEXT: ret
+;
+; B16B16-LABEL: fmax_nxv8bf16:
+; B16B16: // %bb.0:
+; B16B16-NEXT: ptrue p0.h
+; B16B16-NEXT: bfmax z0.h, p0/m, z0.h, z1.h
+; B16B16-NEXT: ret
+ %res = call <vscale x 8 x bfloat> @llvm.maximum.nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b)
+ ret <vscale x 8 x bfloat> %res
+}
+
+;
...
[truncated]
|
SmallVector<SDValue, 4> LoOps, HiOps; | ||
for (const SDValue &V : Op->op_values()) { | ||
LoOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 0)); | ||
HiOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 4)); | ||
} | ||
|
||
unsigned Opc = Op.getOpcode(); | ||
SDValue SplitOpLo = DAG.getNode(Opc, DL, MVT::nxv4bf16, LoOps); | ||
SDValue SplitOpHi = DAG.getNode(Opc, DL, MVT::nxv4bf16, HiOps); | ||
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the generic vector legalizer could have gotten this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just me not looking for trouble :)
I've got a target neutral implementation but as you'd expect the effect is more widespread (all positive so far), so I'll pull it into a separate PR and then update this one if/when it lands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was perhaps being too adventurous but by restricting the expansion to bf16 vectors the fallout shrank significantly so I've kept it with this PR to maintain sufficient testing.
SmallVector<SDValue, 4> Ops; | ||
for (const SDValue &V : Op->op_values()) | ||
Ops.push_back(DAG.getNode(ISD::FP_EXTEND, DL, PromoteVT, V)); | ||
|
||
SDValue PromotedOp = DAG.getNode(Op.getOpcode(), DL, PromoteVT, Ops); | ||
return DAG.getNode(ISD::FP_ROUND, DL, VT, PromotedOp, | ||
DAG.getIntPtrConstant(0, DL, true)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very generic and I would hope doesn't need repeating in target code
c638ddd
to
26b71a3
Compare
@@ -1070,13 +1071,21 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) { | |||
break; | |||
case ISD::FMINNUM: | |||
case ISD::FMAXNUM: | |||
if (SDValue Expanded = ExpandBF16Arith(Node)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably shouldn't just bypass the dedicated function to lower something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was a little unsure how generic expandFMINNUM_FMAXNUM
actually is. The description seemed pretty specific alongside a use in SIISelLowering, so wasn't sure if I could just add the splitting code into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've moved the expansion function into TargetLowering and updated the two expandFMIN###_FMAXNUM### functions to make us of it.
// Try to lower BFloat arithmetic by performing the same operation on operands | ||
// that have been promoted to Float32, the result of which is then truncated. | ||
// If promotion requires non-legal types the operation is split with the | ||
// promotion occuring during a successive call to this function. | ||
SDValue VectorLegalizer::ExpandBF16Arith(SDNode *Node) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just the promote action? It shouldn't be getting used for expand?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I misunderstood your original question. I hadn't realised promotion was an option for floating point types. I will investigate further.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed all the promotion code and am instead making use of common code via setOperationPromotedToType.
Specifically: fabs, fadd, fceil, fdiv, ffloor, fma, fmax, fmaxnm, fmin, fminnm, fmul, fnearbyint, fneg, frint, fround, froundeven, fsub, fsqrt & ftrunc
26b71a3
to
c4b5958
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -663,6 +663,13 @@ let Predicates = [HasSVEorSME] in { | |||
defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs", AArch64fabs_mt>; | |||
defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", AArch64fneg_mt>; | |||
|
|||
foreach VT = [nxv2bf16, nxv4bf16, nxv8bf16] in { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: May be worth a comment clarifying that you're zeroing or inverting the sign bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
return SDValue(); | ||
|
||
EVT LoVT, HiVT; | ||
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a request, more of an style option to match the use of auto elsewhere in the patch for you to consider.
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT); | |
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VT); |
That lets you get rid of the separate declaration line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nice. Done.
SmallVector<SDValue, 4> LoOps, HiOps; | ||
for (const SDValue &V : Node->op_values()) { | ||
SDValue Lo, Hi; | ||
std::tie(Lo, Hi) = DAG.SplitVector(V, DL, LoVT, HiVT); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same structured binding style can be applied here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/59/builds/6177 Here is the relevant piece of the build log for the reference
|
Specifically:
fabs, fadd, fceil, fdiv, ffloor, fma, fmax, fmaxnm, fmin, fminnm,
fmul, fnearbyint, fneg, frint, fround, froundeven, fsub, fsqrt &
ftrunc