Skip to content

Commit 5ecdf69

Browse files
committed
[AArch64] Improve bf16 fp_extend lowering.
A bf16 fp_extend is just a shift into the higher bits. This changes the lowering from using a relatively ugly tablegen pattern, to ISel generating the shift using an extended vector. This is cleaner and should optimize better. StrictFP goes through the same route as it cannot round or set flags.
1 parent 76db473 commit 5ecdf69

11 files changed

+1257
-2042
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
753753
setOperationAction(Op, MVT::v8bf16, Expand);
754754
}
755755

756+
// For bf16, fpextend is custom lowered to optionally expanded into shifts.
757+
setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom);
758+
setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
759+
setOperationAction(ISD::FP_EXTEND, MVT::v4f32, Custom);
760+
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom);
761+
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom);
762+
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v4f32, Custom);
763+
756764
auto LegalizeNarrowFP = [this](MVT ScalarVT) {
757765
for (auto Op : {
758766
ISD::SETCC,
@@ -893,10 +901,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
893901
setOperationAction(Op, MVT::f16, Legal);
894902
}
895903

896-
// Strict conversion to a larger type is legal
897-
for (auto VT : {MVT::f32, MVT::f64})
898-
setOperationAction(ISD::STRICT_FP_EXTEND, VT, Legal);
899-
900904
setOperationAction(ISD::PREFETCH, MVT::Other, Custom);
901905

902906
setOperationAction(ISD::GET_ROUNDING, MVT::i32, Custom);
@@ -4419,6 +4423,54 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
44194423
if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
44204424
return LowerFixedLengthFPExtendToSVE(Op, DAG);
44214425

4426+
bool IsStrict = Op->isStrictFPOpcode();
4427+
SDValue Op0 = Op.getOperand(IsStrict ? 1 : 0);
4428+
EVT Op0VT = Op0.getValueType();
4429+
if (VT == MVT::f64) {
4430+
// FP16->FP32 extends are legal for v32 and v4f32.
4431+
if (Op0VT == MVT::f32 || Op0VT == MVT::f16)
4432+
return Op;
4433+
// Split bf16->f64 extends into two fpextends.
4434+
if (Op0VT == MVT::bf16 && IsStrict) {
4435+
SDValue Ext1 =
4436+
DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {MVT::f32, MVT::Other},
4437+
{Op0, Op.getOperand(0)});
4438+
return DAG.getNode(ISD::STRICT_FP_EXTEND, SDLoc(Op), {VT, MVT::Other},
4439+
{Ext1, Ext1.getValue(1)});
4440+
}
4441+
if (Op0VT == MVT::bf16)
4442+
return DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), VT,
4443+
DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, Op0));
4444+
return SDValue();
4445+
}
4446+
4447+
if (VT.getScalarType() == MVT::f32) {
4448+
// FP16->FP32 extends are legal for v32 and v4f32.
4449+
if (Op0VT.getScalarType() == MVT::f16)
4450+
return Op;
4451+
if (Op0VT.getScalarType() == MVT::bf16) {
4452+
SDLoc DL(Op);
4453+
EVT IVT = VT.changeTypeToInteger();
4454+
if (!Op0VT.isVector()) {
4455+
Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v4bf16, Op0);
4456+
IVT = MVT::v4i32;
4457+
}
4458+
4459+
EVT Op0IVT = Op0.getValueType().changeTypeToInteger();
4460+
SDValue Ext =
4461+
DAG.getNode(ISD::ANY_EXTEND, DL, IVT, DAG.getBitcast(Op0IVT, Op0));
4462+
SDValue Shift =
4463+
DAG.getNode(ISD::SHL, DL, IVT, Ext, DAG.getConstant(16, DL, IVT));
4464+
if (!Op0VT.isVector())
4465+
Shift = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, Shift,
4466+
DAG.getConstant(0, DL, MVT::i64));
4467+
Shift = DAG.getBitcast(VT, Shift);
4468+
return IsStrict ? DAG.getMergeValues({Shift, Op.getOperand(0)}, DL)
4469+
: Shift;
4470+
}
4471+
return SDValue();
4472+
}
4473+
44224474
assert(Op.getValueType() == MVT::f128 && "Unexpected lowering");
44234475
return SDValue();
44244476
}
@@ -7266,6 +7318,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
72667318
case ISD::STRICT_FP_ROUND:
72677319
return LowerFP_ROUND(Op, DAG);
72687320
case ISD::FP_EXTEND:
7321+
case ISD::STRICT_FP_EXTEND:
72697322
return LowerFP_EXTEND(Op, DAG);
72707323
case ISD::FRAMEADDR:
72717324
return LowerFRAMEADDR(Op, DAG);

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5115,22 +5115,6 @@ let Predicates = [HasFullFP16] in {
51155115
//===----------------------------------------------------------------------===//
51165116

51175117
defm FCVT : FPConversion<"fcvt">;
5118-
// Helper to get bf16 into fp32.
5119-
def cvt_bf16_to_fp32 :
5120-
OutPatFrag<(ops node:$Rn),
5121-
(f32 (COPY_TO_REGCLASS
5122-
(i32 (UBFMWri
5123-
(i32 (COPY_TO_REGCLASS (INSERT_SUBREG (f32 (IMPLICIT_DEF)),
5124-
node:$Rn, hsub), GPR32)),
5125-
(i64 (i32shift_a (i64 16))),
5126-
(i64 (i32shift_b (i64 16))))),
5127-
FPR32))>;
5128-
// Pattern for bf16 -> fp32.
5129-
def : Pat<(f32 (any_fpextend (bf16 FPR16:$Rn))),
5130-
(cvt_bf16_to_fp32 FPR16:$Rn)>;
5131-
// Pattern for bf16 -> fp64.
5132-
def : Pat<(f64 (any_fpextend (bf16 FPR16:$Rn))),
5133-
(FCVTDSr (f32 (cvt_bf16_to_fp32 FPR16:$Rn)))>;
51345118

51355119
//===----------------------------------------------------------------------===//
51365120
// Floating point single operand instructions.
@@ -8343,8 +8327,6 @@ def : Pat<(v4i32 (anyext (v4i16 V64:$Rn))), (USHLLv4i16_shift V64:$Rn, (i32 0))>
83438327
def : Pat<(v2i64 (sext (v2i32 V64:$Rn))), (SSHLLv2i32_shift V64:$Rn, (i32 0))>;
83448328
def : Pat<(v2i64 (zext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
83458329
def : Pat<(v2i64 (anyext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
8346-
// Vector bf16 -> fp32 is implemented morally as a zext + shift.
8347-
def : Pat<(v4f32 (any_fpextend (v4bf16 V64:$Rn))), (SHLLv4i16 V64:$Rn)>;
83488330
// Also match an extend from the upper half of a 128 bit source register.
83498331
def : Pat<(v8i16 (anyext (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn)) ))),
83508332
(USHLLv16i8_shift V128:$Rn, (i32 0))>;

llvm/test/CodeGen/AArch64/arm64-fast-isel-conversion-fallback.ll

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,10 @@ define i32 @fptosi_bf(bfloat %a) nounwind ssp {
156156
; CHECK-LABEL: fptosi_bf:
157157
; CHECK: // %bb.0: // %entry
158158
; CHECK-NEXT: fmov s1, s0
159-
; CHECK-NEXT: // implicit-def: $s0
159+
; CHECK-NEXT: // implicit-def: $d0
160160
; CHECK-NEXT: fmov s0, s1
161-
; CHECK-NEXT: fmov w8, s0
162-
; CHECK-NEXT: lsl w8, w8, #16
163-
; CHECK-NEXT: fmov s0, w8
161+
; CHECK-NEXT: shll v0.4s, v0.4h, #16
162+
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $q0
164163
; CHECK-NEXT: fcvtzs w0, s0
165164
; CHECK-NEXT: ret
166165
entry:
@@ -173,11 +172,10 @@ define i32 @fptoui_sbf(bfloat %a) nounwind ssp {
173172
; CHECK-LABEL: fptoui_sbf:
174173
; CHECK: // %bb.0: // %entry
175174
; CHECK-NEXT: fmov s1, s0
176-
; CHECK-NEXT: // implicit-def: $s0
175+
; CHECK-NEXT: // implicit-def: $d0
177176
; CHECK-NEXT: fmov s0, s1
178-
; CHECK-NEXT: fmov w8, s0
179-
; CHECK-NEXT: lsl w8, w8, #16
180-
; CHECK-NEXT: fmov s0, w8
177+
; CHECK-NEXT: shll v0.4s, v0.4h, #16
178+
; CHECK-NEXT: // kill: def $s0 killed $s0 killed $q0
181179
; CHECK-NEXT: fcvtzu w0, s0
182180
; CHECK-NEXT: ret
183181
entry:

llvm/test/CodeGen/AArch64/atomicrmw-fadd.ll

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,14 @@ define half @test_atomicrmw_fadd_f16_seq_cst_align4(ptr %ptr, half %value) #0 {
182182
define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align2(ptr %ptr, bfloat %value) #0 {
183183
; NOLSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align2:
184184
; NOLSE: // %bb.0:
185-
; NOLSE-NEXT: // kill: def $h0 killed $h0 def $s0
186-
; NOLSE-NEXT: fmov w9, s0
185+
; NOLSE-NEXT: // kill: def $h0 killed $h0 def $d0
186+
; NOLSE-NEXT: shll v1.4s, v0.4h, #16
187187
; NOLSE-NEXT: mov w8, #32767 // =0x7fff
188-
; NOLSE-NEXT: lsl w9, w9, #16
189-
; NOLSE-NEXT: fmov s1, w9
190188
; NOLSE-NEXT: .LBB2_1: // %atomicrmw.start
191189
; NOLSE-NEXT: // =>This Inner Loop Header: Depth=1
192190
; NOLSE-NEXT: ldaxrh w9, [x0]
193191
; NOLSE-NEXT: fmov s0, w9
194-
; NOLSE-NEXT: lsl w9, w9, #16
195-
; NOLSE-NEXT: fmov s2, w9
192+
; NOLSE-NEXT: shll v2.4s, v0.4h, #16
196193
; NOLSE-NEXT: fadd s2, s2, s1
197194
; NOLSE-NEXT: fmov w9, s2
198195
; NOLSE-NEXT: ubfx w10, w9, #16, #1
@@ -202,36 +199,34 @@ define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align2(ptr %ptr, bfloat %value)
202199
; NOLSE-NEXT: stlxrh w10, w9, [x0]
203200
; NOLSE-NEXT: cbnz w10, .LBB2_1
204201
; NOLSE-NEXT: // %bb.2: // %atomicrmw.end
205-
; NOLSE-NEXT: // kill: def $h0 killed $h0 killed $s0
202+
; NOLSE-NEXT: // kill: def $h0 killed $h0 killed $d0
206203
; NOLSE-NEXT: ret
207204
;
208205
; LSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align2:
209206
; LSE: // %bb.0:
210-
; LSE-NEXT: // kill: def $h0 killed $h0 def $s0
211-
; LSE-NEXT: fmov w9, s0
207+
; LSE-NEXT: // kill: def $h0 killed $h0 def $d0
208+
; LSE-NEXT: shll v1.4s, v0.4h, #16
212209
; LSE-NEXT: mov w8, #32767 // =0x7fff
213210
; LSE-NEXT: ldr h0, [x0]
214-
; LSE-NEXT: lsl w9, w9, #16
215-
; LSE-NEXT: fmov s1, w9
216211
; LSE-NEXT: .LBB2_1: // %atomicrmw.start
217212
; LSE-NEXT: // =>This Inner Loop Header: Depth=1
218-
; LSE-NEXT: fmov w9, s0
219-
; LSE-NEXT: lsl w9, w9, #16
220-
; LSE-NEXT: fmov s2, w9
213+
; LSE-NEXT: shll v2.4s, v0.4h, #16
221214
; LSE-NEXT: fadd s2, s2, s1
222215
; LSE-NEXT: fmov w9, s2
223216
; LSE-NEXT: ubfx w10, w9, #16, #1
224217
; LSE-NEXT: add w9, w9, w8
225218
; LSE-NEXT: add w9, w10, w9
226-
; LSE-NEXT: fmov w10, s0
227219
; LSE-NEXT: lsr w9, w9, #16
228-
; LSE-NEXT: mov w11, w10
229-
; LSE-NEXT: casalh w11, w9, [x0]
220+
; LSE-NEXT: fmov s2, w9
221+
; LSE-NEXT: fmov w9, s0
222+
; LSE-NEXT: fmov w10, s2
223+
; LSE-NEXT: mov w11, w9
224+
; LSE-NEXT: casalh w11, w10, [x0]
230225
; LSE-NEXT: fmov s0, w11
231-
; LSE-NEXT: cmp w11, w10, uxth
226+
; LSE-NEXT: cmp w11, w9, uxth
232227
; LSE-NEXT: b.ne .LBB2_1
233228
; LSE-NEXT: // %bb.2: // %atomicrmw.end
234-
; LSE-NEXT: // kill: def $h0 killed $h0 killed $s0
229+
; LSE-NEXT: // kill: def $h0 killed $h0 killed $d0
235230
; LSE-NEXT: ret
236231
;
237232
; SOFTFP-NOLSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align2:
@@ -281,17 +276,14 @@ define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align2(ptr %ptr, bfloat %value)
281276
define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align4(ptr %ptr, bfloat %value) #0 {
282277
; NOLSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align4:
283278
; NOLSE: // %bb.0:
284-
; NOLSE-NEXT: // kill: def $h0 killed $h0 def $s0
285-
; NOLSE-NEXT: fmov w9, s0
279+
; NOLSE-NEXT: // kill: def $h0 killed $h0 def $d0
280+
; NOLSE-NEXT: shll v1.4s, v0.4h, #16
286281
; NOLSE-NEXT: mov w8, #32767 // =0x7fff
287-
; NOLSE-NEXT: lsl w9, w9, #16
288-
; NOLSE-NEXT: fmov s1, w9
289282
; NOLSE-NEXT: .LBB3_1: // %atomicrmw.start
290283
; NOLSE-NEXT: // =>This Inner Loop Header: Depth=1
291284
; NOLSE-NEXT: ldaxrh w9, [x0]
292285
; NOLSE-NEXT: fmov s0, w9
293-
; NOLSE-NEXT: lsl w9, w9, #16
294-
; NOLSE-NEXT: fmov s2, w9
286+
; NOLSE-NEXT: shll v2.4s, v0.4h, #16
295287
; NOLSE-NEXT: fadd s2, s2, s1
296288
; NOLSE-NEXT: fmov w9, s2
297289
; NOLSE-NEXT: ubfx w10, w9, #16, #1
@@ -301,36 +293,34 @@ define bfloat @test_atomicrmw_fadd_bf16_seq_cst_align4(ptr %ptr, bfloat %value)
301293
; NOLSE-NEXT: stlxrh w10, w9, [x0]
302294
; NOLSE-NEXT: cbnz w10, .LBB3_1
303295
; NOLSE-NEXT: // %bb.2: // %atomicrmw.end
304-
; NOLSE-NEXT: // kill: def $h0 killed $h0 killed $s0
296+
; NOLSE-NEXT: // kill: def $h0 killed $h0 killed $d0
305297
; NOLSE-NEXT: ret
306298
;
307299
; LSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align4:
308300
; LSE: // %bb.0:
309-
; LSE-NEXT: // kill: def $h0 killed $h0 def $s0
310-
; LSE-NEXT: fmov w9, s0
301+
; LSE-NEXT: // kill: def $h0 killed $h0 def $d0
302+
; LSE-NEXT: shll v1.4s, v0.4h, #16
311303
; LSE-NEXT: mov w8, #32767 // =0x7fff
312304
; LSE-NEXT: ldr h0, [x0]
313-
; LSE-NEXT: lsl w9, w9, #16
314-
; LSE-NEXT: fmov s1, w9
315305
; LSE-NEXT: .LBB3_1: // %atomicrmw.start
316306
; LSE-NEXT: // =>This Inner Loop Header: Depth=1
317-
; LSE-NEXT: fmov w9, s0
318-
; LSE-NEXT: lsl w9, w9, #16
319-
; LSE-NEXT: fmov s2, w9
307+
; LSE-NEXT: shll v2.4s, v0.4h, #16
320308
; LSE-NEXT: fadd s2, s2, s1
321309
; LSE-NEXT: fmov w9, s2
322310
; LSE-NEXT: ubfx w10, w9, #16, #1
323311
; LSE-NEXT: add w9, w9, w8
324312
; LSE-NEXT: add w9, w10, w9
325-
; LSE-NEXT: fmov w10, s0
326313
; LSE-NEXT: lsr w9, w9, #16
327-
; LSE-NEXT: mov w11, w10
328-
; LSE-NEXT: casalh w11, w9, [x0]
314+
; LSE-NEXT: fmov s2, w9
315+
; LSE-NEXT: fmov w9, s0
316+
; LSE-NEXT: fmov w10, s2
317+
; LSE-NEXT: mov w11, w9
318+
; LSE-NEXT: casalh w11, w10, [x0]
329319
; LSE-NEXT: fmov s0, w11
330-
; LSE-NEXT: cmp w11, w10, uxth
320+
; LSE-NEXT: cmp w11, w9, uxth
331321
; LSE-NEXT: b.ne .LBB3_1
332322
; LSE-NEXT: // %bb.2: // %atomicrmw.end
333-
; LSE-NEXT: // kill: def $h0 killed $h0 killed $s0
323+
; LSE-NEXT: // kill: def $h0 killed $h0 killed $d0
334324
; LSE-NEXT: ret
335325
;
336326
; SOFTFP-NOLSE-LABEL: test_atomicrmw_fadd_bf16_seq_cst_align4:

0 commit comments

Comments
 (0)