-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[LLVM][AArch64] Improve big endian code generation for SVE BITCASTs. #104769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLVM][AArch64] Improve big endian code generation for SVE BITCASTs. #104769
Conversation
For the most part I've tried to maintain the use of ISD::BITCAST wherever possible. I'm assuming this will keep access to more DAG combines, but perhaps it's more likely to just encourage the proliferation of invalid combines than if I ensure only AArch64ISD::NVCAST/REINTERPRET_CAST survives lowering?
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-aarch64 Author: Paul Walker (paulwalker-arm) ChangesFor the most part I've tried to maintain the use of ISD::BITCAST wherever possible. I'm assuming this will keep access to more DAG combines, but perhaps it's more likely to just encourage the proliferation of invalid combines than if I ensure only AArch64ISD::NVCAST/REINTERPRET_CAST survives lowering? Patch is 98.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104769.diff 4 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index ab12c3b0e728a8..8d06206755d4f8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6140,6 +6140,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
case ISD::BSWAP:
assert(VT.isInteger() && VT == N1.getValueType() && "Invalid BSWAP!");
+ if (VT.getScalarSizeInBits() == 8)
+ return N1;
assert((VT.getScalarSizeInBits() % 16 == 0) &&
"BSWAP types must be a multiple of 16 bits!");
if (OpOpcode == ISD::UNDEF)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 97fb2c5f552731..4d5034d67c5ed8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1496,7 +1496,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::AVGCEILU, VT, Custom);
if (!Subtarget->isLittleEndian())
- setOperationAction(ISD::BITCAST, VT, Expand);
+ setOperationAction(ISD::BITCAST, VT, Custom);
if (Subtarget->hasSVE2() ||
(Subtarget->hasSME() && Subtarget->isStreaming()))
@@ -1510,9 +1510,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
}
- // Legalize unpacked bitcasts to REINTERPRET_CAST.
- for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32, MVT::nxv2bf16,
- MVT::nxv4bf16, MVT::nxv2f16, MVT::nxv4f16, MVT::nxv2f32})
+ // Type legalize unpacked bitcasts.
+ for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32})
setOperationAction(ISD::BITCAST, VT, Custom);
for (auto VT :
@@ -1587,6 +1586,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32,
MVT::nxv4f32, MVT::nxv2f64}) {
+ setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
@@ -1658,20 +1658,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setCondCodeAction(ISD::SETUGT, VT, Expand);
setCondCodeAction(ISD::SETUEQ, VT, Expand);
setCondCodeAction(ISD::SETONE, VT, Expand);
-
- if (!Subtarget->isLittleEndian())
- setOperationAction(ISD::BITCAST, VT, Expand);
}
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
+ setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
-
- if (!Subtarget->isLittleEndian())
- setOperationAction(ISD::BITCAST, VT, Expand);
}
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -4960,22 +4955,35 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
return LowerFixedLengthBitcastToSVE(Op, DAG);
if (OpVT.isScalableVector()) {
- // Bitcasting between unpacked vector types of different element counts is
- // not a NOP because the live elements are laid out differently.
- // 01234567
- // e.g. nxv2i32 = XX??XX??
- // nxv4f16 = X?X?X?X?
- if (OpVT.getVectorElementCount() != ArgVT.getVectorElementCount())
- return SDValue();
+ assert(isTypeLegal(OpVT) && "Unexpected result type!");
- if (isTypeLegal(OpVT) && !isTypeLegal(ArgVT)) {
+ // Handle type legalisation first.
+ if (!isTypeLegal(ArgVT)) {
assert(OpVT.isFloatingPoint() && !ArgVT.isFloatingPoint() &&
"Expected int->fp bitcast!");
+
+ // Bitcasting between unpacked vector types of different element counts is
+ // not a NOP because the live elements are laid out differently.
+ // 01234567
+ // e.g. nxv2i32 = XX??XX??
+ // nxv4f16 = X?X?X?X?
+ if (OpVT.getVectorElementCount() != ArgVT.getVectorElementCount())
+ return SDValue();
+
SDValue ExtResult =
DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT),
Op.getOperand(0));
return getSVESafeBitCast(OpVT, ExtResult, DAG);
}
+
+ // Bitcasts between legal types with the same element count are legal.
+ if (OpVT.getVectorElementCount() == ArgVT.getVectorElementCount())
+ return Op;
+
+ // getSVESafeBitCast does not support casting between unpacked types.
+ if (!isPackedVectorType(OpVT, DAG))
+ return SDValue();
+
return getSVESafeBitCast(OpVT, Op.getOperand(0), DAG);
}
@@ -28877,7 +28885,20 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
if (InVT != PackedInVT)
Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op);
- Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
+ if (Subtarget->isLittleEndian() ||
+ PackedVT.getScalarSizeInBits() == PackedInVT.getScalarSizeInBits())
+ Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
+ else {
+ EVT PackedVTAsInt = PackedVT.changeTypeToInteger();
+ EVT PackedInVTAsInt = PackedInVT.changeTypeToInteger();
+
+ // Simulate the effect of casting through memory.
+ Op = DAG.getNode(ISD::BITCAST, DL, PackedInVTAsInt, Op);
+ Op = DAG.getNode(ISD::BSWAP, DL, PackedInVTAsInt, Op);
+ Op = DAG.getNode(AArch64ISD::NVCAST, DL, PackedVTAsInt, Op);
+ Op = DAG.getNode(ISD::BSWAP, DL, PackedVTAsInt, Op);
+ Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
+ }
// Unpack result if required.
if (VT != PackedVT)
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index d9a70b5ef02fcb..35035aae05ecb6 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2650,113 +2650,62 @@ let Predicates = [HasSVEorSME] in {
sub_32)>;
}
- // FIXME: BigEndian requires an additional REV instruction to satisfy the
- // constraint that none of the bits change when stored to memory as one
- // type, and reloaded as another type.
- let Predicates = [IsLE] in {
- def : Pat<(nxv16i8 (bitconvert nxv8i16:$src)), (nxv16i8 ZPR:$src)>;
- def : Pat<(nxv16i8 (bitconvert nxv4i32:$src)), (nxv16i8 ZPR:$src)>;
- def : Pat<(nxv16i8 (bitconvert nxv2i64:$src)), (nxv16i8 ZPR:$src)>;
- def : Pat<(nxv16i8 (bitconvert nxv8f16:$src)), (nxv16i8 ZPR:$src)>;
- def : Pat<(nxv16i8 (bitconvert nxv4f32:$src)), (nxv16i8 ZPR:$src)>;
- def : Pat<(nxv16i8 (bitconvert nxv2f64:$src)), (nxv16i8 ZPR:$src)>;
-
- def : Pat<(nxv8i16 (bitconvert nxv16i8:$src)), (nxv8i16 ZPR:$src)>;
- def : Pat<(nxv8i16 (bitconvert nxv4i32:$src)), (nxv8i16 ZPR:$src)>;
- def : Pat<(nxv8i16 (bitconvert nxv2i64:$src)), (nxv8i16 ZPR:$src)>;
- def : Pat<(nxv8i16 (bitconvert nxv8f16:$src)), (nxv8i16 ZPR:$src)>;
- def : Pat<(nxv8i16 (bitconvert nxv4f32:$src)), (nxv8i16 ZPR:$src)>;
- def : Pat<(nxv8i16 (bitconvert nxv2f64:$src)), (nxv8i16 ZPR:$src)>;
-
- def : Pat<(nxv4i32 (bitconvert nxv16i8:$src)), (nxv4i32 ZPR:$src)>;
- def : Pat<(nxv4i32 (bitconvert nxv8i16:$src)), (nxv4i32 ZPR:$src)>;
- def : Pat<(nxv4i32 (bitconvert nxv2i64:$src)), (nxv4i32 ZPR:$src)>;
- def : Pat<(nxv4i32 (bitconvert nxv8f16:$src)), (nxv4i32 ZPR:$src)>;
- def : Pat<(nxv4i32 (bitconvert nxv4f32:$src)), (nxv4i32 ZPR:$src)>;
- def : Pat<(nxv4i32 (bitconvert nxv2f64:$src)), (nxv4i32 ZPR:$src)>;
-
- def : Pat<(nxv2i64 (bitconvert nxv16i8:$src)), (nxv2i64 ZPR:$src)>;
- def : Pat<(nxv2i64 (bitconvert nxv8i16:$src)), (nxv2i64 ZPR:$src)>;
- def : Pat<(nxv2i64 (bitconvert nxv4i32:$src)), (nxv2i64 ZPR:$src)>;
- def : Pat<(nxv2i64 (bitconvert nxv8f16:$src)), (nxv2i64 ZPR:$src)>;
- def : Pat<(nxv2i64 (bitconvert nxv4f32:$src)), (nxv2i64 ZPR:$src)>;
- def : Pat<(nxv2i64 (bitconvert nxv2f64:$src)), (nxv2i64 ZPR:$src)>;
-
- def : Pat<(nxv8f16 (bitconvert nxv16i8:$src)), (nxv8f16 ZPR:$src)>;
- def : Pat<(nxv8f16 (bitconvert nxv8i16:$src)), (nxv8f16 ZPR:$src)>;
- def : Pat<(nxv8f16 (bitconvert nxv4i32:$src)), (nxv8f16 ZPR:$src)>;
- def : Pat<(nxv8f16 (bitconvert nxv2i64:$src)), (nxv8f16 ZPR:$src)>;
- def : Pat<(nxv8f16 (bitconvert nxv4f32:$src)), (nxv8f16 ZPR:$src)>;
- def : Pat<(nxv8f16 (bitconvert nxv2f64:$src)), (nxv8f16 ZPR:$src)>;
-
- def : Pat<(nxv4f32 (bitconvert nxv16i8:$src)), (nxv4f32 ZPR:$src)>;
- def : Pat<(nxv4f32 (bitconvert nxv8i16:$src)), (nxv4f32 ZPR:$src)>;
- def : Pat<(nxv4f32 (bitconvert nxv4i32:$src)), (nxv4f32 ZPR:$src)>;
- def : Pat<(nxv4f32 (bitconvert nxv2i64:$src)), (nxv4f32 ZPR:$src)>;
- def : Pat<(nxv4f32 (bitconvert nxv8f16:$src)), (nxv4f32 ZPR:$src)>;
- def : Pat<(nxv4f32 (bitconvert nxv2f64:$src)), (nxv4f32 ZPR:$src)>;
-
- def : Pat<(nxv2f64 (bitconvert nxv16i8:$src)), (nxv2f64 ZPR:$src)>;
- def : Pat<(nxv2f64 (bitconvert nxv8i16:$src)), (nxv2f64 ZPR:$src)>;
- def : Pat<(nxv2f64 (bitconvert nxv4i32:$src)), (nxv2f64 ZPR:$src)>;
- def : Pat<(nxv2f64 (bitconvert nxv2i64:$src)), (nxv2f64 ZPR:$src)>;
- def : Pat<(nxv2f64 (bitconvert nxv8f16:$src)), (nxv2f64 ZPR:$src)>;
- def : Pat<(nxv2f64 (bitconvert nxv4f32:$src)), (nxv2f64 ZPR:$src)>;
-
- def : Pat<(nxv8bf16 (bitconvert nxv16i8:$src)), (nxv8bf16 ZPR:$src)>;
- def : Pat<(nxv8bf16 (bitconvert nxv8i16:$src)), (nxv8bf16 ZPR:$src)>;
- def : Pat<(nxv8bf16 (bitconvert nxv4i32:$src)), (nxv8bf16 ZPR:$src)>;
- def : Pat<(nxv8bf16 (bitconvert nxv2i64:$src)), (nxv8bf16 ZPR:$src)>;
- def : Pat<(nxv8bf16 (bitconvert nxv8f16:$src)), (nxv8bf16 ZPR:$src)>;
- def : Pat<(nxv8bf16 (bitconvert nxv4f32:$src)), (nxv8bf16 ZPR:$src)>;
- def : Pat<(nxv8bf16 (bitconvert nxv2f64:$src)), (nxv8bf16 ZPR:$src)>;
-
- def : Pat<(nxv16i8 (bitconvert nxv8bf16:$src)), (nxv16i8 ZPR:$src)>;
- def : Pat<(nxv8i16 (bitconvert nxv8bf16:$src)), (nxv8i16 ZPR:$src)>;
- def : Pat<(nxv4i32 (bitconvert nxv8bf16:$src)), (nxv4i32 ZPR:$src)>;
- def : Pat<(nxv2i64 (bitconvert nxv8bf16:$src)), (nxv2i64 ZPR:$src)>;
- def : Pat<(nxv8f16 (bitconvert nxv8bf16:$src)), (nxv8f16 ZPR:$src)>;
- def : Pat<(nxv4f32 (bitconvert nxv8bf16:$src)), (nxv4f32 ZPR:$src)>;
- def : Pat<(nxv2f64 (bitconvert nxv8bf16:$src)), (nxv2f64 ZPR:$src)>;
-
- def : Pat<(nxv16i1 (bitconvert aarch64svcount:$src)), (nxv16i1 PPR:$src)>;
- def : Pat<(aarch64svcount (bitconvert nxv16i1:$src)), (aarch64svcount PNR:$src)>;
- }
-
- // These allow casting from/to unpacked predicate types.
- def : Pat<(nxv16i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv16i1 (reinterpret_cast nxv8i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv16i1 (reinterpret_cast nxv4i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv16i1 (reinterpret_cast nxv2i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv16i1 (reinterpret_cast nxv1i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv8i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv8i1 (reinterpret_cast nxv4i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv8i1 (reinterpret_cast nxv2i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv8i1 (reinterpret_cast nxv1i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv4i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv4i1 (reinterpret_cast nxv8i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv4i1 (reinterpret_cast nxv2i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv4i1 (reinterpret_cast nxv1i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv2i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv2i1 (reinterpret_cast nxv8i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv2i1 (reinterpret_cast nxv4i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv2i1 (reinterpret_cast nxv1i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv1i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv1i1 (reinterpret_cast nxv8i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv1i1 (reinterpret_cast nxv4i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
- def : Pat<(nxv1i1 (reinterpret_cast nxv2i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-
- // These allow casting from/to unpacked floating-point types.
- def : Pat<(nxv2f16 (reinterpret_cast nxv8f16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
- def : Pat<(nxv8f16 (reinterpret_cast nxv2f16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
- def : Pat<(nxv4f16 (reinterpret_cast nxv8f16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
- def : Pat<(nxv8f16 (reinterpret_cast nxv4f16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
- def : Pat<(nxv2f32 (reinterpret_cast nxv4f32:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
- def : Pat<(nxv4f32 (reinterpret_cast nxv2f32:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
- def : Pat<(nxv2bf16 (reinterpret_cast nxv8bf16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
- def : Pat<(nxv8bf16 (reinterpret_cast nxv2bf16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
- def : Pat<(nxv4bf16 (reinterpret_cast nxv8bf16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
- def : Pat<(nxv8bf16 (reinterpret_cast nxv4bf16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+ // For big endian, only BITCASTs involving same sized vector types with same
+ // size vector elements can be isel'd directly.
+ let Predicates = [IsLE] in
+ foreach VT = [ nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv8f16, nxv4f32, nxv2f64, nxv8bf16 ] in
+ foreach VT2 = [ nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv8f16, nxv4f32, nxv2f64, nxv8bf16 ] in
+ if !ne(VT,VT2) then
+ def : Pat<(VT (bitconvert (VT2 ZPR:$src))), (VT ZPR:$src)>;
+
+ def : Pat<(nxv8i16 (bitconvert (nxv8f16 ZPR:$src))), (nxv8i16 ZPR:$src)>;
+ def : Pat<(nxv8f16 (bitconvert (nxv8i16 ZPR:$src))), (nxv8f16 ZPR:$src)>;
+
+ def : Pat<(nxv4i32 (bitconvert (nxv4f32 ZPR:$src))), (nxv4i32 ZPR:$src)>;
+ def : Pat<(nxv4f32 (bitconvert (nxv4i32 ZPR:$src))), (nxv4f32 ZPR:$src)>;
+
+ def : Pat<(nxv2i64 (bitconvert (nxv2f64 ZPR:$src))), (nxv2i64 ZPR:$src)>;
+ def : Pat<(nxv2f64 (bitconvert (nxv2i64 ZPR:$src))), (nxv2f64 ZPR:$src)>;
+
+ def : Pat<(nxv8i16 (bitconvert (nxv8bf16 ZPR:$src))), (nxv8i16 ZPR:$src)>;
+ def : Pat<(nxv8bf16 (bitconvert (nxv8i16 ZPR:$src))), (nxv8bf16 ZPR:$src)>;
+
+ def : Pat<(nxv8bf16 (bitconvert (nxv8f16 ZPR:$src))), (nxv8bf16 ZPR:$src)>;
+ def : Pat<(nxv8f16 (bitconvert (nxv8bf16 ZPR:$src))), (nxv8f16 ZPR:$src)>;
+
+ def : Pat<(nxv4bf16 (bitconvert (nxv4f16 ZPR:$src))), (nxv4bf16 ZPR:$src)>;
+ def : Pat<(nxv4f16 (bitconvert (nxv4bf16 ZPR:$src))), (nxv4f16 ZPR:$src)>;
+
+ def : Pat<(nxv2bf16 (bitconvert (nxv2f16 ZPR:$src))), (nxv2bf16 ZPR:$src)>;
+ def : Pat<(nxv2f16 (bitconvert (nxv2bf16 ZPR:$src))), (nxv2f16 ZPR:$src)>;
+
+ def : Pat<(nxv16i1 (bitconvert (aarch64svcount PNR:$src))), (nxv16i1 PPR:$src)>;
+ def : Pat<(aarch64svcount (bitconvert (nxv16i1 PPR:$src))), (aarch64svcount PNR:$src)>;
+
+ // These allow nop casting between predicate vector types.
+ foreach VT = [ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ] in
+ foreach VT2 = [ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ] in
+ def : Pat<(VT (reinterpret_cast (VT2 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
+
+ // These allow nop casting between half vector types.
+ foreach VT = [ nxv2f16, nxv4f16, nxv8f16 ] in
+ foreach VT2 = [ nxv2f16, nxv4f16, nxv8f16 ] in
+ def : Pat<(VT (reinterpret_cast (VT2 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+
+ // These allow nop casting between float vector types.
+ foreach VT = [ nxv2f32, nxv4f32 ] in
+ foreach VT2 = [ nxv2f32, nxv4f32 ] in
+ def : Pat<(VT (reinterpret_cast (VT2 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+
+ // These allow nop casting between bfloat vector types.
+ foreach VT = [ nxv2bf16, nxv4bf16, nxv8bf16 ] in
+ foreach VT2 = [ nxv2bf16, nxv4bf16, nxv8bf16 ] in
+ def : Pat<(VT (reinterpret_cast (VT2 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+
+ // These allow nop casting between all packed vector types.
+ foreach VT = [ nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv8f16, nxv4f32, nxv2f64, nxv8bf16 ] in
+ foreach VT2 = [ nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv8f16, nxv4f32, nxv2f64, nxv8bf16 ] in
+ def : Pat<(VT (AArch64NvCast (VT2 ZPR:$src))), (VT ZPR:$src)>;
def : Pat<(nxv16i1 (and PPR:$Ps1, PPR:$Ps2)),
(AND_PPzPP (PTRUE_B 31), PPR:$Ps1, PPR:$Ps2)>;
diff --git a/llvm/test/CodeGen/AArch64/sve-bitcast.ll b/llvm/test/CodeGen/AArch64/sve-bitcast.ll
index 95f43ba5126323..5d12d41ac3332f 100644
--- a/llvm/test/CodeGen/AArch64/sve-bitcast.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bitcast.ll
@@ -13,14 +13,8 @@ define <vscale x 16 x i8> @bitcast_nxv8i16_to_nxv16i8(<vscale x 8 x i16> %v) #0
;
; CHECK_BE-LABEL: bitcast_nxv8i16_to_nxv16i8:
; CHECK_BE: // %bb.0:
-; CHECK_BE-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK_BE-NEXT: addvl sp, sp, #-1
; CHECK_BE-NEXT: ptrue p0.h
-; CHECK_BE-NEXT: ptrue p1.b
-; CHECK_BE-NEXT: st1h { z0.h }, p0, [sp]
-; CHECK_BE-NEXT: ld1b { z0.b }, p1/z, [sp]
-; CHECK_BE-NEXT: addvl sp, sp, #1
-; CHECK_BE-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK_BE-NEXT: revb z0.h, p0/m, z0.h
; CHECK_BE-NEXT: ret
%bc = bitcast <vscale x 8 x i16> %v to <vscale x 16 x i8>
ret <vscale x 16 x i8> %bc
@@ -33,14 +27,8 @@ define <vscale x 16 x i8> @bitcast_nxv4i32_to_nxv16i8(<vscale x 4 x i32> %v) #0
;
; CHECK_BE-LABEL: bitcast_nxv4i32_to_nxv16i8:
; CHECK_BE: // %bb.0:
-; CHECK_BE-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK_BE-NEXT: addvl sp, sp, #-1
; CHECK_BE-NEXT: ptrue p0.s
-; CHECK_BE-NEXT: ptrue p1.b
-; CHECK_BE-NEXT: st1w { z0.s }, p0, [sp]
-; CHECK_BE-NEXT: ld1b { z0.b }, p1/z, [sp]
-; CHECK_BE-NEXT: addvl sp, sp, #1
-; CHECK_BE-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK_BE-NEXT: revb z0.s, p0/m, z0.s
; CHECK_BE-NEXT: ret
%bc = bitcast <vscale x 4 x i32> %v to <vscale x 16 x i8>
ret <vscale x 16 x i8> %bc
@@ -53,14 +41,8 @@ define <vscale x 16 x i8> @bitcast_nxv2i64_to_nxv16i8(<vscale x 2 x i64> %v) #0
;
; CHECK_BE-LABEL: bitcast_nxv2i64_to_nxv16i8:
; CHECK_BE: // %bb.0:
-; CHECK_BE-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK_BE-NEXT: addvl sp, sp, #-1
; CHECK_BE-NEXT: ptrue p0.d
-; CHECK_BE-NEXT: ptrue p1.b
-; CHECK_BE-NEXT: st1d { z0.d }, p0, [sp]
-; CHECK_BE-NEXT: ld1b { z0.b }, p1/z, [sp]
-; CHECK_BE-NEXT: addvl sp, sp, #1
-; CHECK_BE-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK_BE-NEXT: revb z0.d, p0/m, z0.d
; CHECK_BE-NEXT: ret
%bc = bitcast <vscale x 2 x i64> %v to <vscale x 16 x i8>
ret <vscale x 16 x i8> %bc
@@ -73,14 +55,8 @@ define <vscale x 16 x i8> @bitcast_nxv8f16_to_nxv16i8(<vscale x 8 x half> %v) #0
;
; CHECK_BE-LABEL: bitcast_nxv8f16_to_nxv16i8:
; CHECK_BE: // %bb.0:
-; CHECK_BE-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK_BE-NEXT: addvl sp, sp, #-1
; CHECK_BE-NEXT: ptrue p0.h
-; CHECK_BE-NEXT: ptrue p1.b
-; CHECK_BE-NEXT: st1h { z0.h }, p0, [sp]
-; CHECK_BE-NEXT: ld1b { z0.b }, p1/z, [sp]
-; CHECK_BE-NEXT: addvl sp, sp, #1
-; CHECK_BE-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK_BE-NEXT: revb z0.h, p0/m, z0.h
; CHECK_BE-NEXT: ret
%bc = bitcast <vscale x 8 x half> %v to <vscale x 16 x i8>
ret <vscale x 16 x i8> %bc
@@ -93,14 +69,8 @@ define <vscale x 16 x i8> @bitcast_nxv4f32_to_nxv16i8(<vscale x 4 x float> %v) #
;
; CHECK_BE-LABEL: bitcast_nxv4f32_to_nxv16i8:
; CHECK_BE: // %bb.0:
-; CHECK_BE-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK_BE-NEXT: addvl sp, sp, #-1
; CHECK_BE-NEXT: ptrue p0.s
-; CHECK_BE-NEXT: ptrue p1.b
-; CHECK_BE-NEXT: st1w { z0.s }, p0, [sp]
-; CHECK_BE-NEXT: ...
[truncated]
|
|
||
// Simulate the effect of casting through memory. | ||
Op = DAG.getNode(ISD::BITCAST, DL, PackedInVTAsInt, Op); | ||
Op = DAG.getNode(ISD::BSWAP, DL, PackedInVTAsInt, Op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In most cases, there's some alternative to two bswaps that's more efficient (like some kind of rotate); are you planning to look at that as a followup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't planning to. The best big endian code generation is not the rationale for the work. I'm fixing bugs with how we're using uzp, uupk{lo,hi} and zip instructions and the fixes keep tripping over getSVESafeBitCast
and the current big endian support so it's just easier for me to move the support along a little to clear the path for what I really want to fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For context 245d3c6 contains the fixes I'm referring to above. You'll see this still affects some of the big endian bitcast tests but it's more contained (albeit I've not verified the changes are all correct yet). That PR is still a work in progress.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, that makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
// Simulate the effect of casting through memory. | ||
Op = DAG.getNode(ISD::BITCAST, DL, PackedInVTAsInt, Op); | ||
Op = DAG.getNode(ISD::BSWAP, DL, PackedInVTAsInt, Op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, that makes sense.
For the most part I've tried to maintain the use of ISD::BITCAST wherever possible. I'm assuming this will keep access to more DAG combines, but perhaps it's more likely to just encourage the proliferation of invalid combines than if I ensure only AArch64ISD::NVCAST/REINTERPRET_CAST survives lowering?