Skip to content

Commit 0ed5d9a

Browse files
authored
[LoongArch][BF16] Add support for the __bf16 type (#142548)
The LoongArch psABI recently added __bf16 type support. Now we can enable this new type in clang. Currently, bf16 operations are automatically supported by promoting to float. This patch adds bf16 support by ensuring that load extension / truncate store operations are properly expanded. And this commit implements support for bf16 truncate/extend on hard FP targets. The extend operation is implemented by a shift just as in the standard legalization. This requires custom lowering of the truncate libcall on hard float ABIs (the normal libcall code path is used on soft ABIs).
1 parent 90beda2 commit 0ed5d9a

File tree

8 files changed

+1822
-4
lines changed

8 files changed

+1822
-4
lines changed

clang/docs/LanguageExtensions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,7 @@ to ``float``; see below for more information on this emulation.
10091009
* 64-bit ARM (AArch64)
10101010
* RISC-V
10111011
* X86 (when SSE2 is available)
1012+
* LoongArch
10121013

10131014
(For X86, SSE2 is available on 64-bit and all recent 32-bit processors.)
10141015

clang/lib/Basic/Targets/LoongArch.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ class LLVM_LIBRARY_VISIBILITY LoongArchTargetInfo : public TargetInfo {
4949
HasFeatureLD_SEQ_SA = false;
5050
HasFeatureDiv32 = false;
5151
HasFeatureSCQ = false;
52+
BFloat16Width = 16;
53+
BFloat16Align = 16;
54+
BFloat16Format = &llvm::APFloat::BFloat();
5255
LongDoubleWidth = 128;
5356
LongDoubleAlign = 128;
5457
LongDoubleFormat = &llvm::APFloat::IEEEquad();
@@ -99,6 +102,8 @@ class LLVM_LIBRARY_VISIBILITY LoongArchTargetInfo : public TargetInfo {
99102

100103
bool hasBitIntType() const override { return true; }
101104

105+
bool hasBFloat16Type() const override { return true; }
106+
102107
bool useFP16ConversionIntrinsics() const override { return false; }
103108

104109
bool handleTargetFeatures(std::vector<std::string> &Features,

clang/test/CodeGen/LoongArch/bfloat-abi.c

Lines changed: 532 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 2
2+
// RUN: %clang_cc1 -triple loongarch64 -emit-llvm -o - %s | FileCheck %s
3+
// RUN: %clang_cc1 -triple loongarch32 -emit-llvm -o - %s | FileCheck %s
4+
5+
// CHECK-LABEL: define dso_local void @_Z3fooDF16b
6+
// CHECK-SAME: (bfloat noundef [[B:%.*]]) #[[ATTR0:[0-9]+]] {
7+
// CHECK-NEXT: entry:
8+
// CHECK-NEXT: [[B_ADDR:%.*]] = alloca bfloat, align 2
9+
// CHECK-NEXT: store bfloat [[B]], ptr [[B_ADDR]], align 2
10+
// CHECK-NEXT: ret void
11+
//
12+
void foo(__bf16 b) {}

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
182182
if (Subtarget.hasBasicF()) {
183183
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
184184
setTruncStoreAction(MVT::f32, MVT::f16, Expand);
185+
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
186+
setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
185187
setCondCodeAction(FPCCToExpand, MVT::f32, Expand);
186188

187189
setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
@@ -203,6 +205,9 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
203205
Subtarget.isSoftFPABI() ? LibCall : Custom);
204206
setOperationAction(ISD::FP_TO_FP16, MVT::f32,
205207
Subtarget.isSoftFPABI() ? LibCall : Custom);
208+
setOperationAction(ISD::BF16_TO_FP, MVT::f32, Custom);
209+
setOperationAction(ISD::FP_TO_BF16, MVT::f32,
210+
Subtarget.isSoftFPABI() ? LibCall : Custom);
206211

207212
if (Subtarget.is64Bit())
208213
setOperationAction(ISD::FRINT, MVT::f32, Legal);
@@ -221,6 +226,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
221226
if (Subtarget.hasBasicD()) {
222227
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
223228
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
229+
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
230+
setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
224231
setTruncStoreAction(MVT::f64, MVT::f16, Expand);
225232
setTruncStoreAction(MVT::f64, MVT::f32, Expand);
226233
setCondCodeAction(FPCCToExpand, MVT::f64, Expand);
@@ -243,6 +250,9 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
243250
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
244251
setOperationAction(ISD::FP_TO_FP16, MVT::f64,
245252
Subtarget.isSoftFPABI() ? LibCall : Custom);
253+
setOperationAction(ISD::BF16_TO_FP, MVT::f64, Custom);
254+
setOperationAction(ISD::FP_TO_BF16, MVT::f64,
255+
Subtarget.isSoftFPABI() ? LibCall : Custom);
246256

247257
if (Subtarget.is64Bit())
248258
setOperationAction(ISD::FRINT, MVT::f64, Legal);
@@ -499,6 +509,10 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
499509
return lowerFP_TO_FP16(Op, DAG);
500510
case ISD::FP16_TO_FP:
501511
return lowerFP16_TO_FP(Op, DAG);
512+
case ISD::FP_TO_BF16:
513+
return lowerFP_TO_BF16(Op, DAG);
514+
case ISD::BF16_TO_FP:
515+
return lowerBF16_TO_FP(Op, DAG);
502516
}
503517
return SDValue();
504518
}
@@ -2333,6 +2347,36 @@ SDValue LoongArchTargetLowering::lowerFP16_TO_FP(SDValue Op,
23332347
return Res;
23342348
}
23352349

2350+
SDValue LoongArchTargetLowering::lowerFP_TO_BF16(SDValue Op,
2351+
SelectionDAG &DAG) const {
2352+
assert(Subtarget.hasBasicF() && "Unexpected custom legalization");
2353+
SDLoc DL(Op);
2354+
MakeLibCallOptions CallOptions;
2355+
RTLIB::Libcall LC =
2356+
RTLIB::getFPROUND(Op.getOperand(0).getValueType(), MVT::bf16);
2357+
SDValue Res =
2358+
makeLibCall(DAG, LC, MVT::f32, Op.getOperand(0), CallOptions, DL).first;
2359+
if (Subtarget.is64Bit())
2360+
return DAG.getNode(LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64, Res);
2361+
return DAG.getBitcast(MVT::i32, Res);
2362+
}
2363+
2364+
SDValue LoongArchTargetLowering::lowerBF16_TO_FP(SDValue Op,
2365+
SelectionDAG &DAG) const {
2366+
assert(Subtarget.hasBasicF() && "Unexpected custom legalization");
2367+
MVT VT = Op.getSimpleValueType();
2368+
SDLoc DL(Op);
2369+
Op = DAG.getNode(
2370+
ISD::SHL, DL, Op.getOperand(0).getValueType(), Op.getOperand(0),
2371+
DAG.getShiftAmountConstant(16, Op.getOperand(0).getValueType(), DL));
2372+
SDValue Res = Subtarget.is64Bit() ? DAG.getNode(LoongArchISD::MOVGR2FR_W_LA64,
2373+
DL, MVT::f32, Op)
2374+
: DAG.getBitcast(MVT::f32, Op);
2375+
if (VT != MVT::f32)
2376+
return DAG.getNode(ISD::FP_EXTEND, DL, VT, Res);
2377+
return Res;
2378+
}
2379+
23362380
static bool isConstantOrUndef(const SDValue Op) {
23372381
if (Op->isUndef())
23382382
return true;
@@ -7993,8 +8037,9 @@ bool LoongArchTargetLowering::splitValueIntoRegisterParts(
79938037
bool IsABIRegCopy = CC.has_value();
79948038
EVT ValueVT = Val.getValueType();
79958039

7996-
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
7997-
// Cast the f16 to i16, extend to i32, pad with ones to make a float
8040+
if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
8041+
PartVT == MVT::f32) {
8042+
// Cast the [b]f16 to i16, extend to i32, pad with ones to make a float
79988043
// nan, and cast to f32.
79998044
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val);
80008045
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
@@ -8013,10 +8058,11 @@ SDValue LoongArchTargetLowering::joinRegisterPartsIntoValue(
80138058
MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
80148059
bool IsABIRegCopy = CC.has_value();
80158060

8016-
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
8061+
if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
8062+
PartVT == MVT::f32) {
80178063
SDValue Val = Parts[0];
80188064

8019-
// Cast the f32 to i32, truncate to i16, and cast back to f16.
8065+
// Cast the f32 to i32, truncate to i16, and cast back to [b]f16.
80208066
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
80218067
Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val);
80228068
Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,8 @@ class LoongArchTargetLowering : public TargetLowering {
373373
SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) const;
374374
SDValue lowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) const;
375375
SDValue lowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
376+
SDValue lowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const;
377+
SDValue lowerBF16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
376378

377379
bool isFPImmLegal(const APFloat &Imm, EVT VT,
378380
bool ForCodeSize) const override;
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=loongarch64 -mattr=+d -target-abi=lp64d < %s | FileCheck --check-prefixes=CHECK,LA64 %s
3+
; RUN: llc -mtriple=loongarch32 -mattr=+d -target-abi=ilp32d < %s | FileCheck --check-prefixes=CHECK,LA32 %s
4+
5+
define void @test_load_store(ptr %p, ptr %q) nounwind {
6+
; CHECK-LABEL: test_load_store:
7+
; CHECK: # %bb.0:
8+
; CHECK-NEXT: ld.h $a0, $a0, 0
9+
; CHECK-NEXT: st.h $a0, $a1, 0
10+
; CHECK-NEXT: ret
11+
%a = load bfloat, ptr %p
12+
store bfloat %a, ptr %q
13+
ret void
14+
}
15+
16+
define float @test_fpextend_float(ptr %p) nounwind {
17+
; LA64-LABEL: test_fpextend_float:
18+
; LA64: # %bb.0:
19+
; LA64-NEXT: ld.hu $a0, $a0, 0
20+
; LA64-NEXT: slli.d $a0, $a0, 16
21+
; LA64-NEXT: movgr2fr.w $fa0, $a0
22+
; LA64-NEXT: ret
23+
;
24+
; LA32-LABEL: test_fpextend_float:
25+
; LA32: # %bb.0:
26+
; LA32-NEXT: ld.hu $a0, $a0, 0
27+
; LA32-NEXT: slli.w $a0, $a0, 16
28+
; LA32-NEXT: movgr2fr.w $fa0, $a0
29+
; LA32-NEXT: ret
30+
%a = load bfloat, ptr %p
31+
%r = fpext bfloat %a to float
32+
ret float %r
33+
}
34+
35+
define double @test_fpextend_double(ptr %p) nounwind {
36+
; LA64-LABEL: test_fpextend_double:
37+
; LA64: # %bb.0:
38+
; LA64-NEXT: ld.hu $a0, $a0, 0
39+
; LA64-NEXT: slli.d $a0, $a0, 16
40+
; LA64-NEXT: movgr2fr.w $fa0, $a0
41+
; LA64-NEXT: fcvt.d.s $fa0, $fa0
42+
; LA64-NEXT: ret
43+
;
44+
; LA32-LABEL: test_fpextend_double:
45+
; LA32: # %bb.0:
46+
; LA32-NEXT: ld.hu $a0, $a0, 0
47+
; LA32-NEXT: slli.w $a0, $a0, 16
48+
; LA32-NEXT: movgr2fr.w $fa0, $a0
49+
; LA32-NEXT: fcvt.d.s $fa0, $fa0
50+
; LA32-NEXT: ret
51+
%a = load bfloat, ptr %p
52+
%r = fpext bfloat %a to double
53+
ret double %r
54+
}
55+
56+
define void @test_fptrunc_float(float %f, ptr %p) nounwind {
57+
; LA64-LABEL: test_fptrunc_float:
58+
; LA64: # %bb.0:
59+
; LA64-NEXT: addi.d $sp, $sp, -16
60+
; LA64-NEXT: st.d $ra, $sp, 8 # 8-byte Folded Spill
61+
; LA64-NEXT: st.d $fp, $sp, 0 # 8-byte Folded Spill
62+
; LA64-NEXT: move $fp, $a0
63+
; LA64-NEXT: pcaddu18i $ra, %call36(__truncsfbf2)
64+
; LA64-NEXT: jirl $ra, $ra, 0
65+
; LA64-NEXT: movfr2gr.s $a0, $fa0
66+
; LA64-NEXT: st.h $a0, $fp, 0
67+
; LA64-NEXT: ld.d $fp, $sp, 0 # 8-byte Folded Reload
68+
; LA64-NEXT: ld.d $ra, $sp, 8 # 8-byte Folded Reload
69+
; LA64-NEXT: addi.d $sp, $sp, 16
70+
; LA64-NEXT: ret
71+
;
72+
; LA32-LABEL: test_fptrunc_float:
73+
; LA32: # %bb.0:
74+
; LA32-NEXT: addi.w $sp, $sp, -16
75+
; LA32-NEXT: st.w $ra, $sp, 12 # 4-byte Folded Spill
76+
; LA32-NEXT: st.w $fp, $sp, 8 # 4-byte Folded Spill
77+
; LA32-NEXT: move $fp, $a0
78+
; LA32-NEXT: bl __truncsfbf2
79+
; LA32-NEXT: movfr2gr.s $a0, $fa0
80+
; LA32-NEXT: st.h $a0, $fp, 0
81+
; LA32-NEXT: ld.w $fp, $sp, 8 # 4-byte Folded Reload
82+
; LA32-NEXT: ld.w $ra, $sp, 12 # 4-byte Folded Reload
83+
; LA32-NEXT: addi.w $sp, $sp, 16
84+
; LA32-NEXT: ret
85+
%a = fptrunc float %f to bfloat
86+
store bfloat %a, ptr %p
87+
ret void
88+
}
89+
90+
define void @test_fptrunc_double(double %d, ptr %p) nounwind {
91+
; LA64-LABEL: test_fptrunc_double:
92+
; LA64: # %bb.0:
93+
; LA64-NEXT: addi.d $sp, $sp, -16
94+
; LA64-NEXT: st.d $ra, $sp, 8 # 8-byte Folded Spill
95+
; LA64-NEXT: st.d $fp, $sp, 0 # 8-byte Folded Spill
96+
; LA64-NEXT: move $fp, $a0
97+
; LA64-NEXT: pcaddu18i $ra, %call36(__truncdfbf2)
98+
; LA64-NEXT: jirl $ra, $ra, 0
99+
; LA64-NEXT: movfr2gr.s $a0, $fa0
100+
; LA64-NEXT: st.h $a0, $fp, 0
101+
; LA64-NEXT: ld.d $fp, $sp, 0 # 8-byte Folded Reload
102+
; LA64-NEXT: ld.d $ra, $sp, 8 # 8-byte Folded Reload
103+
; LA64-NEXT: addi.d $sp, $sp, 16
104+
; LA64-NEXT: ret
105+
;
106+
; LA32-LABEL: test_fptrunc_double:
107+
; LA32: # %bb.0:
108+
; LA32-NEXT: addi.w $sp, $sp, -16
109+
; LA32-NEXT: st.w $ra, $sp, 12 # 4-byte Folded Spill
110+
; LA32-NEXT: st.w $fp, $sp, 8 # 4-byte Folded Spill
111+
; LA32-NEXT: move $fp, $a0
112+
; LA32-NEXT: bl __truncdfbf2
113+
; LA32-NEXT: movfr2gr.s $a0, $fa0
114+
; LA32-NEXT: st.h $a0, $fp, 0
115+
; LA32-NEXT: ld.w $fp, $sp, 8 # 4-byte Folded Reload
116+
; LA32-NEXT: ld.w $ra, $sp, 12 # 4-byte Folded Reload
117+
; LA32-NEXT: addi.w $sp, $sp, 16
118+
; LA32-NEXT: ret
119+
%a = fptrunc double %d to bfloat
120+
store bfloat %a, ptr %p
121+
ret void
122+
}
123+
124+
define void @test_fadd(ptr %p, ptr %q) nounwind {
125+
; LA64-LABEL: test_fadd:
126+
; LA64: # %bb.0:
127+
; LA64-NEXT: addi.d $sp, $sp, -16
128+
; LA64-NEXT: st.d $ra, $sp, 8 # 8-byte Folded Spill
129+
; LA64-NEXT: st.d $fp, $sp, 0 # 8-byte Folded Spill
130+
; LA64-NEXT: ld.hu $a1, $a1, 0
131+
; LA64-NEXT: move $fp, $a0
132+
; LA64-NEXT: ld.hu $a0, $a0, 0
133+
; LA64-NEXT: slli.d $a1, $a1, 16
134+
; LA64-NEXT: movgr2fr.w $fa0, $a1
135+
; LA64-NEXT: slli.d $a0, $a0, 16
136+
; LA64-NEXT: movgr2fr.w $fa1, $a0
137+
; LA64-NEXT: fadd.s $fa0, $fa1, $fa0
138+
; LA64-NEXT: pcaddu18i $ra, %call36(__truncsfbf2)
139+
; LA64-NEXT: jirl $ra, $ra, 0
140+
; LA64-NEXT: movfr2gr.s $a0, $fa0
141+
; LA64-NEXT: st.h $a0, $fp, 0
142+
; LA64-NEXT: ld.d $fp, $sp, 0 # 8-byte Folded Reload
143+
; LA64-NEXT: ld.d $ra, $sp, 8 # 8-byte Folded Reload
144+
; LA64-NEXT: addi.d $sp, $sp, 16
145+
; LA64-NEXT: ret
146+
;
147+
; LA32-LABEL: test_fadd:
148+
; LA32: # %bb.0:
149+
; LA32-NEXT: addi.w $sp, $sp, -16
150+
; LA32-NEXT: st.w $ra, $sp, 12 # 4-byte Folded Spill
151+
; LA32-NEXT: st.w $fp, $sp, 8 # 4-byte Folded Spill
152+
; LA32-NEXT: ld.hu $a1, $a1, 0
153+
; LA32-NEXT: move $fp, $a0
154+
; LA32-NEXT: ld.hu $a0, $a0, 0
155+
; LA32-NEXT: slli.w $a1, $a1, 16
156+
; LA32-NEXT: movgr2fr.w $fa0, $a1
157+
; LA32-NEXT: slli.w $a0, $a0, 16
158+
; LA32-NEXT: movgr2fr.w $fa1, $a0
159+
; LA32-NEXT: fadd.s $fa0, $fa1, $fa0
160+
; LA32-NEXT: bl __truncsfbf2
161+
; LA32-NEXT: movfr2gr.s $a0, $fa0
162+
; LA32-NEXT: st.h $a0, $fp, 0
163+
; LA32-NEXT: ld.w $fp, $sp, 8 # 4-byte Folded Reload
164+
; LA32-NEXT: ld.w $ra, $sp, 12 # 4-byte Folded Reload
165+
; LA32-NEXT: addi.w $sp, $sp, 16
166+
; LA32-NEXT: ret
167+
%a = load bfloat, ptr %p
168+
%b = load bfloat, ptr %q
169+
%r = fadd bfloat %a, %b
170+
store bfloat %r, ptr %p
171+
ret void
172+
}

0 commit comments

Comments
 (0)