Skip to content

[LLVM][SVE] Implement isel for bfloat fptoi and itofp operations. #129713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 45 additions & 34 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4582,6 +4582,10 @@ SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1;

if (VT.isScalableVector()) {
// Let common code split the operation.
if (SrcVT == MVT::nxv8f32)
return Op;

if (VT.getScalarType() != MVT::bf16)
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);

Expand Down Expand Up @@ -4724,6 +4728,22 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
assert(!(IsStrict && VT.isScalableVector()) &&
"Unimplemented SVE support for STRICT_FP_to_INT!");

// f16 conversions are promoted to f32 when full fp16 is not supported.
if ((InVT.getVectorElementType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
InVT.getVectorElementType() == MVT::bf16) {
EVT NewVT = VT.changeElementType(MVT::f32);
SDLoc dl(Op);
if (IsStrict) {
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NewVT, MVT::Other},
{Op.getOperand(0), Op.getOperand(1)});
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
{Ext.getValue(1), Ext.getValue(0)});
}
return DAG.getNode(
Op.getOpcode(), dl, Op.getValueType(),
DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
}

if (VT.isScalableVector()) {
if (VT.getVectorElementType() == MVT::i1) {
SDLoc DL(Op);
Expand All @@ -4733,6 +4753,10 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
return DAG.getSetCC(DL, VT, Cvt, Zero, ISD::SETNE);
}

// Let common code split the operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the reason you've put this after the i1 case above is because you believe that when converting nxv8bf16 -> nxv8i1 it's better to do:

FP_EXTEND: nvx8bf16 -> nxv8f32
VectorFP_TO_INT: nxv8f32 -> nxv8i32
SETNE: nxv8f32, zero -> nxv8i1

than

FP_EXTEND: nvx8bf16 -> nxv8f32
VectorFP_TO_INT: nxv8f32 -> nxv8i1

Presumably because you think SETNE will do a better job of splitting with an i1 result element type, than VectorFP_TO_INT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. I only put the bail out code here because it's the next blob of code that definitely doesn't support MVT::nxv8f32. When I move it before the i1 handling the output changes thusly:

 ; CHECK-NEXT:    lsl z0.s, z0.s, #16
 ; CHECK-NEXT:    fcvtzs z1.s, p0/m, z1.s
 ; CHECK-NEXT:    fcvtzs z0.s, p0/m, z0.s
-; CHECK-NEXT:    ptrue p0.h
-; CHECK-NEXT:    uzp1 z0.h, z0.h, z1.h
-; CHECK-NEXT:    cmpne p0.h, p0/z, z0.h, #0
+; CHECK-NEXT:    cmpne p1.s, p0/z, z1.s, #0
+; CHECK-NEXT:    cmpne p0.s, p0/z, z0.s, #0
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h

Looking at the Neoverse SWOG the two compares of the new output look like they'll be serialised and so I might have hit the better output by fluke rather than judgement?

if (InVT == MVT::nxv8f32)
return Op;

unsigned Opcode = Op.getOpcode() == ISD::FP_TO_UINT
? AArch64ISD::FCVTZU_MERGE_PASSTHRU
: AArch64ISD::FCVTZS_MERGE_PASSTHRU;
Expand All @@ -4743,24 +4767,6 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
return LowerFixedLengthFPToIntToSVE(Op, DAG);

unsigned NumElts = InVT.getVectorNumElements();

// f16 conversions are promoted to f32 when full fp16 is not supported.
if ((InVT.getVectorElementType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
InVT.getVectorElementType() == MVT::bf16) {
MVT NewVT = MVT::getVectorVT(MVT::f32, NumElts);
SDLoc dl(Op);
if (IsStrict) {
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NewVT, MVT::Other},
{Op.getOperand(0), Op.getOperand(1)});
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
{Ext.getValue(1), Ext.getValue(0)});
}
return DAG.getNode(
Op.getOpcode(), dl, Op.getValueType(),
DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
}

uint64_t VTSize = VT.getFixedSizeInBits();
uint64_t InVTSize = InVT.getFixedSizeInBits();
if (VTSize < InVTSize) {
Expand Down Expand Up @@ -4795,7 +4801,7 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,

// Use a scalar operation for conversions between single-element vectors of
// the same size.
if (NumElts == 1) {
if (InVT.getVectorNumElements() == 1) {
SDLoc dl(Op);
SDValue Extract = DAG.getNode(
ISD::EXTRACT_VECTOR_ELT, dl, InVT.getScalarType(),
Expand Down Expand Up @@ -5041,23 +5047,14 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
assert(!(IsStrict && VT.isScalableVector()) &&
"Unimplemented SVE support for ISD:::STRICT_INT_TO_FP!");

if (VT.isScalableVector()) {
if (InVT.getVectorElementType() == MVT::i1) {
SDValue FalseVal = DAG.getConstantFP(0.0, dl, VT);
SDValue TrueVal = IsSigned ? DAG.getConstantFP(-1.0, dl, VT)
: DAG.getConstantFP(1.0, dl, VT);
return DAG.getNode(ISD::VSELECT, dl, VT, In, TrueVal, FalseVal);
}

unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
: AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
return LowerToPredicatedOp(Op, DAG, Opcode);
// NOTE: i1->bf16 does not require promotion to f32.
if (VT.isScalableVector() && InVT.getVectorElementType() == MVT::i1) {
SDValue FalseVal = DAG.getConstantFP(0.0, dl, VT);
SDValue TrueVal = IsSigned ? DAG.getConstantFP(-1.0, dl, VT)
: DAG.getConstantFP(1.0, dl, VT);
return DAG.getNode(ISD::VSELECT, dl, VT, In, TrueVal, FalseVal);
}

if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()) ||
useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
return LowerFixedLengthIntToFPToSVE(Op, DAG);

// Promote bf16 conversions to f32.
if (VT.getVectorElementType() == MVT::bf16) {
EVT F32 = VT.changeElementType(MVT::f32);
Expand All @@ -5074,6 +5071,20 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
}

if (VT.isScalableVector()) {
// Let common code split the operation.
if (VT == MVT::nxv8f32)
return Op;

unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
: AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
return LowerToPredicatedOp(Op, DAG, Opcode);
}

if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()) ||
useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
return LowerFixedLengthIntToFPToSVE(Op, DAG);

uint64_t VTSize = VT.getFixedSizeInBits();
uint64_t InVTSize = InVT.getFixedSizeInBits();
if (VTSize < InVTSize) {
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/AArch64/SVEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -5465,6 +5465,14 @@ multiclass sve_int_dup_fpimm_pred<string asm> {
(!cast<Instruction>(NAME # _S) $zd, $pg, fpimm32:$imm8)>;
def : Pat<(nxv2f64 (vselect nxv2i1:$pg, (splat_vector fpimm64:$imm8), nxv2f64:$zd)),
(!cast<Instruction>(NAME # _D) $zd, $pg, fpimm64:$imm8)>;

// Some half precision immediates alias with bfloat (e.g. f16(1.875) == bf16(1.0)).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment implies that some don't, so what happens if fpimmbf16 matches a value that the fp16 variant doesn't have? Or should the comment actually be something like All fpimmbf16 immediates alias with a FP16 immediate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand. The comment is saying "some" half precision immediates alias with bfloat and the way this is achieved is by using the fpimmbf16 complex pattern that will only let the safe ones through.

All fpimmbf16 immediates alias with a FP16 immediate is obvious from its use because that's how isel works. I can remove the comment if you feel it offers no value? I only added it just in case somebody wondered why we have bfloat patterns for an instruction that doesn't really support bfloat.

def : Pat<(nxv8bf16 (vselect nxv8i1:$pg, (splat_vector fpimmbf16:$imm8), nxv8bf16:$zd)),
(!cast<Instruction>(NAME # _H) $zd, $pg, (fpimm16XForm bf16:$imm8))>;
def : Pat<(nxv4bf16 (vselect nxv4i1:$pg, (splat_vector fpimmbf16:$imm8), nxv4bf16:$zd)),
(!cast<Instruction>(NAME # _H) $zd, $pg, (fpimm16XForm bf16:$imm8))>;
def : Pat<(nxv2bf16 (vselect nxv2i1:$pg, (splat_vector fpimmbf16:$imm8), nxv2bf16:$zd)),
(!cast<Instruction>(NAME # _H) $zd, $pg, (fpimm16XForm bf16:$imm8))>;
}

class sve_int_dup_imm_pred<bits<2> sz8_64, bit m, string asm,
Expand Down
Loading