Skip to content

[AArch64][SVE2] Generate XAR #77160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4275,6 +4275,58 @@ bool AArch64DAGToDAGISel::trySelectXAR(SDNode *N) {

SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N->getValueType(0);

// Essentially: rotr (xor(x, y), imm) -> xar (x, y, imm)
// Rotate by a constant is a funnel shift in IR which is exanded to
// an OR with shifted operands.
// We do the following transform:
// OR N0, N1 -> xar (x, y, imm)
// Where:
// N1 = SRL_PRED true, V, splat(imm) --> rotr amount
// N0 = SHL_PRED true, V, splat(bits-imm)
// V = (xor x, y)
if (VT.isScalableVector() && Subtarget->hasSVE2orSME()) {
if (N0.getOpcode() != AArch64ISD::SHL_PRED ||
N1.getOpcode() != AArch64ISD::SRL_PRED)
std::swap(N0, N1);
if (N0.getOpcode() != AArch64ISD::SHL_PRED ||
N1.getOpcode() != AArch64ISD::SRL_PRED)
return false;

auto *TLI = static_cast<const AArch64TargetLowering *>(getTargetLowering());
if (!TLI->isAllActivePredicate(*CurDAG, N0.getOperand(0)) ||
!TLI->isAllActivePredicate(*CurDAG, N1.getOperand(0)))
return false;

SDValue XOR = N0.getOperand(1);
if (XOR.getOpcode() != ISD::XOR || XOR != N1.getOperand(1))
return false;

APInt ShlAmt, ShrAmt;
if (!ISD::isConstantSplatVector(N0.getOperand(2).getNode(), ShlAmt) ||
!ISD::isConstantSplatVector(N1.getOperand(2).getNode(), ShrAmt))
return false;

if (ShlAmt + ShrAmt != VT.getScalarSizeInBits())
return false;

SDLoc DL(N);
SDValue Imm =
CurDAG->getTargetConstant(ShrAmt.getZExtValue(), DL, MVT::i32);

SDValue Ops[] = {XOR.getOperand(0), XOR.getOperand(1), Imm};
if (auto Opc = SelectOpcodeFromVT<SelectTypeKind::Int>(
VT, {AArch64::XAR_ZZZI_B, AArch64::XAR_ZZZI_H, AArch64::XAR_ZZZI_S,
AArch64::XAR_ZZZI_D})) {
CurDAG->SelectNodeTo(N, Opc, VT, Ops);
return true;
}
return false;
}

if (!Subtarget->hasSHA3())
return false;

if (N0->getOpcode() != AArch64ISD::VSHL ||
N1->getOpcode() != AArch64ISD::VLSHR)
Expand Down Expand Up @@ -4367,7 +4419,7 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
case ISD::OR:
if (tryBitfieldInsertOp(Node))
return;
if (Subtarget->hasSHA3() && trySelectXAR(Node))
if (trySelectXAR(Node))
return;
break;

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64Subtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ class AArch64Subtarget final : public AArch64GenSubtargetInfo {
void mirFileLoaded(MachineFunction &MF) const override;

bool hasSVEorSME() const { return hasSVE() || hasSME(); }
bool hasSVE2orSME() const { return hasSVE2() || hasSME(); }

// Return the known range for the bit length of SVE data registers. A value
// of 0 means nothing is known about that particular limit beyong what's
Expand Down
240 changes: 240 additions & 0 deletions llvm/test/CodeGen/AArch64/sve2-xar.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc -mtriple=aarch64 -mattr=+sve < %s -o - | FileCheck --check-prefixes=CHECK,SVE %s
; RUN: llc -mtriple=aarch64 -mattr=+sve2 < %s -o - | FileCheck --check-prefixes=CHECK,SVE2 %s

define <vscale x 2 x i64> @xar_nxv2i64_l(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y) {
; SVE-LABEL: xar_nxv2i64_l:
; SVE: // %bb.0:
; SVE-NEXT: eor z0.d, z0.d, z1.d
; SVE-NEXT: lsr z1.d, z0.d, #4
; SVE-NEXT: lsl z0.d, z0.d, #60
; SVE-NEXT: orr z0.d, z0.d, z1.d
; SVE-NEXT: ret
;
; SVE2-LABEL: xar_nxv2i64_l:
; SVE2: // %bb.0:
; SVE2-NEXT: xar z0.d, z0.d, z1.d, #4
; SVE2-NEXT: ret
%a = xor <vscale x 2 x i64> %x, %y
%b = call <vscale x 2 x i64> @llvm.fshl.nxv2i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %a, <vscale x 2 x i64> splat (i64 60))
ret <vscale x 2 x i64> %b
}

define <vscale x 2 x i64> @xar_nxv2i64_r(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y) {
; SVE-LABEL: xar_nxv2i64_r:
; SVE: // %bb.0:
; SVE-NEXT: eor z0.d, z0.d, z1.d
; SVE-NEXT: lsl z1.d, z0.d, #60
; SVE-NEXT: lsr z0.d, z0.d, #4
; SVE-NEXT: orr z0.d, z0.d, z1.d
; SVE-NEXT: ret
;
; SVE2-LABEL: xar_nxv2i64_r:
; SVE2: // %bb.0:
; SVE2-NEXT: xar z0.d, z0.d, z1.d, #4
; SVE2-NEXT: ret
%a = xor <vscale x 2 x i64> %x, %y
%b = call <vscale x 2 x i64> @llvm.fshr.nxv2i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %a, <vscale x 2 x i64> splat (i64 4))
ret <vscale x 2 x i64> %b
}


define <vscale x 4 x i32> @xar_nxv4i32_l(<vscale x 4 x i32> %x, <vscale x 4 x i32> %y) {
; SVE-LABEL: xar_nxv4i32_l:
; SVE: // %bb.0:
; SVE-NEXT: eor z0.d, z0.d, z1.d
; SVE-NEXT: lsr z1.s, z0.s, #4
; SVE-NEXT: lsl z0.s, z0.s, #28
; SVE-NEXT: orr z0.d, z0.d, z1.d
; SVE-NEXT: ret
;
; SVE2-LABEL: xar_nxv4i32_l:
; SVE2: // %bb.0:
; SVE2-NEXT: xar z0.s, z0.s, z1.s, #4
; SVE2-NEXT: ret
%a = xor <vscale x 4 x i32> %x, %y
%b = call <vscale x 4 x i32> @llvm.fshl.nxv4i32(<vscale x 4 x i32> %a, <vscale x 4 x i32> %a, <vscale x 4 x i32> splat (i32 28))
ret <vscale x 4 x i32> %b
}

define <vscale x 4 x i32> @xar_nxv4i32_r(<vscale x 4 x i32> %x, <vscale x 4 x i32> %y) {
; SVE-LABEL: xar_nxv4i32_r:
; SVE: // %bb.0:
; SVE-NEXT: eor z0.d, z0.d, z1.d
; SVE-NEXT: lsl z1.s, z0.s, #28
; SVE-NEXT: lsr z0.s, z0.s, #4
; SVE-NEXT: orr z0.d, z0.d, z1.d
; SVE-NEXT: ret
;
; SVE2-LABEL: xar_nxv4i32_r:
; SVE2: // %bb.0:
; SVE2-NEXT: xar z0.s, z0.s, z1.s, #4
; SVE2-NEXT: ret
%a = xor <vscale x 4 x i32> %x, %y
%b = call <vscale x 4 x i32> @llvm.fshr.nxv4i32(<vscale x 4 x i32> %a, <vscale x 4 x i32> %a, <vscale x 4 x i32> splat (i32 4))
ret <vscale x 4 x i32> %b
}

define <vscale x 8 x i16> @xar_nxv8i16_l(<vscale x 8 x i16> %x, <vscale x 8 x i16> %y) {
; SVE-LABEL: xar_nxv8i16_l:
; SVE: // %bb.0:
; SVE-NEXT: eor z0.d, z0.d, z1.d
; SVE-NEXT: lsr z1.h, z0.h, #4
; SVE-NEXT: lsl z0.h, z0.h, #12
; SVE-NEXT: orr z0.d, z0.d, z1.d
; SVE-NEXT: ret
;
; SVE2-LABEL: xar_nxv8i16_l:
; SVE2: // %bb.0:
; SVE2-NEXT: xar z0.h, z0.h, z1.h, #4
; SVE2-NEXT: ret
%a = xor <vscale x 8 x i16> %x, %y
%b = call <vscale x 8 x i16> @llvm.fshl.nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i16> %a, <vscale x 8 x i16> splat (i16 12))
ret <vscale x 8 x i16> %b
}

define <vscale x 8 x i16> @xar_nxv8i16_r(<vscale x 8 x i16> %x, <vscale x 8 x i16> %y) {
; SVE-LABEL: xar_nxv8i16_r:
; SVE: // %bb.0:
; SVE-NEXT: eor z0.d, z0.d, z1.d
; SVE-NEXT: lsl z1.h, z0.h, #12
; SVE-NEXT: lsr z0.h, z0.h, #4
; SVE-NEXT: orr z0.d, z0.d, z1.d
; SVE-NEXT: ret
;
; SVE2-LABEL: xar_nxv8i16_r:
; SVE2: // %bb.0:
; SVE2-NEXT: xar z0.h, z0.h, z1.h, #4
; SVE2-NEXT: ret
%a = xor <vscale x 8 x i16> %x, %y
%b = call <vscale x 8 x i16> @llvm.fshr.nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i16> %a, <vscale x 8 x i16> splat (i16 4))
ret <vscale x 8 x i16> %b
}

define <vscale x 16 x i8> @xar_nxv16i8_l(<vscale x 16 x i8> %x, <vscale x 16 x i8> %y) {
; SVE-LABEL: xar_nxv16i8_l:
; SVE: // %bb.0:
; SVE-NEXT: eor z0.d, z0.d, z1.d
; SVE-NEXT: lsr z1.b, z0.b, #4
; SVE-NEXT: lsl z0.b, z0.b, #4
; SVE-NEXT: orr z0.d, z0.d, z1.d
; SVE-NEXT: ret
;
; SVE2-LABEL: xar_nxv16i8_l:
; SVE2: // %bb.0:
; SVE2-NEXT: xar z0.b, z0.b, z1.b, #4
; SVE2-NEXT: ret
%a = xor <vscale x 16 x i8> %x, %y
%b = call <vscale x 16 x i8> @llvm.fshl.nxv16i8(<vscale x 16 x i8> %a, <vscale x 16 x i8> %a, <vscale x 16 x i8> splat (i8 4))
ret <vscale x 16 x i8> %b
}

define <vscale x 16 x i8> @xar_nxv16i8_r(<vscale x 16 x i8> %x, <vscale x 16 x i8> %y) {
; SVE-LABEL: xar_nxv16i8_r:
; SVE: // %bb.0:
; SVE-NEXT: eor z0.d, z0.d, z1.d
; SVE-NEXT: lsl z1.b, z0.b, #4
; SVE-NEXT: lsr z0.b, z0.b, #4
; SVE-NEXT: orr z0.d, z0.d, z1.d
; SVE-NEXT: ret
;
; SVE2-LABEL: xar_nxv16i8_r:
; SVE2: // %bb.0:
; SVE2-NEXT: xar z0.b, z0.b, z1.b, #4
; SVE2-NEXT: ret
%a = xor <vscale x 16 x i8> %x, %y
%b = call <vscale x 16 x i8> @llvm.fshr.nxv16i8(<vscale x 16 x i8> %a, <vscale x 16 x i8> %a, <vscale x 16 x i8> splat (i8 4))
ret <vscale x 16 x i8> %b
}

; Shift is not a constant.
define <vscale x 2 x i64> @xar_nxv2i64_l_neg1(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y, <vscale x 2 x i64> %z) {
; CHECK-LABEL: xar_nxv2i64_l_neg1:
; CHECK: // %bb.0:
; CHECK-NEXT: mov z3.d, z2.d
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: subr z2.d, z2.d, #0 // =0x0
; CHECK-NEXT: eor z0.d, z0.d, z1.d
; CHECK-NEXT: and z2.d, z2.d, #0x3f
; CHECK-NEXT: and z3.d, z3.d, #0x3f
; CHECK-NEXT: movprfx z1, z0
; CHECK-NEXT: lsl z1.d, p0/m, z1.d, z3.d
; CHECK-NEXT: lsr z0.d, p0/m, z0.d, z2.d
; CHECK-NEXT: orr z0.d, z1.d, z0.d
; CHECK-NEXT: ret
%a = xor <vscale x 2 x i64> %x, %y
%b = call <vscale x 2 x i64> @llvm.fshl.nxv2i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %a, <vscale x 2 x i64> %z)
ret <vscale x 2 x i64> %b
}

; OR instead of an XOR.
; TODO: We could use usra instruction here for SVE2.
define <vscale x 2 x i64> @xar_nxv2i64_l_neg2(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y) {
; CHECK-LABEL: xar_nxv2i64_l_neg2:
; CHECK: // %bb.0:
; CHECK-NEXT: orr z0.d, z0.d, z1.d
; CHECK-NEXT: lsr z1.d, z0.d, #4
; CHECK-NEXT: lsl z0.d, z0.d, #60
; CHECK-NEXT: orr z0.d, z0.d, z1.d
; CHECK-NEXT: ret
%a = or <vscale x 2 x i64> %x, %y
%b = call <vscale x 2 x i64> @llvm.fshl.nxv2i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %a, <vscale x 2 x i64> splat (i64 60))
ret <vscale x 2 x i64> %b
}

; Rotate amount is 0.
define <vscale x 2 x i64> @xar_nxv2i64_l_neg3(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y) {
; CHECK-LABEL: xar_nxv2i64_l_neg3:
; CHECK: // %bb.0:
; CHECK-NEXT: eor z0.d, z0.d, z1.d
; CHECK-NEXT: ret
%a = xor <vscale x 2 x i64> %x, %y
%b = call <vscale x 2 x i64> @llvm.fshl.nxv2i64(<vscale x 2 x i64> %a, <vscale x 2 x i64> %a, <vscale x 2 x i64> splat (i64 64))
ret <vscale x 2 x i64> %b
}

; Uses individual shifts instead of funnel shifts, just one test.
define <vscale x 2 x i64> @xar_nxv2i64_shifts(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y) {
; SVE-LABEL: xar_nxv2i64_shifts:
; SVE: // %bb.0:
; SVE-NEXT: eor z0.d, z0.d, z1.d
; SVE-NEXT: lsr z1.d, z0.d, #4
; SVE-NEXT: lsl z0.d, z0.d, #60
; SVE-NEXT: orr z0.d, z0.d, z1.d
; SVE-NEXT: ret
;
; SVE2-LABEL: xar_nxv2i64_shifts:
; SVE2: // %bb.0:
; SVE2-NEXT: xar z0.d, z0.d, z1.d, #4
; SVE2-NEXT: ret
%xor = xor <vscale x 2 x i64> %x, %y
%shl = shl <vscale x 2 x i64> %xor, splat (i64 60)
%shr = lshr <vscale x 2 x i64> %xor, splat (i64 4)
%or = or <vscale x 2 x i64> %shl, %shr
ret <vscale x 2 x i64> %or
}

; Not a rotate operation as 60 + 3 != 64
define <vscale x 2 x i64> @xar_nxv2i64_shifts_neg(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y) {
; CHECK-LABEL: xar_nxv2i64_shifts_neg:
; CHECK: // %bb.0:
; CHECK-NEXT: eor z0.d, z0.d, z1.d
; CHECK-NEXT: lsl z1.d, z0.d, #60
; CHECK-NEXT: lsr z0.d, z0.d, #3
; CHECK-NEXT: orr z0.d, z1.d, z0.d
; CHECK-NEXT: ret
%xor = xor <vscale x 2 x i64> %x, %y
%shl = shl <vscale x 2 x i64> %xor, splat (i64 60)
%shr = lshr <vscale x 2 x i64> %xor, splat (i64 3)
%or = or <vscale x 2 x i64> %shl, %shr
ret <vscale x 2 x i64> %or
}

declare <vscale x 2 x i64> @llvm.fshl.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>)
declare <vscale x 4 x i32> @llvm.fshl.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>)
declare <vscale x 8 x i16> @llvm.fshl.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>)
declare <vscale x 16 x i8> @llvm.fshl.nxv16i8(<vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>)
declare <vscale x 2 x i64> @llvm.fshr.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>)
declare <vscale x 4 x i32> @llvm.fshr.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>)
declare <vscale x 8 x i16> @llvm.fshr.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>)
declare <vscale x 16 x i8> @llvm.fshr.nxv16i8(<vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>)