Skip to content

[AArch64] Improve bf16 fp_extend lowering. #118966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(Op, MVT::v8bf16, Expand);
}

// For bf16, fpextend is custom lowered to be optionally expanded into shifts.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: typo I think. lowered to optionally -> lowered to be optionally

setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom);
setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
setOperationAction(ISD::FP_EXTEND, MVT::v4f32, Custom);
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom);
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom);
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v4f32, Custom);

auto LegalizeNarrowFP = [this](MVT ScalarVT) {
for (auto Op : {
ISD::SETCC,
Expand Down Expand Up @@ -893,10 +901,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(Op, MVT::f16, Legal);
}

// Strict conversion to a larger type is legal
for (auto VT : {MVT::f32, MVT::f64})
setOperationAction(ISD::STRICT_FP_EXTEND, VT, Legal);

setOperationAction(ISD::PREFETCH, MVT::Other, Custom);

setOperationAction(ISD::GET_ROUNDING, MVT::i32, Custom);
Expand Down Expand Up @@ -4498,6 +4502,54 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
return LowerFixedLengthFPExtendToSVE(Op, DAG);

bool IsStrict = Op->isStrictFPOpcode();
SDValue Op0 = Op.getOperand(IsStrict ? 1 : 0);
EVT Op0VT = Op0.getValueType();
if (VT == MVT::f64) {
// FP16->FP32 extends are legal for v32 and v4f32.
if (Op0VT == MVT::f32 || Op0VT == MVT::f16)
return Op;
// Split bf16->f64 extends into two fpextends.
if (Op0VT == MVT::bf16 && IsStrict) {
SDValue Ext1 =
DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {MVT::f32, MVT::Other},
{Op0, Op.getOperand(0)});
return DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {VT, MVT::Other},
{Ext1, Ext1.getValue(1)});
}
if (Op0VT == MVT::bf16)
return DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), VT,
DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, Op0));
return SDValue();
}

if (VT.getScalarType() == MVT::f32) {
// FP16->FP32 extends are legal for v32 and v4f32.
if (Op0VT.getScalarType() == MVT::f16)
return Op;
if (Op0VT.getScalarType() == MVT::bf16) {
SDLoc DL(Op);
EVT IVT = VT.changeTypeToInteger();
if (!Op0VT.isVector()) {
Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4bf16, Op0);
IVT = MVT::v4i32;
}

EVT Op0IVT = Op0.getValueType().changeTypeToInteger();
SDValue Ext =
DAG.getNode(ISD::ANY_EXTEND, DL, IVT, DAG.getBitcast(Op0IVT, Op0));
SDValue Shift =
DAG.getNode(ISD::SHL, DL, IVT, Ext, DAG.getConstant(16, DL, IVT));
if (!Op0VT.isVector())
Shift = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, Shift,
DAG.getConstant(0, DL, MVT::i64));
Shift = DAG.getBitcast(VT, Shift);
return IsStrict ? DAG.getMergeValues({Shift, Op.getOperand(0)}, DL)
: Shift;
}
return SDValue();
}

assert(Op.getValueType() == MVT::f128 && "Unexpected lowering");
return SDValue();
}
Expand Down Expand Up @@ -7345,6 +7397,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::STRICT_FP_ROUND:
return LowerFP_ROUND(Op, DAG);
case ISD::FP_EXTEND:
case ISD::STRICT_FP_EXTEND:
return LowerFP_EXTEND(Op, DAG);
case ISD::FRAMEADDR:
return LowerFRAMEADDR(Op, DAG);
Expand Down
18 changes: 0 additions & 18 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -5123,22 +5123,6 @@ let Predicates = [HasFullFP16] in {
//===----------------------------------------------------------------------===//

defm FCVT : FPConversion<"fcvt">;
// Helper to get bf16 into fp32.
def cvt_bf16_to_fp32 :
OutPatFrag<(ops node:$Rn),
(f32 (COPY_TO_REGCLASS
(i32 (UBFMWri
(i32 (COPY_TO_REGCLASS (INSERT_SUBREG (f32 (IMPLICIT_DEF)),
node:$Rn, hsub), GPR32)),
(i64 (i32shift_a (i64 16))),
(i64 (i32shift_b (i64 16))))),
FPR32))>;
// Pattern for bf16 -> fp32.
def : Pat<(f32 (any_fpextend (bf16 FPR16:$Rn))),
(cvt_bf16_to_fp32 FPR16:$Rn)>;
// Pattern for bf16 -> fp64.
def : Pat<(f64 (any_fpextend (bf16 FPR16:$Rn))),
(FCVTDSr (f32 (cvt_bf16_to_fp32 FPR16:$Rn)))>;

//===----------------------------------------------------------------------===//
// Floating point single operand instructions.
Expand Down Expand Up @@ -8333,8 +8317,6 @@ def : Pat<(v4i32 (anyext (v4i16 V64:$Rn))), (USHLLv4i16_shift V64:$Rn, (i32 0))>
def : Pat<(v2i64 (sext (v2i32 V64:$Rn))), (SSHLLv2i32_shift V64:$Rn, (i32 0))>;
def : Pat<(v2i64 (zext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
def : Pat<(v2i64 (anyext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
// Vector bf16 -> fp32 is implemented morally as a zext + shift.
def : Pat<(v4f32 (any_fpextend (v4bf16 V64:$Rn))), (SHLLv4i16 V64:$Rn)>;
// Also match an extend from the upper half of a 128 bit source register.
def : Pat<(v8i16 (anyext (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn)) ))),
(USHLLv16i8_shift V128:$Rn, (i32 0))>;
Expand Down
14 changes: 6 additions & 8 deletions llvm/test/CodeGen/AArch64/arm64-fast-isel-conversion-fallback.ll
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,10 @@ define i32 @fptosi_bf(bfloat %a) nounwind ssp {
; CHECK-LABEL: fptosi_bf:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: fmov s1, s0
; CHECK-NEXT: // implicit-def: $s0
; CHECK-NEXT: // implicit-def: $d0
; CHECK-NEXT: fmov s0, s1
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: lsl w8, w8, #16
; CHECK-NEXT: fmov s0, w8
; CHECK-NEXT: shll v0.4s, v0.4h, #16
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $q0
; CHECK-NEXT: fcvtzs w0, s0
; CHECK-NEXT: ret
entry:
Expand All @@ -173,11 +172,10 @@ define i32 @fptoui_sbf(bfloat %a) nounwind ssp {
; CHECK-LABEL: fptoui_sbf:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: fmov s1, s0
; CHECK-NEXT: // implicit-def: $s0
; CHECK-NEXT: // implicit-def: $d0
; CHECK-NEXT: fmov s0, s1
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: lsl w8, w8, #16
; CHECK-NEXT: fmov s0, w8
; CHECK-NEXT: shll v0.4s, v0.4h, #16
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $q0
; CHECK-NEXT: fcvtzu w0, s0
; CHECK-NEXT: ret
entry:
Expand Down
66 changes: 28 additions & 38 deletions llvm/test/CodeGen/AArch64/atomicrmw-fadd.ll
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,14 @@ define half @test_atomicrmw_fadd_f16_seq_cst_align4(ptr %ptr, half %value) #0 {
define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align2(ptr %ptr, bfloat %value) #0 {
; NOLSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align2:
; NOLSE: // %bb.0:
; NOLSE-NEXT: // kill: def $h0 killed $h0 def $s0
; NOLSE-NEXT: fmov w9, s0
; NOLSE-NEXT: // kill: def $h0 killed $h0 def $d0
; NOLSE-NEXT: shll v1.4s, v0.4h, #16
; NOLSE-NEXT: mov w8, #32767 // =0x7fff
; NOLSE-NEXT: lsl w9, w9, #16
; NOLSE-NEXT: fmov s1, w9
; NOLSE-NEXT: .LBB2_1: // %atomicrmw.start
; NOLSE-NEXT: // =>This Inner Loop Header: Depth=1
; NOLSE-NEXT: ldaxrh w9, [x0]
; NOLSE-NEXT: fmov s0, w9
; NOLSE-NEXT: lsl w9, w9, #16
; NOLSE-NEXT: fmov s2, w9
; NOLSE-NEXT: shll v2.4s, v0.4h, #16
; NOLSE-NEXT: fadd s2, s2, s1
; NOLSE-NEXT: fmov w9, s2
; NOLSE-NEXT: ubfx w10, w9, #16, #1
Expand All @@ -202,36 +199,34 @@ define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align2(ptr %ptr, bfloat %value)
; NOLSE-NEXT: stlxrh w10, w9, [x0]
; NOLSE-NEXT: cbnz w10, .LBB2_1
; NOLSE-NEXT: // %bb.2: // %atomicrmw.end
; NOLSE-NEXT: // kill: def $h0 killed $h0 killed $s0
; NOLSE-NEXT: // kill: def $h0 killed $h0 killed $d0
; NOLSE-NEXT: ret
;
; LSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align2:
; LSE: // %bb.0:
; LSE-NEXT: // kill: def $h0 killed $h0 def $s0
; LSE-NEXT: fmov w9, s0
; LSE-NEXT: // kill: def $h0 killed $h0 def $d0
; LSE-NEXT: shll v1.4s, v0.4h, #16
; LSE-NEXT: mov w8, #32767 // =0x7fff
; LSE-NEXT: ldr h0, [x0]
; LSE-NEXT: lsl w9, w9, #16
; LSE-NEXT: fmov s1, w9
; LSE-NEXT: .LBB2_1: // %atomicrmw.start
; LSE-NEXT: // =>This Inner Loop Header: Depth=1
; LSE-NEXT: fmov w9, s0
; LSE-NEXT: lsl w9, w9, #16
; LSE-NEXT: fmov s2, w9
; LSE-NEXT: shll v2.4s, v0.4h, #16
; LSE-NEXT: fadd s2, s2, s1
; LSE-NEXT: fmov w9, s2
; LSE-NEXT: ubfx w10, w9, #16, #1
; LSE-NEXT: add w9, w9, w8
; LSE-NEXT: add w9, w10, w9
; LSE-NEXT: fmov w10, s0
; LSE-NEXT: lsr w9, w9, #16
; LSE-NEXT: mov w11, w10
; LSE-NEXT: casalh w11, w9, [x0]
; LSE-NEXT: fmov s2, w9
; LSE-NEXT: fmov w9, s0
; LSE-NEXT: fmov w10, s2
; LSE-NEXT: mov w11, w9
; LSE-NEXT: casalh w11, w10, [x0]
; LSE-NEXT: fmov s0, w11
; LSE-NEXT: cmp w11, w10, uxth
; LSE-NEXT: cmp w11, w9, uxth
; LSE-NEXT: b.ne .LBB2_1
; LSE-NEXT: // %bb.2: // %atomicrmw.end
; LSE-NEXT: // kill: def $h0 killed $h0 killed $s0
; LSE-NEXT: // kill: def $h0 killed $h0 killed $d0
; LSE-NEXT: ret
;
; SOFTFP-NOLSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align2:
Expand Down Expand Up @@ -281,17 +276,14 @@ define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align2(ptr %ptr, bfloat %value)
define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align4(ptr %ptr, bfloat %value) #0 {
; NOLSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align4:
; NOLSE: // %bb.0:
; NOLSE-NEXT: // kill: def $h0 killed $h0 def $s0
; NOLSE-NEXT: fmov w9, s0
; NOLSE-NEXT: // kill: def $h0 killed $h0 def $d0
; NOLSE-NEXT: shll v1.4s, v0.4h, #16
; NOLSE-NEXT: mov w8, #32767 // =0x7fff
; NOLSE-NEXT: lsl w9, w9, #16
; NOLSE-NEXT: fmov s1, w9
; NOLSE-NEXT: .LBB3_1: // %atomicrmw.start
; NOLSE-NEXT: // =>This Inner Loop Header: Depth=1
; NOLSE-NEXT: ldaxrh w9, [x0]
; NOLSE-NEXT: fmov s0, w9
; NOLSE-NEXT: lsl w9, w9, #16
; NOLSE-NEXT: fmov s2, w9
; NOLSE-NEXT: shll v2.4s, v0.4h, #16
; NOLSE-NEXT: fadd s2, s2, s1
; NOLSE-NEXT: fmov w9, s2
; NOLSE-NEXT: ubfx w10, w9, #16, #1
Expand All @@ -301,36 +293,34 @@ define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align4(ptr %ptr, bfloat %value)
; NOLSE-NEXT: stlxrh w10, w9, [x0]
; NOLSE-NEXT: cbnz w10, .LBB3_1
; NOLSE-NEXT: // %bb.2: // %atomicrmw.end
; NOLSE-NEXT: // kill: def $h0 killed $h0 killed $s0
; NOLSE-NEXT: // kill: def $h0 killed $h0 killed $d0
; NOLSE-NEXT: ret
;
; LSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align4:
; LSE: // %bb.0:
; LSE-NEXT: // kill: def $h0 killed $h0 def $s0
; LSE-NEXT: fmov w9, s0
; LSE-NEXT: // kill: def $h0 killed $h0 def $d0
; LSE-NEXT: shll v1.4s, v0.4h, #16
; LSE-NEXT: mov w8, #32767 // =0x7fff
; LSE-NEXT: ldr h0, [x0]
; LSE-NEXT: lsl w9, w9, #16
; LSE-NEXT: fmov s1, w9
; LSE-NEXT: .LBB3_1: // %atomicrmw.start
; LSE-NEXT: // =>This Inner Loop Header: Depth=1
; LSE-NEXT: fmov w9, s0
; LSE-NEXT: lsl w9, w9, #16
; LSE-NEXT: fmov s2, w9
; LSE-NEXT: shll v2.4s, v0.4h, #16
; LSE-NEXT: fadd s2, s2, s1
; LSE-NEXT: fmov w9, s2
; LSE-NEXT: ubfx w10, w9, #16, #1
; LSE-NEXT: add w9, w9, w8
; LSE-NEXT: add w9, w10, w9
; LSE-NEXT: fmov w10, s0
; LSE-NEXT: lsr w9, w9, #16
; LSE-NEXT: mov w11, w10
; LSE-NEXT: casalh w11, w9, [x0]
; LSE-NEXT: fmov s2, w9
; LSE-NEXT: fmov w9, s0
; LSE-NEXT: fmov w10, s2
; LSE-NEXT: mov w11, w9
; LSE-NEXT: casalh w11, w10, [x0]
; LSE-NEXT: fmov s0, w11
; LSE-NEXT: cmp w11, w10, uxth
; LSE-NEXT: cmp w11, w9, uxth
; LSE-NEXT: b.ne .LBB3_1
; LSE-NEXT: // %bb.2: // %atomicrmw.end
; LSE-NEXT: // kill: def $h0 killed $h0 killed $s0
; LSE-NEXT: // kill: def $h0 killed $h0 killed $d0
; LSE-NEXT: ret
;
; SOFTFP-NOLSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align4:
Expand Down
Loading
Loading