Skip to content

Commit 3dd6750

Browse files
committed
[AArch64] Add more complete support for BF16
We can use a small amount of integer arithmetic to round FP32 to BF16 and extend BF16 to FP32. While a number of operations still require promotion, this can be reduced for some rather simple operations like abs, copysign, fneg but these can be done in a follow-up. A few neat optimizations are implemented: - round-inexact-to-odd is used for F64 to BF16 rounding. - quieting signaling NaNs for f32 -> bf16 tries to detect if a prior operation makes it unnecessary.
1 parent 2435dcd commit 3dd6750

File tree

13 files changed

+842
-197
lines changed

13 files changed

+842
-197
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,13 +1573,14 @@ class TargetLoweringBase {
15731573
assert((VT.isInteger() || VT.isFloatingPoint()) &&
15741574
"Cannot autopromote this type, add it with AddPromotedToType.");
15751575

1576+
uint64_t VTBits = VT.getScalarSizeInBits();
15761577
MVT NVT = VT;
15771578
do {
15781579
NVT = (MVT::SimpleValueType)(NVT.SimpleTy+1);
15791580
assert(NVT.isInteger() == VT.isInteger() && NVT != MVT::isVoid &&
15801581
"Didn't find type to promote to!");
1581-
} while (!isTypeLegal(NVT) ||
1582-
getOperationAction(Op, NVT) == Promote);
1582+
} while (VTBits >= NVT.getScalarSizeInBits() || !isTypeLegal(NVT) ||
1583+
getOperationAction(Op, NVT) == Promote);
15831584
return NVT;
15841585
}
15851586

llvm/include/llvm/CodeGen/ValueTypes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ namespace llvm {
107107
return changeExtendedVectorElementType(EltVT);
108108
}
109109

110+
/// Return a VT for a type whose attributes match ourselves with the
111+
/// exception of the element type that is chosen by the caller.
112+
EVT changeElementType(EVT EltVT) const {
113+
EltVT = EltVT.getScalarType();
114+
return isVector() ? changeVectorElementType(EltVT) : EltVT;
115+
}
116+
110117
/// Return the type converted to an equivalently sized integer or vector
111118
/// with integer element type. Similar to changeVectorElementTypeToInteger,
112119
/// but also handles scalars.

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 201 additions & 89 deletions
Large diffs are not rendered by default.

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ enum NodeType : unsigned {
249249
FCMLEz,
250250
FCMLTz,
251251

252+
// Round wide FP to narrow FP with inexact results to odd.
253+
FCVTXN,
254+
252255
// Vector across-lanes addition
253256
// Only the lower result lane is defined.
254257
SADDV,

llvm/lib/Target/AArch64/AArch64InstrFormats.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7547,7 +7547,7 @@ class BaseSIMDCmpTwoScalar<bit U, bits<2> size, bits<2> size2, bits<5> opcode,
75477547
let mayRaiseFPException = 1, Uses = [FPCR] in
75487548
class SIMDInexactCvtTwoScalar<bits<5> opcode, string asm>
75497549
: I<(outs FPR32:$Rd), (ins FPR64:$Rn), asm, "\t$Rd, $Rn", "",
7550-
[(set (f32 FPR32:$Rd), (int_aarch64_sisd_fcvtxn (f64 FPR64:$Rn)))]>,
7550+
[(set (f32 FPR32:$Rd), (AArch64fcvtxn (f64 FPR64:$Rn)))]>,
75517551
Sched<[WriteVd]> {
75527552
bits<5> Rd;
75537553
bits<5> Rn;

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,11 @@ def AArch64fcmgtz: SDNode<"AArch64ISD::FCMGTz", SDT_AArch64fcmpz>;
756756
def AArch64fcmlez: SDNode<"AArch64ISD::FCMLEz", SDT_AArch64fcmpz>;
757757
def AArch64fcmltz: SDNode<"AArch64ISD::FCMLTz", SDT_AArch64fcmpz>;
758758

759+
def AArch64fcvtxn_n: SDNode<"AArch64ISD::FCVTXN", SDTFPRoundOp>;
760+
def AArch64fcvtxn: PatFrags<(ops node:$Rn),
761+
[(f32 (int_aarch64_sisd_fcvtxn (f64 node:$Rn))),
762+
(f32 (AArch64fcvtxn_n (f64 node:$Rn)))]>;
763+
759764
def AArch64bici: SDNode<"AArch64ISD::BICi", SDT_AArch64vecimm>;
760765
def AArch64orri: SDNode<"AArch64ISD::ORRi", SDT_AArch64vecimm>;
761766

@@ -1276,6 +1281,9 @@ def BFMLALTIdx : SIMDBF16MLALIndex<1, "bfmlalt", int_aarch64_neon_bfmlalt>;
12761281
def BFCVTN : SIMD_BFCVTN;
12771282
def BFCVTN2 : SIMD_BFCVTN2;
12781283

1284+
def : Pat<(v4bf16 (any_fpround (v4f32 V128:$Rn))),
1285+
(EXTRACT_SUBREG (BFCVTN V128:$Rn), dsub)>;
1286+
12791287
// Vector-scalar BFDOT:
12801288
// The second source operand of the 64-bit variant of BF16DOTlane is a 128-bit
12811289
// register (the instruction uses a single 32-bit lane from it), so the pattern
@@ -1296,6 +1304,8 @@ def : Pat<(v2f32 (int_aarch64_neon_bfdot
12961304

12971305
let Predicates = [HasNEONorSME, HasBF16] in {
12981306
def BFCVT : BF16ToSinglePrecision<"bfcvt">;
1307+
// Round FP32 to BF16.
1308+
def : Pat<(bf16 (any_fpround (f32 FPR32:$Rn))), (BFCVT $Rn)>;
12991309
}
13001310

13011311
// ARMv8.6A AArch64 matrix multiplication
@@ -4648,6 +4658,22 @@ let Predicates = [HasFullFP16] in {
46484658
//===----------------------------------------------------------------------===//
46494659

46504660
defm FCVT : FPConversion<"fcvt">;
4661+
// Helper to get bf16 into fp32.
4662+
def cvt_bf16_to_fp32 :
4663+
OutPatFrag<(ops node:$Rn),
4664+
(f32 (COPY_TO_REGCLASS
4665+
(i32 (UBFMWri
4666+
(i32 (COPY_TO_REGCLASS (INSERT_SUBREG (f32 (IMPLICIT_DEF)),
4667+
node:$Rn, hsub), GPR32)),
4668+
(i64 (i32shift_a (i64 16))),
4669+
(i64 (i32shift_b (i64 16))))),
4670+
FPR32))>;
4671+
// Pattern for bf16 -> fp32.
4672+
def : Pat<(f32 (any_fpextend (bf16 FPR16:$Rn))),
4673+
(cvt_bf16_to_fp32 FPR16:$Rn)>;
4674+
// Pattern for bf16 -> fp64.
4675+
def : Pat<(f64 (any_fpextend (bf16 FPR16:$Rn))),
4676+
(FCVTDSr (f32 (cvt_bf16_to_fp32 FPR16:$Rn)))>;
46514677

46524678
//===----------------------------------------------------------------------===//
46534679
// Floating point single operand instructions.
@@ -5002,6 +5028,9 @@ defm FCVTNU : SIMDTwoVectorFPToInt<1,0,0b11010, "fcvtnu",int_aarch64_neon_fcvtnu
50025028
defm FCVTN : SIMDFPNarrowTwoVector<0, 0, 0b10110, "fcvtn">;
50035029
def : Pat<(v4i16 (int_aarch64_neon_vcvtfp2hf (v4f32 V128:$Rn))),
50045030
(FCVTNv4i16 V128:$Rn)>;
5031+
//def : Pat<(concat_vectors V64:$Rd,
5032+
// (v4bf16 (any_fpround (v4f32 V128:$Rn)))),
5033+
// (FCVTNv8bf16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub), V128:$Rn)>;
50055034
def : Pat<(concat_vectors V64:$Rd,
50065035
(v4i16 (int_aarch64_neon_vcvtfp2hf (v4f32 V128:$Rn)))),
50075036
(FCVTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub), V128:$Rn)>;
@@ -5686,6 +5715,11 @@ defm USQADD : SIMDTwoScalarBHSDTied< 1, 0b00011, "usqadd",
56865715
def : Pat<(v1i64 (AArch64vashr (v1i64 V64:$Rn), (i32 63))),
56875716
(CMLTv1i64rz V64:$Rn)>;
56885717

5718+
// Round FP64 to BF16.
5719+
let Predicates = [HasNEONorSME, HasBF16] in
5720+
def : Pat<(bf16 (any_fpround (f64 FPR64:$Rn))),
5721+
(BFCVT (FCVTXNv1i64 $Rn))>;
5722+
56895723
def : Pat<(v1i64 (int_aarch64_neon_fcvtas (v1f64 FPR64:$Rn))),
56905724
(FCVTASv1i64 FPR64:$Rn)>;
56915725
def : Pat<(v1i64 (int_aarch64_neon_fcvtau (v1f64 FPR64:$Rn))),
@@ -7698,6 +7732,9 @@ def : Pat<(v4i32 (anyext (v4i16 V64:$Rn))), (USHLLv4i16_shift V64:$Rn, (i32 0))>
76987732
def : Pat<(v2i64 (sext (v2i32 V64:$Rn))), (SSHLLv2i32_shift V64:$Rn, (i32 0))>;
76997733
def : Pat<(v2i64 (zext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
77007734
def : Pat<(v2i64 (anyext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
7735+
// Vector bf16 -> fp32 is implemented morally as a zext + shift.
7736+
def : Pat<(v4f32 (any_fpextend (v4bf16 V64:$Rn))),
7737+
(USHLLv4i16_shift V64:$Rn, (i32 16))>;
77017738
// Also match an extend from the upper half of a 128 bit source register.
77027739
def : Pat<(v8i16 (anyext (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn)) ))),
77037740
(USHLLv16i8_shift V128:$Rn, (i32 0))>;

llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,13 +1022,17 @@ void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
10221022

10231023
bool Invert = false;
10241024
AArch64CC::CondCode CC, CC2 = AArch64CC::AL;
1025-
if (Pred == CmpInst::Predicate::FCMP_ORD && IsZero) {
1025+
if ((Pred == CmpInst::Predicate::FCMP_ORD ||
1026+
Pred == CmpInst::Predicate::FCMP_UNO) &&
1027+
IsZero) {
10261028
// The special case "fcmp ord %a, 0" is the canonical check that LHS isn't
10271029
// NaN, so equivalent to a == a and doesn't need the two comparisons an
10281030
// "ord" normally would.
1031+
// Similarly, "fcmp uno %a, 0" is the canonical check that LHS is NaN and is
1032+
// thus equivalent to a != a.
10291033
RHS = LHS;
10301034
IsZero = false;
1031-
CC = AArch64CC::EQ;
1035+
CC = Pred == CmpInst::Predicate::FCMP_ORD ? AArch64CC::EQ : AArch64CC::NE;
10321036
} else
10331037
changeVectorFCMPPredToAArch64CC(Pred, CC, CC2, Invert);
10341038

llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ define void @strict_fp_reductions() {
1313
; CHECK-NEXT: Cost Model: Found an estimated cost of 28 for instruction: %fadd_v8f32 = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> undef)
1414
; CHECK-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v2f64 = call double @llvm.vector.reduce.fadd.v2f64(double 0.000000e+00, <2 x double> undef)
1515
; CHECK-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f64 = call double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
16-
; CHECK-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
16+
; CHECK-NEXT: Cost Model: Found an estimated cost of 18 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
1717
; CHECK-NEXT: Cost Model: Found an estimated cost of 20 for instruction: %fadd_v4f128 = call fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
1818
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
1919
;
@@ -24,7 +24,7 @@ define void @strict_fp_reductions() {
2424
; FP16-NEXT: Cost Model: Found an estimated cost of 28 for instruction: %fadd_v8f32 = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> undef)
2525
; FP16-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v2f64 = call double @llvm.vector.reduce.fadd.v2f64(double 0.000000e+00, <2 x double> undef)
2626
; FP16-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f64 = call double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
27-
; FP16-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
27+
; FP16-NEXT: Cost Model: Found an estimated cost of 18 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
2828
; FP16-NEXT: Cost Model: Found an estimated cost of 20 for instruction: %fadd_v4f128 = call fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
2929
; FP16-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
3030
;
@@ -72,7 +72,7 @@ define void @fast_fp_reductions() {
7272
; CHECK-NEXT: Cost Model: Found an estimated cost of 5 for instruction: %fadd_v4f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
7373
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %fadd_v7f64 = call fast double @llvm.vector.reduce.fadd.v7f64(double 0.000000e+00, <7 x double> undef)
7474
; CHECK-NEXT: Cost Model: Found an estimated cost of 15 for instruction: %fadd_v9f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v9f64(double 0.000000e+00, <9 x double> undef)
75-
; CHECK-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
75+
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
7676
; CHECK-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f128 = call reassoc fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
7777
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
7878
;
@@ -95,7 +95,7 @@ define void @fast_fp_reductions() {
9595
; FP16-NEXT: Cost Model: Found an estimated cost of 5 for instruction: %fadd_v4f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
9696
; FP16-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %fadd_v7f64 = call fast double @llvm.vector.reduce.fadd.v7f64(double 0.000000e+00, <7 x double> undef)
9797
; FP16-NEXT: Cost Model: Found an estimated cost of 15 for instruction: %fadd_v9f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v9f64(double 0.000000e+00, <9 x double> undef)
98-
; FP16-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
98+
; FP16-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
9999
; FP16-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f128 = call reassoc fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
100100
; FP16-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
101101
;

llvm/test/CodeGen/AArch64/GlobalISel/lower-neon-vector-fcmp.mir

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -321,18 +321,15 @@ body: |
321321
bb.0:
322322
liveins: $q0, $q1
323323
324-
; Should be inverted. Needs two compares.
325324
326325
; CHECK-LABEL: name: uno_zero
327326
; CHECK: liveins: $q0, $q1
328327
; CHECK-NEXT: {{ $}}
329328
; CHECK-NEXT: %lhs:_(<2 x s64>) = COPY $q0
330-
; CHECK-NEXT: [[FCMGEZ:%[0-9]+]]:_(<2 x s64>) = G_FCMGEZ %lhs
331-
; CHECK-NEXT: [[FCMLTZ:%[0-9]+]]:_(<2 x s64>) = G_FCMLTZ %lhs
332-
; CHECK-NEXT: [[OR:%[0-9]+]]:_(<2 x s64>) = G_OR [[FCMLTZ]], [[FCMGEZ]]
329+
; CHECK-NEXT: [[FCMEQ:%[0-9]+]]:_(<2 x s64>) = G_FCMEQ %lhs, %lhs(<2 x s64>)
333330
; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 -1
334331
; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<2 x s64>) = G_BUILD_VECTOR [[C]](s64), [[C]](s64)
335-
; CHECK-NEXT: [[XOR:%[0-9]+]]:_(<2 x s64>) = G_XOR [[OR]], [[BUILD_VECTOR]]
332+
; CHECK-NEXT: [[XOR:%[0-9]+]]:_(<2 x s64>) = G_XOR [[FCMEQ]], [[BUILD_VECTOR]]
336333
; CHECK-NEXT: $q0 = COPY [[XOR]](<2 x s64>)
337334
; CHECK-NEXT: RET_ReallyLR implicit $q0
338335
%lhs:_(<2 x s64>) = COPY $q0

llvm/test/CodeGen/AArch64/implicitly-set-zero-high-64-bits.ll

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,7 @@ entry:
187187
define <8 x bfloat> @insertzero_v4bf16(<4 x bfloat> %a) {
188188
; CHECK-LABEL: insertzero_v4bf16:
189189
; CHECK: // %bb.0: // %entry
190-
; CHECK-NEXT: movi d4, #0000000000000000
191-
; CHECK-NEXT: movi d5, #0000000000000000
192-
; CHECK-NEXT: movi d6, #0000000000000000
193-
; CHECK-NEXT: movi d7, #0000000000000000
190+
; CHECK-NEXT: fmov d0, d0
194191
; CHECK-NEXT: ret
195192
entry:
196193
%shuffle.i = shufflevector <4 x bfloat> %a, <4 x bfloat> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>

0 commit comments

Comments
 (0)