Skip to content

Commit 3e3780e

Browse files
[LLVM][CodeGen][SVE] Implement nxvf32 fpround to nxvbf16. (#107420)
1 parent c1826ae commit 3e3780e

File tree

4 files changed

+180
-7
lines changed

4 files changed

+180
-7
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,6 +1664,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
16641664
setOperationAction(ISD::BITCAST, VT, Custom);
16651665
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
16661666
setOperationAction(ISD::FP_EXTEND, VT, Custom);
1667+
setOperationAction(ISD::FP_ROUND, VT, Custom);
16671668
setOperationAction(ISD::MLOAD, VT, Custom);
16681669
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
16691670
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
@@ -4334,14 +4335,57 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
43344335
SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
43354336
SelectionDAG &DAG) const {
43364337
EVT VT = Op.getValueType();
4337-
if (VT.isScalableVector())
4338-
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
4339-
43404338
bool IsStrict = Op->isStrictFPOpcode();
43414339
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
43424340
EVT SrcVT = SrcVal.getValueType();
43434341
bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1;
43444342

4343+
if (VT.isScalableVector()) {
4344+
if (VT.getScalarType() != MVT::bf16)
4345+
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
4346+
4347+
SDLoc DL(Op);
4348+
constexpr EVT I32 = MVT::nxv4i32;
4349+
auto ImmV = [&](int I) -> SDValue { return DAG.getConstant(I, DL, I32); };
4350+
4351+
SDValue NaN;
4352+
SDValue Narrow;
4353+
4354+
if (SrcVT == MVT::nxv2f32 || SrcVT == MVT::nxv4f32) {
4355+
if (Subtarget->hasBF16())
4356+
return LowerToPredicatedOp(Op, DAG,
4357+
AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
4358+
4359+
Narrow = getSVESafeBitCast(I32, SrcVal, DAG);
4360+
4361+
// Set the quiet bit.
4362+
if (!DAG.isKnownNeverSNaN(SrcVal))
4363+
NaN = DAG.getNode(ISD::OR, DL, I32, Narrow, ImmV(0x400000));
4364+
} else
4365+
return SDValue();
4366+
4367+
if (!Trunc) {
4368+
SDValue Lsb = DAG.getNode(ISD::SRL, DL, I32, Narrow, ImmV(16));
4369+
Lsb = DAG.getNode(ISD::AND, DL, I32, Lsb, ImmV(1));
4370+
SDValue RoundingBias = DAG.getNode(ISD::ADD, DL, I32, Lsb, ImmV(0x7fff));
4371+
Narrow = DAG.getNode(ISD::ADD, DL, I32, Narrow, RoundingBias);
4372+
}
4373+
4374+
// Don't round if we had a NaN, we don't want to turn 0x7fffffff into
4375+
// 0x80000000.
4376+
if (NaN) {
4377+
EVT I1 = I32.changeElementType(MVT::i1);
4378+
EVT CondVT = VT.changeElementType(MVT::i1);
4379+
SDValue IsNaN = DAG.getSetCC(DL, CondVT, SrcVal, SrcVal, ISD::SETUO);
4380+
IsNaN = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, I1, IsNaN);
4381+
Narrow = DAG.getSelect(DL, I32, IsNaN, NaN, Narrow);
4382+
}
4383+
4384+
// Now that we have rounded, shift the bits into position.
4385+
Narrow = DAG.getNode(ISD::SRL, DL, I32, Narrow, ImmV(16));
4386+
return getSVESafeBitCast(VT, Narrow, DAG);
4387+
}
4388+
43454389
if (useSVEForFixedLengthVectorVT(SrcVT, !Subtarget->isNeonAvailable()))
43464390
return LowerFixedLengthFPRoundToSVE(Op, DAG);
43474391

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2425,7 +2425,7 @@ let Predicates = [HasBF16, HasSVEorSME] in {
24252425
defm BFMLALT_ZZZ : sve2_fp_mla_long<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt>;
24262426
defm BFMLALB_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b100, "bfmlalb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalb_lane_v2>;
24272427
defm BFMLALT_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt_lane_v2>;
2428-
defm BFCVT_ZPmZ : sve_bfloat_convert<0b1, "bfcvt", int_aarch64_sve_fcvt_bf16f32>;
2428+
defm BFCVT_ZPmZ : sve_bfloat_convert<0b1, "bfcvt", int_aarch64_sve_fcvt_bf16f32, AArch64fcvtr_mt>;
24292429
defm BFCVTNT_ZPmZ : sve_bfloat_convert<0b0, "bfcvtnt", int_aarch64_sve_fcvtnt_bf16f32>;
24302430
} // End HasBF16, HasSVEorSME
24312431

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8807,9 +8807,13 @@ class sve_bfloat_convert<bit N, string asm>
88078807
let mayRaiseFPException = 1;
88088808
}
88098809

8810-
multiclass sve_bfloat_convert<bit N, string asm, SDPatternOperator op> {
8810+
multiclass sve_bfloat_convert<bit N, string asm, SDPatternOperator op,
8811+
SDPatternOperator ir_op = null_frag> {
88118812
def NAME : sve_bfloat_convert<N, asm>;
8813+
88128814
def : SVE_3_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8i1, nxv4f32, !cast<Instruction>(NAME)>;
8815+
def : SVE_1_Op_Passthru_Round_Pat<nxv4bf16, ir_op, nxv4i1, nxv4f32, !cast<Instruction>(NAME)>;
8816+
def : SVE_1_Op_Passthru_Round_Pat<nxv2bf16, ir_op, nxv2i1, nxv2f32, !cast<Instruction>(NAME)>;
88138817
}
88148818

88158819
//===----------------------------------------------------------------------===//

llvm/test/CodeGen/AArch64/sve-bf16-converts.ll

Lines changed: 127 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2-
; RUN: llc -mattr=+sve < %s | FileCheck %s
3-
; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s
2+
; RUN: llc -mattr=+sve < %s | FileCheck %s --check-prefixes=CHECK,NOBF16
3+
; RUN: llc -mattr=+sve --enable-no-nans-fp-math < %s | FileCheck %s --check-prefixes=CHECK,NOBF16NNAN
4+
; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s --check-prefixes=CHECK,BF16
5+
; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s --check-prefixes=CHECK,BF16
46

57
target triple = "aarch64-unknown-linux-gnu"
68

9+
; NOTE: "fptrunc <# x double> to <# x bfloat>" is not supported because SVE
10+
; lacks a down convert that rounds to odd. Such IR will trigger the usual
11+
; failure (crash) when attempting to unroll a scalable vector.
12+
713
define <vscale x 2 x float> @fpext_nxv2bf16_to_nxv2f32(<vscale x 2 x bfloat> %a) {
814
; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f32:
915
; CHECK: // %bb.0:
@@ -87,3 +93,122 @@ define <vscale x 8 x double> @fpext_nxv8bf16_to_nxv8f64(<vscale x 8 x bfloat> %a
8793
%res = fpext <vscale x 8 x bfloat> %a to <vscale x 8 x double>
8894
ret <vscale x 8 x double> %res
8995
}
96+
97+
define <vscale x 2 x bfloat> @fptrunc_nxv2f32_to_nxv2bf16(<vscale x 2 x float> %a) {
98+
; NOBF16-LABEL: fptrunc_nxv2f32_to_nxv2bf16:
99+
; NOBF16: // %bb.0:
100+
; NOBF16-NEXT: mov z1.s, #32767 // =0x7fff
101+
; NOBF16-NEXT: lsr z2.s, z0.s, #16
102+
; NOBF16-NEXT: ptrue p0.d
103+
; NOBF16-NEXT: fcmuo p0.s, p0/z, z0.s, z0.s
104+
; NOBF16-NEXT: and z2.s, z2.s, #0x1
105+
; NOBF16-NEXT: add z1.s, z0.s, z1.s
106+
; NOBF16-NEXT: orr z0.s, z0.s, #0x400000
107+
; NOBF16-NEXT: add z1.s, z2.s, z1.s
108+
; NOBF16-NEXT: sel z0.s, p0, z0.s, z1.s
109+
; NOBF16-NEXT: lsr z0.s, z0.s, #16
110+
; NOBF16-NEXT: ret
111+
;
112+
; NOBF16NNAN-LABEL: fptrunc_nxv2f32_to_nxv2bf16:
113+
; NOBF16NNAN: // %bb.0:
114+
; NOBF16NNAN-NEXT: mov z1.s, #32767 // =0x7fff
115+
; NOBF16NNAN-NEXT: lsr z2.s, z0.s, #16
116+
; NOBF16NNAN-NEXT: and z2.s, z2.s, #0x1
117+
; NOBF16NNAN-NEXT: add z0.s, z0.s, z1.s
118+
; NOBF16NNAN-NEXT: add z0.s, z2.s, z0.s
119+
; NOBF16NNAN-NEXT: lsr z0.s, z0.s, #16
120+
; NOBF16NNAN-NEXT: ret
121+
;
122+
; BF16-LABEL: fptrunc_nxv2f32_to_nxv2bf16:
123+
; BF16: // %bb.0:
124+
; BF16-NEXT: ptrue p0.d
125+
; BF16-NEXT: bfcvt z0.h, p0/m, z0.s
126+
; BF16-NEXT: ret
127+
%res = fptrunc <vscale x 2 x float> %a to <vscale x 2 x bfloat>
128+
ret <vscale x 2 x bfloat> %res
129+
}
130+
131+
define <vscale x 4 x bfloat> @fptrunc_nxv4f32_to_nxv4bf16(<vscale x 4 x float> %a) {
132+
; NOBF16-LABEL: fptrunc_nxv4f32_to_nxv4bf16:
133+
; NOBF16: // %bb.0:
134+
; NOBF16-NEXT: mov z1.s, #32767 // =0x7fff
135+
; NOBF16-NEXT: lsr z2.s, z0.s, #16
136+
; NOBF16-NEXT: ptrue p0.s
137+
; NOBF16-NEXT: fcmuo p0.s, p0/z, z0.s, z0.s
138+
; NOBF16-NEXT: and z2.s, z2.s, #0x1
139+
; NOBF16-NEXT: add z1.s, z0.s, z1.s
140+
; NOBF16-NEXT: orr z0.s, z0.s, #0x400000
141+
; NOBF16-NEXT: add z1.s, z2.s, z1.s
142+
; NOBF16-NEXT: sel z0.s, p0, z0.s, z1.s
143+
; NOBF16-NEXT: lsr z0.s, z0.s, #16
144+
; NOBF16-NEXT: ret
145+
;
146+
; NOBF16NNAN-LABEL: fptrunc_nxv4f32_to_nxv4bf16:
147+
; NOBF16NNAN: // %bb.0:
148+
; NOBF16NNAN-NEXT: mov z1.s, #32767 // =0x7fff
149+
; NOBF16NNAN-NEXT: lsr z2.s, z0.s, #16
150+
; NOBF16NNAN-NEXT: and z2.s, z2.s, #0x1
151+
; NOBF16NNAN-NEXT: add z0.s, z0.s, z1.s
152+
; NOBF16NNAN-NEXT: add z0.s, z2.s, z0.s
153+
; NOBF16NNAN-NEXT: lsr z0.s, z0.s, #16
154+
; NOBF16NNAN-NEXT: ret
155+
;
156+
; BF16-LABEL: fptrunc_nxv4f32_to_nxv4bf16:
157+
; BF16: // %bb.0:
158+
; BF16-NEXT: ptrue p0.s
159+
; BF16-NEXT: bfcvt z0.h, p0/m, z0.s
160+
; BF16-NEXT: ret
161+
%res = fptrunc <vscale x 4 x float> %a to <vscale x 4 x bfloat>
162+
ret <vscale x 4 x bfloat> %res
163+
}
164+
165+
define <vscale x 8 x bfloat> @fptrunc_nxv8f32_to_nxv8bf16(<vscale x 8 x float> %a) {
166+
; NOBF16-LABEL: fptrunc_nxv8f32_to_nxv8bf16:
167+
; NOBF16: // %bb.0:
168+
; NOBF16-NEXT: mov z2.s, #32767 // =0x7fff
169+
; NOBF16-NEXT: lsr z3.s, z1.s, #16
170+
; NOBF16-NEXT: lsr z4.s, z0.s, #16
171+
; NOBF16-NEXT: ptrue p0.s
172+
; NOBF16-NEXT: and z3.s, z3.s, #0x1
173+
; NOBF16-NEXT: and z4.s, z4.s, #0x1
174+
; NOBF16-NEXT: fcmuo p1.s, p0/z, z1.s, z1.s
175+
; NOBF16-NEXT: add z5.s, z1.s, z2.s
176+
; NOBF16-NEXT: add z2.s, z0.s, z2.s
177+
; NOBF16-NEXT: fcmuo p0.s, p0/z, z0.s, z0.s
178+
; NOBF16-NEXT: orr z1.s, z1.s, #0x400000
179+
; NOBF16-NEXT: orr z0.s, z0.s, #0x400000
180+
; NOBF16-NEXT: add z3.s, z3.s, z5.s
181+
; NOBF16-NEXT: add z2.s, z4.s, z2.s
182+
; NOBF16-NEXT: sel z1.s, p1, z1.s, z3.s
183+
; NOBF16-NEXT: sel z0.s, p0, z0.s, z2.s
184+
; NOBF16-NEXT: lsr z1.s, z1.s, #16
185+
; NOBF16-NEXT: lsr z0.s, z0.s, #16
186+
; NOBF16-NEXT: uzp1 z0.h, z0.h, z1.h
187+
; NOBF16-NEXT: ret
188+
;
189+
; NOBF16NNAN-LABEL: fptrunc_nxv8f32_to_nxv8bf16:
190+
; NOBF16NNAN: // %bb.0:
191+
; NOBF16NNAN-NEXT: mov z2.s, #32767 // =0x7fff
192+
; NOBF16NNAN-NEXT: lsr z3.s, z1.s, #16
193+
; NOBF16NNAN-NEXT: lsr z4.s, z0.s, #16
194+
; NOBF16NNAN-NEXT: and z3.s, z3.s, #0x1
195+
; NOBF16NNAN-NEXT: and z4.s, z4.s, #0x1
196+
; NOBF16NNAN-NEXT: add z1.s, z1.s, z2.s
197+
; NOBF16NNAN-NEXT: add z0.s, z0.s, z2.s
198+
; NOBF16NNAN-NEXT: add z1.s, z3.s, z1.s
199+
; NOBF16NNAN-NEXT: add z0.s, z4.s, z0.s
200+
; NOBF16NNAN-NEXT: lsr z1.s, z1.s, #16
201+
; NOBF16NNAN-NEXT: lsr z0.s, z0.s, #16
202+
; NOBF16NNAN-NEXT: uzp1 z0.h, z0.h, z1.h
203+
; NOBF16NNAN-NEXT: ret
204+
;
205+
; BF16-LABEL: fptrunc_nxv8f32_to_nxv8bf16:
206+
; BF16: // %bb.0:
207+
; BF16-NEXT: ptrue p0.s
208+
; BF16-NEXT: bfcvt z1.h, p0/m, z1.s
209+
; BF16-NEXT: bfcvt z0.h, p0/m, z0.s
210+
; BF16-NEXT: uzp1 z0.h, z0.h, z1.h
211+
; BF16-NEXT: ret
212+
%res = fptrunc <vscale x 8 x float> %a to <vscale x 8 x bfloat>
213+
ret <vscale x 8 x bfloat> %res
214+
}

0 commit comments

Comments
 (0)