Skip to content

Commit be1958f

Browse files
[LLVM][CodeGen][SVE] Implement nxvbf16 fpextend to nxvf32/nxvf64. (#107253)
NOTE: There are no dedicated SVE instructions but bf16->f32 is just a left shift because they share the same exponent range and from there other convert instructions can be used.
1 parent c2018fa commit be1958f

File tree

3 files changed

+117
-2
lines changed

3 files changed

+117
-2
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1663,6 +1663,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
16631663
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
16641664
setOperationAction(ISD::BITCAST, VT, Custom);
16651665
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1666+
setOperationAction(ISD::FP_EXTEND, VT, Custom);
16661667
setOperationAction(ISD::MLOAD, VT, Custom);
16671668
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
16681669
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
@@ -4298,8 +4299,28 @@ static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) {
42984299
SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
42994300
SelectionDAG &DAG) const {
43004301
EVT VT = Op.getValueType();
4301-
if (VT.isScalableVector())
4302+
if (VT.isScalableVector()) {
4303+
SDValue SrcVal = Op.getOperand(0);
4304+
4305+
if (SrcVal.getValueType().getScalarType() == MVT::bf16) {
4306+
// bf16 and f32 share the same exponent range so the conversion requires
4307+
// them to be aligned with the new mantissa bits zero'd. This is just a
4308+
// left shift that is best to isel directly.
4309+
if (VT == MVT::nxv2f32 || VT == MVT::nxv4f32)
4310+
return Op;
4311+
4312+
if (VT != MVT::nxv2f64)
4313+
return SDValue();
4314+
4315+
// Break other conversions in two with the first part converting to f32
4316+
// and the second using native f32->VT instructions.
4317+
SDLoc DL(Op);
4318+
return DAG.getNode(ISD::FP_EXTEND, DL, VT,
4319+
DAG.getNode(ISD::FP_EXTEND, DL, MVT::nxv2f32, SrcVal));
4320+
}
4321+
43024322
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_EXTEND_MERGE_PASSTHRU);
4323+
}
43034324

43044325
if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
43054326
return LowerFixedLengthFPExtendToSVE(Op, DAG);

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2320,7 +2320,12 @@ let Predicates = [HasSVEorSME] in {
23202320
def : Pat<(nxv2f16 (AArch64fcvtr_mt (nxv2i1 (SVEAllActive:$Pg)), nxv2f32:$Zs, (i64 timm0_1), nxv2f16:$Zd)),
23212321
(FCVT_ZPmZ_StoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
23222322

2323-
// Signed integer -> Floating-point
2323+
def : Pat<(nxv4f32 (fpextend nxv4bf16:$op)),
2324+
(LSL_ZZI_S $op, (i32 16))>;
2325+
def : Pat<(nxv2f32 (fpextend nxv2bf16:$op)),
2326+
(LSL_ZZI_S $op, (i32 16))>;
2327+
2328+
// Signed integer -> Floating-point
23242329
def : Pat<(nxv2f16 (AArch64scvtf_mt (nxv2i1 (SVEAllActive):$Pg),
23252330
(sext_inreg nxv2i64:$Zs, nxv2i16), nxv2f16:$Zd)),
23262331
(SCVTF_ZPmZ_HtoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mattr=+sve < %s | FileCheck %s
3+
; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s
4+
5+
target triple = "aarch64-unknown-linux-gnu"
6+
7+
define <vscale x 2 x float> @fpext_nxv2bf16_to_nxv2f32(<vscale x 2 x bfloat> %a) {
8+
; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f32:
9+
; CHECK: // %bb.0:
10+
; CHECK-NEXT: lsl z0.s, z0.s, #16
11+
; CHECK-NEXT: ret
12+
%res = fpext <vscale x 2 x bfloat> %a to <vscale x 2 x float>
13+
ret <vscale x 2 x float> %res
14+
}
15+
16+
define <vscale x 4 x float> @fpext_nxv4bf16_to_nxv4f32(<vscale x 4 x bfloat> %a) {
17+
; CHECK-LABEL: fpext_nxv4bf16_to_nxv4f32:
18+
; CHECK: // %bb.0:
19+
; CHECK-NEXT: lsl z0.s, z0.s, #16
20+
; CHECK-NEXT: ret
21+
%res = fpext <vscale x 4 x bfloat> %a to <vscale x 4 x float>
22+
ret <vscale x 4 x float> %res
23+
}
24+
25+
define <vscale x 8 x float> @fpext_nxv8bf16_to_nxv8f32(<vscale x 8 x bfloat> %a) {
26+
; CHECK-LABEL: fpext_nxv8bf16_to_nxv8f32:
27+
; CHECK: // %bb.0:
28+
; CHECK-NEXT: uunpklo z1.s, z0.h
29+
; CHECK-NEXT: uunpkhi z2.s, z0.h
30+
; CHECK-NEXT: lsl z0.s, z1.s, #16
31+
; CHECK-NEXT: lsl z1.s, z2.s, #16
32+
; CHECK-NEXT: ret
33+
%res = fpext <vscale x 8 x bfloat> %a to <vscale x 8 x float>
34+
ret <vscale x 8 x float> %res
35+
}
36+
37+
define <vscale x 2 x double> @fpext_nxv2bf16_to_nxv2f64(<vscale x 2 x bfloat> %a) {
38+
; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f64:
39+
; CHECK: // %bb.0:
40+
; CHECK-NEXT: lsl z0.s, z0.s, #16
41+
; CHECK-NEXT: ptrue p0.d
42+
; CHECK-NEXT: fcvt z0.d, p0/m, z0.s
43+
; CHECK-NEXT: ret
44+
%res = fpext <vscale x 2 x bfloat> %a to <vscale x 2 x double>
45+
ret <vscale x 2 x double> %res
46+
}
47+
48+
define <vscale x 4 x double> @fpext_nxv4bf16_to_nxv4f64(<vscale x 4 x bfloat> %a) {
49+
; CHECK-LABEL: fpext_nxv4bf16_to_nxv4f64:
50+
; CHECK: // %bb.0:
51+
; CHECK-NEXT: uunpklo z1.d, z0.s
52+
; CHECK-NEXT: uunpkhi z0.d, z0.s
53+
; CHECK-NEXT: ptrue p0.d
54+
; CHECK-NEXT: lsl z1.s, z1.s, #16
55+
; CHECK-NEXT: lsl z2.s, z0.s, #16
56+
; CHECK-NEXT: movprfx z0, z1
57+
; CHECK-NEXT: fcvt z0.d, p0/m, z1.s
58+
; CHECK-NEXT: movprfx z1, z2
59+
; CHECK-NEXT: fcvt z1.d, p0/m, z2.s
60+
; CHECK-NEXT: ret
61+
%res = fpext <vscale x 4 x bfloat> %a to <vscale x 4 x double>
62+
ret <vscale x 4 x double> %res
63+
}
64+
65+
define <vscale x 8 x double> @fpext_nxv8bf16_to_nxv8f64(<vscale x 8 x bfloat> %a) {
66+
; CHECK-LABEL: fpext_nxv8bf16_to_nxv8f64:
67+
; CHECK: // %bb.0:
68+
; CHECK-NEXT: uunpklo z1.s, z0.h
69+
; CHECK-NEXT: uunpkhi z0.s, z0.h
70+
; CHECK-NEXT: ptrue p0.d
71+
; CHECK-NEXT: uunpklo z2.d, z1.s
72+
; CHECK-NEXT: uunpkhi z1.d, z1.s
73+
; CHECK-NEXT: uunpklo z3.d, z0.s
74+
; CHECK-NEXT: uunpkhi z0.d, z0.s
75+
; CHECK-NEXT: lsl z1.s, z1.s, #16
76+
; CHECK-NEXT: lsl z2.s, z2.s, #16
77+
; CHECK-NEXT: lsl z3.s, z3.s, #16
78+
; CHECK-NEXT: lsl z4.s, z0.s, #16
79+
; CHECK-NEXT: fcvt z1.d, p0/m, z1.s
80+
; CHECK-NEXT: movprfx z0, z2
81+
; CHECK-NEXT: fcvt z0.d, p0/m, z2.s
82+
; CHECK-NEXT: movprfx z2, z3
83+
; CHECK-NEXT: fcvt z2.d, p0/m, z3.s
84+
; CHECK-NEXT: movprfx z3, z4
85+
; CHECK-NEXT: fcvt z3.d, p0/m, z4.s
86+
; CHECK-NEXT: ret
87+
%res = fpext <vscale x 8 x bfloat> %a to <vscale x 8 x double>
88+
ret <vscale x 8 x double> %res
89+
}

0 commit comments

Comments
 (0)