@@ -1663,12 +1663,32 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1663
1663
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
1664
1664
setOperationAction(ISD::BITCAST, VT, Custom);
1665
1665
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1666
+ setOperationAction(ISD::FCEIL, VT, Custom);
1667
+ setOperationAction(ISD::FDIV, VT, Custom);
1668
+ setOperationAction(ISD::FFLOOR, VT, Custom);
1669
+ setOperationAction(ISD::FMA, VT, Custom);
1670
+ setOperationAction(ISD::FMAXIMUM, VT, Custom);
1671
+ setOperationAction(ISD::FMAXNUM, VT, Custom);
1672
+ setOperationAction(ISD::FMINIMUM, VT, Custom);
1673
+ setOperationAction(ISD::FMINNUM, VT, Custom);
1674
+ setOperationAction(ISD::FNEARBYINT, VT, Custom);
1666
1675
setOperationAction(ISD::FP_EXTEND, VT, Custom);
1667
1676
setOperationAction(ISD::FP_ROUND, VT, Custom);
1677
+ setOperationAction(ISD::FRINT, VT, Custom);
1678
+ setOperationAction(ISD::FROUND, VT, Custom);
1679
+ setOperationAction(ISD::FROUNDEVEN, VT, Custom);
1680
+ setOperationAction(ISD::FSQRT, VT, Custom);
1681
+ setOperationAction(ISD::FTRUNC, VT, Custom);
1668
1682
setOperationAction(ISD::MLOAD, VT, Custom);
1669
1683
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1670
1684
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
1671
1685
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
1686
+
1687
+ if (!Subtarget->hasSVEB16B16()) {
1688
+ setOperationAction(ISD::FADD, VT, Custom);
1689
+ setOperationAction(ISD::FMUL, VT, Custom);
1690
+ setOperationAction(ISD::FSUB, VT, Custom);
1691
+ }
1672
1692
}
1673
1693
1674
1694
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7051,32 +7071,58 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
7051
7071
case ISD::UMULO:
7052
7072
return LowerXALUO(Op, DAG);
7053
7073
case ISD::FADD:
7074
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7075
+ return LowerBFloatOp(Op, DAG);
7054
7076
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED);
7055
7077
case ISD::FSUB:
7078
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7079
+ return LowerBFloatOp(Op, DAG);
7056
7080
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
7057
7081
case ISD::FMUL:
7082
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7083
+ return LowerBFloatOp(Op, DAG);
7058
7084
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
7059
7085
case ISD::FMA:
7086
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7087
+ return LowerBFloatOp(Op, DAG);
7060
7088
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
7061
7089
case ISD::FDIV:
7090
+ if (Op.getScalarValueType() == MVT::bf16)
7091
+ return LowerBFloatOp(Op, DAG);
7062
7092
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
7063
7093
case ISD::FNEG:
7064
7094
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU);
7065
7095
case ISD::FCEIL:
7096
+ if (Op.getScalarValueType() == MVT::bf16)
7097
+ return LowerBFloatOp(Op, DAG);
7066
7098
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FCEIL_MERGE_PASSTHRU);
7067
7099
case ISD::FFLOOR:
7100
+ if (Op.getScalarValueType() == MVT::bf16)
7101
+ return LowerBFloatOp(Op, DAG);
7068
7102
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FFLOOR_MERGE_PASSTHRU);
7069
7103
case ISD::FNEARBYINT:
7104
+ if (Op.getScalarValueType() == MVT::bf16)
7105
+ return LowerBFloatOp(Op, DAG);
7070
7106
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEARBYINT_MERGE_PASSTHRU);
7071
7107
case ISD::FRINT:
7108
+ if (Op.getScalarValueType() == MVT::bf16)
7109
+ return LowerBFloatOp(Op, DAG);
7072
7110
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FRINT_MERGE_PASSTHRU);
7073
7111
case ISD::FROUND:
7112
+ if (Op.getScalarValueType() == MVT::bf16)
7113
+ return LowerBFloatOp(Op, DAG);
7074
7114
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUND_MERGE_PASSTHRU);
7075
7115
case ISD::FROUNDEVEN:
7116
+ if (Op.getScalarValueType() == MVT::bf16)
7117
+ return LowerBFloatOp(Op, DAG);
7076
7118
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU);
7077
7119
case ISD::FTRUNC:
7120
+ if (Op.getScalarValueType() == MVT::bf16)
7121
+ return LowerBFloatOp(Op, DAG);
7078
7122
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU);
7079
7123
case ISD::FSQRT:
7124
+ if (Op.getScalarValueType() == MVT::bf16)
7125
+ return LowerBFloatOp(Op, DAG);
7080
7126
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU);
7081
7127
case ISD::FABS:
7082
7128
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FABS_MERGE_PASSTHRU);
@@ -7242,12 +7288,20 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
7242
7288
case ISD::SUB:
7243
7289
return LowerToScalableOp(Op, DAG);
7244
7290
case ISD::FMAXIMUM:
7291
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7292
+ return LowerBFloatOp(Op, DAG);
7245
7293
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAX_PRED);
7246
7294
case ISD::FMAXNUM:
7295
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7296
+ return LowerBFloatOp(Op, DAG);
7247
7297
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAXNM_PRED);
7248
7298
case ISD::FMINIMUM:
7299
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7300
+ return LowerBFloatOp(Op, DAG);
7249
7301
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMIN_PRED);
7250
7302
case ISD::FMINNUM:
7303
+ if (Op.getScalarValueType() == MVT::bf16 && !Subtarget->hasSVEB16B16())
7304
+ return LowerBFloatOp(Op, DAG);
7251
7305
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMINNM_PRED);
7252
7306
case ISD::VSELECT:
7253
7307
return LowerFixedLengthVectorSelectToSVE(Op, DAG);
@@ -28466,6 +28520,40 @@ SDValue AArch64TargetLowering::LowerFixedLengthInsertVectorElt(
28466
28520
return convertFromScalableVector(DAG, VT, ScalableRes);
28467
28521
}
28468
28522
28523
+ // Lower bfloat16 operations by upcasting to float32, performing the operation
28524
+ // and then downcasting the result back to bfloat16.
28525
+ SDValue AArch64TargetLowering::LowerBFloatOp(SDValue Op,
28526
+ SelectionDAG &DAG) const {
28527
+ SDLoc DL(Op);
28528
+ EVT VT = Op.getValueType();
28529
+ assert(isTypeLegal(VT) && VT.isScalableVector() && "Unexpected type!");
28530
+
28531
+ // Split the vector and try again.
28532
+ if (VT == MVT::nxv8bf16) {
28533
+ SmallVector<SDValue, 4> LoOps, HiOps;
28534
+ for (const SDValue &V : Op->op_values()) {
28535
+ LoOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 0));
28536
+ HiOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 4));
28537
+ }
28538
+
28539
+ unsigned Opc = Op.getOpcode();
28540
+ SDValue SplitOpLo = DAG.getNode(Opc, DL, MVT::nxv4bf16, LoOps);
28541
+ SDValue SplitOpHi = DAG.getNode(Opc, DL, MVT::nxv4bf16, HiOps);
28542
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
28543
+ }
28544
+
28545
+ // Promote to float and try again.
28546
+ EVT PromoteVT = VT.changeVectorElementType(MVT::f32);
28547
+
28548
+ SmallVector<SDValue, 4> Ops;
28549
+ for (const SDValue &V : Op->op_values())
28550
+ Ops.push_back(DAG.getNode(ISD::FP_EXTEND, DL, PromoteVT, V));
28551
+
28552
+ SDValue PromotedOp = DAG.getNode(Op.getOpcode(), DL, PromoteVT, Ops);
28553
+ return DAG.getNode(ISD::FP_ROUND, DL, VT, PromotedOp,
28554
+ DAG.getIntPtrConstant(0, DL, true));
28555
+ }
28556
+
28469
28557
// Convert vector operation 'Op' to an equivalent predicated operation whereby
28470
28558
// the original operation's type is used to construct a suitable predicate.
28471
28559
// NOTE: The results for inactive lanes are undefined.
0 commit comments