Skip to content

Commit 6c773a8

Browse files
[LLVM][SVE] Implement isel for bfloat fptoi and itofp operations. (#129713)
NOTE: This PR only considers scalable vectors because SVE VLS does not support bfloat (see useSVEForFixedLengthVectorVT()).
1 parent 449cdfa commit 6c773a8

File tree

3 files changed

+869
-34
lines changed

3 files changed

+869
-34
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4600,6 +4600,10 @@ SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
46004600
bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1;
46014601

46024602
if (VT.isScalableVector()) {
4603+
// Let common code split the operation.
4604+
if (SrcVT == MVT::nxv8f32)
4605+
return Op;
4606+
46034607
if (VT.getScalarType() != MVT::bf16)
46044608
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
46054609

@@ -4742,6 +4746,22 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
47424746
assert(!(IsStrict && VT.isScalableVector()) &&
47434747
"Unimplemented SVE support for STRICT_FP_to_INT!");
47444748

4749+
// f16 conversions are promoted to f32 when full fp16 is not supported.
4750+
if ((InVT.getVectorElementType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
4751+
InVT.getVectorElementType() == MVT::bf16) {
4752+
EVT NewVT = VT.changeElementType(MVT::f32);
4753+
SDLoc dl(Op);
4754+
if (IsStrict) {
4755+
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NewVT, MVT::Other},
4756+
{Op.getOperand(0), Op.getOperand(1)});
4757+
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
4758+
{Ext.getValue(1), Ext.getValue(0)});
4759+
}
4760+
return DAG.getNode(
4761+
Op.getOpcode(), dl, Op.getValueType(),
4762+
DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
4763+
}
4764+
47454765
if (VT.isScalableVector()) {
47464766
if (VT.getVectorElementType() == MVT::i1) {
47474767
SDLoc DL(Op);
@@ -4751,6 +4771,10 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
47514771
return DAG.getSetCC(DL, VT, Cvt, Zero, ISD::SETNE);
47524772
}
47534773

4774+
// Let common code split the operation.
4775+
if (InVT == MVT::nxv8f32)
4776+
return Op;
4777+
47544778
unsigned Opcode = Op.getOpcode() == ISD::FP_TO_UINT
47554779
? AArch64ISD::FCVTZU_MERGE_PASSTHRU
47564780
: AArch64ISD::FCVTZS_MERGE_PASSTHRU;
@@ -4761,24 +4785,6 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
47614785
useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
47624786
return LowerFixedLengthFPToIntToSVE(Op, DAG);
47634787

4764-
unsigned NumElts = InVT.getVectorNumElements();
4765-
4766-
// f16 conversions are promoted to f32 when full fp16 is not supported.
4767-
if ((InVT.getVectorElementType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
4768-
InVT.getVectorElementType() == MVT::bf16) {
4769-
MVT NewVT = MVT::getVectorVT(MVT::f32, NumElts);
4770-
SDLoc dl(Op);
4771-
if (IsStrict) {
4772-
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NewVT, MVT::Other},
4773-
{Op.getOperand(0), Op.getOperand(1)});
4774-
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
4775-
{Ext.getValue(1), Ext.getValue(0)});
4776-
}
4777-
return DAG.getNode(
4778-
Op.getOpcode(), dl, Op.getValueType(),
4779-
DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
4780-
}
4781-
47824788
uint64_t VTSize = VT.getFixedSizeInBits();
47834789
uint64_t InVTSize = InVT.getFixedSizeInBits();
47844790
if (VTSize < InVTSize) {
@@ -4813,7 +4819,7 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
48134819

48144820
// Use a scalar operation for conversions between single-element vectors of
48154821
// the same size.
4816-
if (NumElts == 1) {
4822+
if (InVT.getVectorNumElements() == 1) {
48174823
SDLoc dl(Op);
48184824
SDValue Extract = DAG.getNode(
48194825
ISD::EXTRACT_VECTOR_ELT, dl, InVT.getScalarType(),
@@ -5059,23 +5065,14 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
50595065
assert(!(IsStrict && VT.isScalableVector()) &&
50605066
"Unimplemented SVE support for ISD:::STRICT_INT_TO_FP!");
50615067

5062-
if (VT.isScalableVector()) {
5063-
if (InVT.getVectorElementType() == MVT::i1) {
5064-
SDValue FalseVal = DAG.getConstantFP(0.0, dl, VT);
5065-
SDValue TrueVal = IsSigned ? DAG.getConstantFP(-1.0, dl, VT)
5066-
: DAG.getConstantFP(1.0, dl, VT);
5067-
return DAG.getNode(ISD::VSELECT, dl, VT, In, TrueVal, FalseVal);
5068-
}
5069-
5070-
unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
5071-
: AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
5072-
return LowerToPredicatedOp(Op, DAG, Opcode);
5068+
// NOTE: i1->bf16 does not require promotion to f32.
5069+
if (VT.isScalableVector() && InVT.getVectorElementType() == MVT::i1) {
5070+
SDValue FalseVal = DAG.getConstantFP(0.0, dl, VT);
5071+
SDValue TrueVal = IsSigned ? DAG.getConstantFP(-1.0, dl, VT)
5072+
: DAG.getConstantFP(1.0, dl, VT);
5073+
return DAG.getNode(ISD::VSELECT, dl, VT, In, TrueVal, FalseVal);
50735074
}
50745075

5075-
if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()) ||
5076-
useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
5077-
return LowerFixedLengthIntToFPToSVE(Op, DAG);
5078-
50795076
// Promote bf16 conversions to f32.
50805077
if (VT.getVectorElementType() == MVT::bf16) {
50815078
EVT F32 = VT.changeElementType(MVT::f32);
@@ -5092,6 +5089,20 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
50925089
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
50935090
}
50945091

5092+
if (VT.isScalableVector()) {
5093+
// Let common code split the operation.
5094+
if (VT == MVT::nxv8f32)
5095+
return Op;
5096+
5097+
unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
5098+
: AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
5099+
return LowerToPredicatedOp(Op, DAG, Opcode);
5100+
}
5101+
5102+
if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()) ||
5103+
useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
5104+
return LowerFixedLengthIntToFPToSVE(Op, DAG);
5105+
50955106
uint64_t VTSize = VT.getFixedSizeInBits();
50965107
uint64_t InVTSize = InVT.getFixedSizeInBits();
50975108
if (VTSize < InVTSize) {

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5465,6 +5465,14 @@ multiclass sve_int_dup_fpimm_pred<string asm> {
54655465
(!cast<Instruction>(NAME # _S) $zd, $pg, fpimm32:$imm8)>;
54665466
def : Pat<(nxv2f64 (vselect nxv2i1:$pg, (splat_vector fpimm64:$imm8), nxv2f64:$zd)),
54675467
(!cast<Instruction>(NAME # _D) $zd, $pg, fpimm64:$imm8)>;
5468+
5469+
// Some half precision immediates alias with bfloat (e.g. f16(1.875) == bf16(1.0)).
5470+
def : Pat<(nxv8bf16 (vselect nxv8i1:$pg, (splat_vector fpimmbf16:$imm8), nxv8bf16:$zd)),
5471+
(!cast<Instruction>(NAME # _H) $zd, $pg, (fpimm16XForm bf16:$imm8))>;
5472+
def : Pat<(nxv4bf16 (vselect nxv4i1:$pg, (splat_vector fpimmbf16:$imm8), nxv4bf16:$zd)),
5473+
(!cast<Instruction>(NAME # _H) $zd, $pg, (fpimm16XForm bf16:$imm8))>;
5474+
def : Pat<(nxv2bf16 (vselect nxv2i1:$pg, (splat_vector fpimmbf16:$imm8), nxv2bf16:$zd)),
5475+
(!cast<Instruction>(NAME # _H) $zd, $pg, (fpimm16XForm bf16:$imm8))>;
54685476
}
54695477

54705478
class sve_int_dup_imm_pred<bits<2> sz8_64, bit m, string asm,

0 commit comments

Comments
 (0)