Skip to content

Commit 1729e6e

Browse files
authored
[AArch64] Improve bf16 fp_extend lowering. (#118966)
A bf16 fp_extend is just a shift into the higher bits. This changes the lowering from using a relatively ugly tablegen pattern, to ISel generating the shift using an extended vector. This is cleaner and should optimize better. StrictFP goes through the same route as it cannot round or set flags.
1 parent 5a7dfb4 commit 1729e6e

11 files changed

+1257
-2042
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
753753
setOperationAction(Op, MVT::v8bf16, Expand);
754754
}
755755

756+
// For bf16, fpextend is custom lowered to be optionally expanded into shifts.
757+
setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom);
758+
setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
759+
setOperationAction(ISD::FP_EXTEND, MVT::v4f32, Custom);
760+
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom);
761+
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom);
762+
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v4f32, Custom);
763+
756764
auto LegalizeNarrowFP = [this](MVT ScalarVT) {
757765
for (auto Op : {
758766
ISD::SETCC,
@@ -893,10 +901,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
893901
setOperationAction(Op, MVT::f16, Legal);
894902
}
895903

896-
// Strict conversion to a larger type is legal
897-
for (auto VT : {MVT::f32, MVT::f64})
898-
setOperationAction(ISD::STRICT_FP_EXTEND, VT, Legal);
899-
900904
setOperationAction(ISD::PREFETCH, MVT::Other, Custom);
901905

902906
setOperationAction(ISD::GET_ROUNDING, MVT::i32, Custom);
@@ -4498,6 +4502,54 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
44984502
if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
44994503
return LowerFixedLengthFPExtendToSVE(Op, DAG);
45004504

4505+
bool IsStrict = Op->isStrictFPOpcode();
4506+
SDValue Op0 = Op.getOperand(IsStrict ? 1 : 0);
4507+
EVT Op0VT = Op0.getValueType();
4508+
if (VT == MVT::f64) {
4509+
// FP16->FP32 extends are legal for v32 and v4f32.
4510+
if (Op0VT == MVT::f32 || Op0VT == MVT::f16)
4511+
return Op;
4512+
// Split bf16->f64 extends into two fpextends.
4513+
if (Op0VT == MVT::bf16 && IsStrict) {
4514+
SDValue Ext1 =
4515+
DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {MVT::f32, MVT::Other},
4516+
{Op0, Op.getOperand(0)});
4517+
return DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {VT, MVT::Other},
4518+
{Ext1, Ext1.getValue(1)});
4519+
}
4520+
if (Op0VT == MVT::bf16)
4521+
return DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), VT,
4522+
DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, Op0));
4523+
return SDValue();
4524+
}
4525+
4526+
if (VT.getScalarType() == MVT::f32) {
4527+
// FP16->FP32 extends are legal for v32 and v4f32.
4528+
if (Op0VT.getScalarType() == MVT::f16)
4529+
return Op;
4530+
if (Op0VT.getScalarType() == MVT::bf16) {
4531+
SDLoc DL(Op);
4532+
EVT IVT = VT.changeTypeToInteger();
4533+
if (!Op0VT.isVector()) {
4534+
Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4bf16, Op0);
4535+
IVT = MVT::v4i32;
4536+
}
4537+
4538+
EVT Op0IVT = Op0.getValueType().changeTypeToInteger();
4539+
SDValue Ext =
4540+
DAG.getNode(ISD::ANY_EXTEND, DL, IVT, DAG.getBitcast(Op0IVT, Op0));
4541+
SDValue Shift =
4542+
DAG.getNode(ISD::SHL, DL, IVT, Ext, DAG.getConstant(16, DL, IVT));
4543+
if (!Op0VT.isVector())
4544+
Shift = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, Shift,
4545+
DAG.getConstant(0, DL, MVT::i64));
4546+
Shift = DAG.getBitcast(VT, Shift);
4547+
return IsStrict ? DAG.getMergeValues({Shift, Op.getOperand(0)}, DL)
4548+
: Shift;
4549+
}
4550+
return SDValue();
4551+
}
4552+
45014553
assert(Op.getValueType() == MVT::f128 && "Unexpected lowering");
45024554
return SDValue();
45034555
}
@@ -7345,6 +7397,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
73457397
case ISD::STRICT_FP_ROUND:
73467398
return LowerFP_ROUND(Op, DAG);
73477399
case ISD::FP_EXTEND:
7400+
case ISD::STRICT_FP_EXTEND:
73487401
return LowerFP_EXTEND(Op, DAG);
73497402
case ISD::FRAMEADDR:
73507403
return LowerFRAMEADDR(Op, DAG);

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5123,22 +5123,6 @@ let Predicates = [HasFullFP16] in {
51235123
//===----------------------------------------------------------------------===//
51245124

51255125
defm FCVT : FPConversion<"fcvt">;
5126-
// Helper to get bf16 into fp32.
5127-
def cvt_bf16_to_fp32 :
5128-
OutPatFrag<(ops node:$Rn),
5129-
(f32 (COPY_TO_REGCLASS
5130-
(i32 (UBFMWri
5131-
(i32 (COPY_TO_REGCLASS (INSERT_SUBREG (f32 (IMPLICIT_DEF)),
5132-
node:$Rn, hsub), GPR32)),
5133-
(i64 (i32shift_a (i64 16))),
5134-
(i64 (i32shift_b (i64 16))))),
5135-
FPR32))>;
5136-
// Pattern for bf16 -> fp32.
5137-
def : Pat<(f32 (any_fpextend (bf16 FPR16:$Rn))),
5138-
(cvt_bf16_to_fp32 FPR16:$Rn)>;
5139-
// Pattern for bf16 -> fp64.
5140-
def : Pat<(f64 (any_fpextend (bf16 FPR16:$Rn))),
5141-
(FCVTDSr (f32 (cvt_bf16_to_fp32 FPR16:$Rn)))>;
51425126

51435127
//===----------------------------------------------------------------------===//
51445128
// Floating point single operand instructions.
@@ -8333,8 +8317,6 @@ def : Pat<(v4i32 (anyext (v4i16 V64:$Rn))), (USHLLv4i16_shift V64:$Rn, (i32 0))>
83338317
def : Pat<(v2i64 (sext (v2i32 V64:$Rn))), (SSHLLv2i32_shift V64:$Rn, (i32 0))>;
83348318
def : Pat<(v2i64 (zext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
83358319
def : Pat<(v2i64 (anyext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
8336-
// Vector bf16 -> fp32 is implemented morally as a zext + shift.
8337-
def : Pat<(v4f32 (any_fpextend (v4bf16 V64:$Rn))), (SHLLv4i16 V64:$Rn)>;
83388320
// Also match an extend from the upper half of a 128 bit source register.
83398321
def : Pat<(v8i16 (anyext (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn)) ))),
83408322
(USHLLv16i8_shift V128:$Rn, (i32 0))>;

llvm/test/CodeGen/AArch64/arm64-fast-isel-conversion-fallback.ll

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,10 @@ define i32 @fptosi_bf(bfloat %a) nounwind ssp {
156156
; CHECK-LABEL: fptosi_bf:
157157
; CHECK: // %bb.0: // %entry
158158
; CHECK-NEXT: fmov s1, s0
159-
; CHECK-NEXT: // implicit-def: $s0
159+
; CHECK-NEXT: // implicit-def: $d0
160160
; CHECK-NEXT: fmov s0, s1
161-
; CHECK-NEXT: fmov w8, s0
162-
; CHECK-NEXT: lsl w8, w8, #16
163-
; CHECK-NEXT: fmov s0, w8
161+
; CHECK-NEXT: shll v0.4s, v0.4h, #16
162+
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $q0
164163
; CHECK-NEXT: fcvtzs w0, s0
165164
; CHECK-NEXT: ret
166165
entry:
@@ -173,11 +172,10 @@ define i32 @fptoui_sbf(bfloat %a) nounwind ssp {
173172
; CHECK-LABEL: fptoui_sbf:
174173
; CHECK: // %bb.0: // %entry
175174
; CHECK-NEXT: fmov s1, s0
176-
; CHECK-NEXT: // implicit-def: $s0
175+
; CHECK-NEXT: // implicit-def: $d0
177176
; CHECK-NEXT: fmov s0, s1
178-
; CHECK-NEXT: fmov w8, s0
179-
; CHECK-NEXT: lsl w8, w8, #16
180-
; CHECK-NEXT: fmov s0, w8
177+
; CHECK-NEXT: shll v0.4s, v0.4h, #16
178+
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $q0
181179
; CHECK-NEXT: fcvtzu w0, s0
182180
; CHECK-NEXT: ret
183181
entry:

llvm/test/CodeGen/AArch64/atomicrmw-fadd.ll

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,14 @@ define half @test_atomicrmw_fadd_f16_seq_cst_align4(ptr %ptr, half %value) #0 {
182182
define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align2(ptr %ptr, bfloat %value) #0 {
183183
; NOLSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align2:
184184
; NOLSE: // %bb.0:
185-
; NOLSE-NEXT: // kill: def $h0 killed $h0 def $s0
186-
; NOLSE-NEXT: fmov w9, s0
185+
; NOLSE-NEXT: // kill: def $h0 killed $h0 def $d0
186+
; NOLSE-NEXT: shll v1.4s, v0.4h, #16
187187
; NOLSE-NEXT: mov w8, #32767 // =0x7fff
188-
; NOLSE-NEXT: lsl w9, w9, #16
189-
; NOLSE-NEXT: fmov s1, w9
190188
; NOLSE-NEXT: .LBB2_1: // %atomicrmw.start
191189
; NOLSE-NEXT: // =>This Inner Loop Header: Depth=1
192190
; NOLSE-NEXT: ldaxrh w9, [x0]
193191
; NOLSE-NEXT: fmov s0, w9
194-
; NOLSE-NEXT: lsl w9, w9, #16
195-
; NOLSE-NEXT: fmov s2, w9
192+
; NOLSE-NEXT: shll v2.4s, v0.4h, #16
196193
; NOLSE-NEXT: fadd s2, s2, s1
197194
; NOLSE-NEXT: fmov w9, s2
198195
; NOLSE-NEXT: ubfx w10, w9, #16, #1
@@ -202,36 +199,34 @@ define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align2(ptr %ptr, bfloat %value)
202199
; NOLSE-NEXT: stlxrh w10, w9, [x0]
203200
; NOLSE-NEXT: cbnz w10, .LBB2_1
204201
; NOLSE-NEXT: // %bb.2: // %atomicrmw.end
205-
; NOLSE-NEXT: // kill: def $h0 killed $h0 killed $s0
202+
; NOLSE-NEXT: // kill: def $h0 killed $h0 killed $d0
206203
; NOLSE-NEXT: ret
207204
;
208205
; LSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align2:
209206
; LSE: // %bb.0:
210-
; LSE-NEXT: // kill: def $h0 killed $h0 def $s0
211-
; LSE-NEXT: fmov w9, s0
207+
; LSE-NEXT: // kill: def $h0 killed $h0 def $d0
208+
; LSE-NEXT: shll v1.4s, v0.4h, #16
212209
; LSE-NEXT: mov w8, #32767 // =0x7fff
213210
; LSE-NEXT: ldr h0, [x0]
214-
; LSE-NEXT: lsl w9, w9, #16
215-
; LSE-NEXT: fmov s1, w9
216211
; LSE-NEXT: .LBB2_1: // %atomicrmw.start
217212
; LSE-NEXT: // =>This Inner Loop Header: Depth=1
218-
; LSE-NEXT: fmov w9, s0
219-
; LSE-NEXT: lsl w9, w9, #16
220-
; LSE-NEXT: fmov s2, w9
213+
; LSE-NEXT: shll v2.4s, v0.4h, #16
221214
; LSE-NEXT: fadd s2, s2, s1
222215
; LSE-NEXT: fmov w9, s2
223216
; LSE-NEXT: ubfx w10, w9, #16, #1
224217
; LSE-NEXT: add w9, w9, w8
225218
; LSE-NEXT: add w9, w10, w9
226-
; LSE-NEXT: fmov w10, s0
227219
; LSE-NEXT: lsr w9, w9, #16
228-
; LSE-NEXT: mov w11, w10
229-
; LSE-NEXT: casalh w11, w9, [x0]
220+
; LSE-NEXT: fmov s2, w9
221+
; LSE-NEXT: fmov w9, s0
222+
; LSE-NEXT: fmov w10, s2
223+
; LSE-NEXT: mov w11, w9
224+
; LSE-NEXT: casalh w11, w10, [x0]
230225
; LSE-NEXT: fmov s0, w11
231-
; LSE-NEXT: cmp w11, w10, uxth
226+
; LSE-NEXT: cmp w11, w9, uxth
232227
; LSE-NEXT: b.ne .LBB2_1
233228
; LSE-NEXT: // %bb.2: // %atomicrmw.end
234-
; LSE-NEXT: // kill: def $h0 killed $h0 killed $s0
229+
; LSE-NEXT: // kill: def $h0 killed $h0 killed $d0
235230
; LSE-NEXT: ret
236231
;
237232
; SOFTFP-NOLSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align2:
@@ -281,17 +276,14 @@ define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align2(ptr %ptr, bfloat %value)
281276
define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align4(ptr %ptr, bfloat %value) #0 {
282277
; NOLSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align4:
283278
; NOLSE: // %bb.0:
284-
; NOLSE-NEXT: // kill: def $h0 killed $h0 def $s0
285-
; NOLSE-NEXT: fmov w9, s0
279+
; NOLSE-NEXT: // kill: def $h0 killed $h0 def $d0
280+
; NOLSE-NEXT: shll v1.4s, v0.4h, #16
286281
; NOLSE-NEXT: mov w8, #32767 // =0x7fff
287-
; NOLSE-NEXT: lsl w9, w9, #16
288-
; NOLSE-NEXT: fmov s1, w9
289282
; NOLSE-NEXT: .LBB3_1: // %atomicrmw.start
290283
; NOLSE-NEXT: // =>This Inner Loop Header: Depth=1
291284
; NOLSE-NEXT: ldaxrh w9, [x0]
292285
; NOLSE-NEXT: fmov s0, w9
293-
; NOLSE-NEXT: lsl w9, w9, #16
294-
; NOLSE-NEXT: fmov s2, w9
286+
; NOLSE-NEXT: shll v2.4s, v0.4h, #16
295287
; NOLSE-NEXT: fadd s2, s2, s1
296288
; NOLSE-NEXT: fmov w9, s2
297289
; NOLSE-NEXT: ubfx w10, w9, #16, #1
@@ -301,36 +293,34 @@ define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align4(ptr %ptr, bfloat %value)
301293
; NOLSE-NEXT: stlxrh w10, w9, [x0]
302294
; NOLSE-NEXT: cbnz w10, .LBB3_1
303295
; NOLSE-NEXT: // %bb.2: // %atomicrmw.end
304-
; NOLSE-NEXT: // kill: def $h0 killed $h0 killed $s0
296+
; NOLSE-NEXT: // kill: def $h0 killed $h0 killed $d0
305297
; NOLSE-NEXT: ret
306298
;
307299
; LSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align4:
308300
; LSE: // %bb.0:
309-
; LSE-NEXT: // kill: def $h0 killed $h0 def $s0
310-
; LSE-NEXT: fmov w9, s0
301+
; LSE-NEXT: // kill: def $h0 killed $h0 def $d0
302+
; LSE-NEXT: shll v1.4s, v0.4h, #16
311303
; LSE-NEXT: mov w8, #32767 // =0x7fff
312304
; LSE-NEXT: ldr h0, [x0]
313-
; LSE-NEXT: lsl w9, w9, #16
314-
; LSE-NEXT: fmov s1, w9
315305
; LSE-NEXT: .LBB3_1: // %atomicrmw.start
316306
; LSE-NEXT: // =>This Inner Loop Header: Depth=1
317-
; LSE-NEXT: fmov w9, s0
318-
; LSE-NEXT: lsl w9, w9, #16
319-
; LSE-NEXT: fmov s2, w9
307+
; LSE-NEXT: shll v2.4s, v0.4h, #16
320308
; LSE-NEXT: fadd s2, s2, s1
321309
; LSE-NEXT: fmov w9, s2
322310
; LSE-NEXT: ubfx w10, w9, #16, #1
323311
; LSE-NEXT: add w9, w9, w8
324312
; LSE-NEXT: add w9, w10, w9
325-
; LSE-NEXT: fmov w10, s0
326313
; LSE-NEXT: lsr w9, w9, #16
327-
; LSE-NEXT: mov w11, w10
328-
; LSE-NEXT: casalh w11, w9, [x0]
314+
; LSE-NEXT: fmov s2, w9
315+
; LSE-NEXT: fmov w9, s0
316+
; LSE-NEXT: fmov w10, s2
317+
; LSE-NEXT: mov w11, w9
318+
; LSE-NEXT: casalh w11, w10, [x0]
329319
; LSE-NEXT: fmov s0, w11
330-
; LSE-NEXT: cmp w11, w10, uxth
320+
; LSE-NEXT: cmp w11, w9, uxth
331321
; LSE-NEXT: b.ne .LBB3_1
332322
; LSE-NEXT: // %bb.2: // %atomicrmw.end
333-
; LSE-NEXT: // kill: def $h0 killed $h0 killed $s0
323+
; LSE-NEXT: // kill: def $h0 killed $h0 killed $d0
334324
; LSE-NEXT: ret
335325
;
336326
; SOFTFP-NOLSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align4:

0 commit comments

Comments
 (0)