Skip to content

Commit 0176ac9

Browse files
committed
[AArch64] Optimize SVE bitcasts of unpacked types.
Target-independent code only knows how to spill to the stack; instead, use AArch64ISD::REINTERPRET_CAST. Differential Revision: https://reviews.llvm.org/D104573
1 parent 430bfc4 commit 0176ac9

File tree

3 files changed

+63
-5
lines changed

3 files changed

+63
-5
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,6 +1192,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
11921192
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
11931193
}
11941194

1195+
// Legalize unpacked bitcasts to REINTERPRET_CAST.
1196+
for (auto VT : {MVT::nxv2i32, MVT::nxv2f32})
1197+
setOperationAction(ISD::BITCAST, VT, Custom);
1198+
11951199
for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) {
11961200
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
11971201
setOperationAction(ISD::SELECT, VT, Custom);
@@ -3508,17 +3512,30 @@ SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op,
35083512
return CallResult.first;
35093513
}
35103514

3515+
static MVT getSVEContainerType(EVT ContentTy);
3516+
35113517
SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
35123518
SelectionDAG &DAG) const {
35133519
EVT OpVT = Op.getValueType();
3520+
EVT ArgVT = Op.getOperand(0).getValueType();
35143521

35153522
if (useSVEForFixedLengthVectorVT(OpVT))
35163523
return LowerFixedLengthBitcastToSVE(Op, DAG);
35173524

3525+
if (OpVT == MVT::nxv2f32) {
3526+
if (ArgVT.isInteger()) {
3527+
SDValue ExtResult =
3528+
DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT),
3529+
Op.getOperand(0));
3530+
return getSVESafeBitCast(MVT::nxv2f32, ExtResult, DAG);
3531+
}
3532+
return getSVESafeBitCast(MVT::nxv2f32, Op.getOperand(0), DAG);
3533+
}
3534+
35183535
if (OpVT != MVT::f16 && OpVT != MVT::bf16)
35193536
return SDValue();
35203537

3521-
assert(Op.getOperand(0).getValueType() == MVT::i16);
3538+
assert(ArgVT == MVT::i16);
35223539
SDLoc DL(Op);
35233540

35243541
Op = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op.getOperand(0));
@@ -16866,11 +16883,18 @@ bool AArch64TargetLowering::getPostIndexedAddressParts(
1686616883
return true;
1686716884
}
1686816885

16869-
static void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
16870-
SelectionDAG &DAG) {
16886+
void AArch64TargetLowering::ReplaceBITCASTResults(
16887+
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
1687116888
SDLoc DL(N);
1687216889
SDValue Op = N->getOperand(0);
1687316890

16891+
if (N->getValueType(0) == MVT::nxv2i32 &&
16892+
Op.getValueType().isFloatingPoint()) {
16893+
SDValue CastResult = getSVESafeBitCast(MVT::nxv2i64, Op, DAG);
16894+
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::nxv2i32, CastResult));
16895+
return;
16896+
}
16897+
1687416898
if (N->getValueType(0) != MVT::i16 ||
1687516899
(Op.getValueType() != MVT::f16 && Op.getValueType() != MVT::bf16))
1687616900
return;
@@ -18428,8 +18452,6 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
1842818452

1842918453
EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
1843018454
EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
18431-
assert((VT == PackedVT || InVT == PackedInVT) &&
18432-
"Cannot cast between unpacked scalable vector types!");
1843318455

1843418456
// Pack input if required.
1843518457
if (InVT != PackedInVT)

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,8 @@ class AArch64TargetLowering : public TargetLowering {
10661066

10671067
void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
10681068
SelectionDAG &DAG) const override;
1069+
void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
1070+
SelectionDAG &DAG) const;
10691071
void ReplaceExtractSubVectorResults(SDNode *N,
10701072
SmallVectorImpl<SDValue> &Results,
10711073
SelectionDAG &DAG) const;

llvm/test/CodeGen/AArch64/sve-bitcast.ll

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,5 +450,39 @@ define <vscale x 8 x bfloat> @bitcast_double_to_bfloat(<vscale x 2 x double> %v)
450450
ret <vscale x 8 x bfloat> %bc
451451
}
452452

453+
define <vscale x 2 x i32> @bitcast_short_float_to_i32(<vscale x 2 x double> %v) #0 {
454+
; CHECK-LABEL: bitcast_short_float_to_i32:
455+
; CHECK: // %bb.0:
456+
; CHECK-NEXT: ptrue p0.d
457+
; CHECK-NEXT: fcvt z0.s, p0/m, z0.d
458+
; CHECK-NEXT: ret
459+
%trunc = fptrunc <vscale x 2 x double> %v to <vscale x 2 x float>
460+
%bitcast = bitcast <vscale x 2 x float> %trunc to <vscale x 2 x i32>
461+
ret <vscale x 2 x i32> %bitcast
462+
}
463+
464+
define <vscale x 2 x double> @bitcast_short_i32_to_float(<vscale x 2 x i64> %v) #0 {
465+
; CHECK-LABEL: bitcast_short_i32_to_float:
466+
; CHECK: // %bb.0:
467+
; CHECK-NEXT: ptrue p0.d
468+
; CHECK-NEXT: fcvt z0.d, p0/m, z0.s
469+
; CHECK-NEXT: ret
470+
%trunc = trunc <vscale x 2 x i64> %v to <vscale x 2 x i32>
471+
%bitcast = bitcast <vscale x 2 x i32> %trunc to <vscale x 2 x float>
472+
%extended = fpext <vscale x 2 x float> %bitcast to <vscale x 2 x double>
473+
ret <vscale x 2 x double> %extended
474+
}
475+
476+
define <vscale x 2 x float> @bitcast_short_half_to_float(<vscale x 4 x half> %v) #0 {
477+
; CHECK-LABEL: bitcast_short_half_to_float:
478+
; CHECK: // %bb.0:
479+
; CHECK-NEXT: ptrue p0.s
480+
; CHECK-NEXT: fadd z0.h, p0/m, z0.h, z0.h
481+
; CHECK-NEXT: ret
482+
%add = fadd <vscale x 4 x half> %v, %v
483+
%bitcast = bitcast <vscale x 4 x half> %add to <vscale x 2 x float>
484+
ret <vscale x 2 x float> %bitcast
485+
}
486+
453487
; +bf16 is required for the bfloat version.
454488
attributes #0 = { "target-features"="+sve,+bf16" }

0 commit comments

Comments
 (0)