Skip to content

[LLVM][SVE] Improve code generation for vector.insert into posion. #105665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14880,6 +14880,10 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
}

// We can select these directly.
if (isTypeLegal(InVT) && Vec0.isUndef())
return Op;

// Ensure the subvector is half the size of the main vector.
if (VT.getVectorElementCount() != (InVT.getVectorElementCount() * 2))
return SDValue();
Expand Down
24 changes: 20 additions & 4 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1938,19 +1938,35 @@ let Predicates = [HasSVEorSME] in {
def : Pat<(nxv2bf16 (extract_subvector nxv8bf16:$Zs, (i64 6))),
(UUNPKHI_ZZ_D (UUNPKHI_ZZ_S ZPR:$Zs))>;

// Insert subvectors into FP SVE vectors.
foreach VT = [nxv4f16, nxv4f32, nxv4bf16] in
foreach idx = [0, 2] in
def : Pat<(VT (vector_insert_subvec undef, SVEType<VT>.HalfLength:$src, (i64 idx))),
(UZP1_ZZZ_S $src, $src)>;

foreach VT = [nxv8f16, nxv8bf16] in {
foreach idx = [0, 4] in
def : Pat<(VT (vector_insert_subvec undef, SVEType<VT>.HalfLength:$src, (i64 idx))),
(UZP1_ZZZ_H $src, $src)>;

foreach idx = [0, 2, 4, 6] in
def : Pat<(VT (vector_insert_subvec undef, SVEType<VT>.QuarterLength:$src, (i64 idx))),
(UZP1_ZZZ_H (UZP1_ZZZ_H $src, $src), (UZP1_ZZZ_H $src, $src))>;
}

// extract/insert 64-bit fixed length vector from/into a scalable vector
foreach VT = [v8i8, v4i16, v2i32, v1i64, v4f16, v2f32, v1f64, v4bf16] in {
def : Pat<(VT (vector_extract_subvec SVEContainerVT<VT>.Value:$Zs, (i64 0))),
def : Pat<(VT (vector_extract_subvec NEONType<VT>.SVEContainer:$Zs, (i64 0))),
(EXTRACT_SUBREG ZPR:$Zs, dsub)>;
def : Pat<(SVEContainerVT<VT>.Value (vector_insert_subvec undef, (VT V64:$src), (i64 0))),
def : Pat<(NEONType<VT>.SVEContainer (vector_insert_subvec undef, (VT V64:$src), (i64 0))),
(INSERT_SUBREG (IMPLICIT_DEF), $src, dsub)>;
}

// extract/insert 128-bit fixed length vector from/into a scalable vector
foreach VT = [v16i8, v8i16, v4i32, v2i64, v8f16, v4f32, v2f64, v8bf16] in {
def : Pat<(VT (vector_extract_subvec SVEContainerVT<VT>.Value:$Zs, (i64 0))),
def : Pat<(VT (vector_extract_subvec NEONType<VT>.SVEContainer:$Zs, (i64 0))),
(EXTRACT_SUBREG ZPR:$Zs, zsub)>;
def : Pat<(SVEContainerVT<VT>.Value (vector_insert_subvec undef, (VT V128:$src), (i64 0))),
def : Pat<(NEONType<VT>.SVEContainer (vector_insert_subvec undef, (VT V128:$src), (i64 0))),
(INSERT_SUBREG (IMPLICIT_DEF), $src, zsub)>;
}

Expand Down
39 changes: 30 additions & 9 deletions llvm/lib/Target/AArch64/SVEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
//
//===----------------------------------------------------------------------===//

// Helper class to find the largest legal scalable vector type that can hold VT.
// Non-matches return VT, which often means VT is the container type.
class SVEContainerVT<ValueType VT> {
ValueType Value = !cond(
// fixed length vectors
// Helper class to hold conversions of legal fixed-length vector types.
class NEONType<ValueType VT> {
// The largest legal scalable vector type that can hold VT.
ValueType SVEContainer = !cond(
!eq(VT, v8i8): nxv16i8,
!eq(VT, v16i8): nxv16i8,
!eq(VT, v4i16): nxv8i16,
Expand All @@ -31,13 +30,35 @@ class SVEContainerVT<ValueType VT> {
!eq(VT, v2f64): nxv2f64,
!eq(VT, v4bf16): nxv8bf16,
!eq(VT, v8bf16): nxv8bf16,
// unpacked scalable vectors
true : untyped);
}

// Helper class to hold conversions of legal scalable vector types.
class SVEType<ValueType VT> {
// The largest legal scalable vector type that can hold VT.
// Non-matches return VT because only packed types remain.
ValueType Packed = !cond(
!eq(VT, nxv2f16): nxv8f16,
!eq(VT, nxv4f16): nxv8f16,
!eq(VT, nxv2f32): nxv4f32,
!eq(VT, nxv2bf16): nxv8bf16,
!eq(VT, nxv4bf16): nxv8bf16,
true : VT);

// The legal scalable vector that is half the length of VT.
ValueType HalfLength = !cond(
!eq(VT, nxv8f16): nxv4f16,
!eq(VT, nxv4f16): nxv2f16,
!eq(VT, nxv4f32): nxv2f32,
!eq(VT, nxv8bf16): nxv4bf16,
!eq(VT, nxv4bf16): nxv2bf16,
true : untyped);

// The legal scalable vector that is quarter the length of VT.
ValueType QuarterLength = !cond(
!eq(VT, nxv8f16): nxv2f16,
!eq(VT, nxv8bf16): nxv2bf16,
true : untyped);
}

def SDT_AArch64Setcc : SDTypeProfile<1, 4, [
Expand Down Expand Up @@ -2959,10 +2980,10 @@ multiclass sve_fp_2op_p_zd<bits<7> opc, string asm,
def NAME : sve_fp_2op_p_zd<opc, asm, i_zprtype, o_zprtype, Sz>,
SVEPseudo2Instr<NAME, 1>;
// convert vt1 to a packed type for the intrinsic patterns
defvar packedvt1 = SVEContainerVT<vt1>.Value;
defvar packedvt1 = SVEType<vt1>.Packed;

// convert vt3 to a packed type for the intrinsic patterns
defvar packedvt3 = SVEContainerVT<vt3>.Value;
defvar packedvt3 = SVEType<vt3>.Packed;

def : SVE_3_Op_Pat<packedvt1, int_op, packedvt1, vt2, packedvt3, !cast<Instruction>(NAME)>;
def : SVE_1_Op_Passthru_Pat<vt1, ir_op, vt2, vt3, !cast<Instruction>(NAME)>;
Expand All @@ -2982,7 +3003,7 @@ multiclass sve_fp_2op_p_zdr<bits<7> opc, string asm,
SVEPseudo2Instr<NAME, 1>;

// convert vt1 to a packed type for the intrinsic patterns
defvar packedvt1 = SVEContainerVT<vt1>.Value;
defvar packedvt1 = SVEType<vt1>.Packed;

def : SVE_3_Op_Pat<packedvt1, int_op, packedvt1, vt2, vt3, !cast<Instruction>(NAME)>;
def : SVE_1_Op_Passthru_Round_Pat<vt1, ir_op, vt2, vt3, !cast<Instruction>(NAME)>;
Expand Down
82 changes: 22 additions & 60 deletions llvm/test/CodeGen/AArch64/sve-bitcast.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1426,11 +1426,8 @@ define <vscale x 1 x i64> @bitcast_nxv4f16_to_nxv1i64(<vscale x 4 x half> %v) #0
;
; CHECK_BE-LABEL: bitcast_nxv4f16_to_nxv1i64:
; CHECK_BE: // %bb.0:
; CHECK_BE-NEXT: ptrue p0.h
; CHECK_BE-NEXT: ptrue p1.s
; CHECK_BE-NEXT: revb z0.h, p0/m, z0.h
; CHECK_BE-NEXT: revb z0.s, p1/m, z0.s
; CHECK_BE-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK_BE-NEXT: ptrue p0.h
; CHECK_BE-NEXT: revb z0.h, p0/m, z0.h
; CHECK_BE-NEXT: ptrue p0.d
; CHECK_BE-NEXT: revb z0.d, p0/m, z0.d
Expand All @@ -1447,13 +1444,11 @@ define <vscale x 1 x i64> @bitcast_nxv2f32_to_nxv1i64(<vscale x 2 x float> %v) #
;
; CHECK_BE-LABEL: bitcast_nxv2f32_to_nxv1i64:
; CHECK_BE: // %bb.0:
; CHECK_BE-NEXT: ptrue p0.s
; CHECK_BE-NEXT: ptrue p1.d
; CHECK_BE-NEXT: revb z0.s, p0/m, z0.s
; CHECK_BE-NEXT: revb z0.d, p1/m, z0.d
; CHECK_BE-NEXT: uzp1 z0.s, z0.s, z0.s
; CHECK_BE-NEXT: ptrue p0.s
; CHECK_BE-NEXT: revb z0.s, p0/m, z0.s
; CHECK_BE-NEXT: revb z0.d, p1/m, z0.d
; CHECK_BE-NEXT: ptrue p0.d
; CHECK_BE-NEXT: revb z0.d, p0/m, z0.d
; CHECK_BE-NEXT: ret
%bc = bitcast <vscale x 2 x float> %v to <vscale x 1 x i64>
ret <vscale x 1 x i64> %bc
Expand All @@ -1479,11 +1474,8 @@ define <vscale x 1 x i64> @bitcast_nxv4bf16_to_nxv1i64(<vscale x 4 x bfloat> %v)
;
; CHECK_BE-LABEL: bitcast_nxv4bf16_to_nxv1i64:
; CHECK_BE: // %bb.0:
; CHECK_BE-NEXT: ptrue p0.h
; CHECK_BE-NEXT: ptrue p1.s
; CHECK_BE-NEXT: revb z0.h, p0/m, z0.h
; CHECK_BE-NEXT: revb z0.s, p1/m, z0.s
; CHECK_BE-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK_BE-NEXT: ptrue p0.h
; CHECK_BE-NEXT: revb z0.h, p0/m, z0.h
; CHECK_BE-NEXT: ptrue p0.d
; CHECK_BE-NEXT: revb z0.d, p0/m, z0.d
Expand Down Expand Up @@ -1888,11 +1880,8 @@ define <vscale x 1 x double> @bitcast_nxv4f16_to_nxv1f64(<vscale x 4 x half> %v)
;
; CHECK_BE-LABEL: bitcast_nxv4f16_to_nxv1f64:
; CHECK_BE: // %bb.0:
; CHECK_BE-NEXT: ptrue p0.h
; CHECK_BE-NEXT: ptrue p1.s
; CHECK_BE-NEXT: revb z0.h, p0/m, z0.h
; CHECK_BE-NEXT: revb z0.s, p1/m, z0.s
; CHECK_BE-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK_BE-NEXT: ptrue p0.h
; CHECK_BE-NEXT: revb z0.h, p0/m, z0.h
; CHECK_BE-NEXT: ptrue p0.d
; CHECK_BE-NEXT: revb z0.d, p0/m, z0.d
Expand All @@ -1909,13 +1898,11 @@ define <vscale x 1 x double> @bitcast_nxv2f32_to_nxv1f64(<vscale x 2 x float> %v
;
; CHECK_BE-LABEL: bitcast_nxv2f32_to_nxv1f64:
; CHECK_BE: // %bb.0:
; CHECK_BE-NEXT: ptrue p0.s
; CHECK_BE-NEXT: ptrue p1.d
; CHECK_BE-NEXT: revb z0.s, p0/m, z0.s
; CHECK_BE-NEXT: revb z0.d, p1/m, z0.d
; CHECK_BE-NEXT: uzp1 z0.s, z0.s, z0.s
; CHECK_BE-NEXT: ptrue p0.s
; CHECK_BE-NEXT: revb z0.s, p0/m, z0.s
; CHECK_BE-NEXT: revb z0.d, p1/m, z0.d
; CHECK_BE-NEXT: ptrue p0.d
; CHECK_BE-NEXT: revb z0.d, p0/m, z0.d
; CHECK_BE-NEXT: ret
%bc = bitcast <vscale x 2 x float> %v to <vscale x 1 x double>
ret <vscale x 1 x double> %bc
Expand All @@ -1929,11 +1916,8 @@ define <vscale x 1 x double> @bitcast_nxv4bf16_to_nxv1f64(<vscale x 4 x bfloat>
;
; CHECK_BE-LABEL: bitcast_nxv4bf16_to_nxv1f64:
; CHECK_BE: // %bb.0:
; CHECK_BE-NEXT: ptrue p0.h
; CHECK_BE-NEXT: ptrue p1.s
; CHECK_BE-NEXT: revb z0.h, p0/m, z0.h
; CHECK_BE-NEXT: revb z0.s, p1/m, z0.s
; CHECK_BE-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK_BE-NEXT: ptrue p0.h
; CHECK_BE-NEXT: revb z0.h, p0/m, z0.h
; CHECK_BE-NEXT: ptrue p0.d
; CHECK_BE-NEXT: revb z0.d, p0/m, z0.d
Expand Down Expand Up @@ -2333,29 +2317,18 @@ define <vscale x 1 x i32> @bitcast_nxv2i16_to_nxv1i32(<vscale x 2 x i16> %v) #0
define <vscale x 1 x i32> @bitcast_nxv2f16_to_nxv1i32(<vscale x 2 x half> %v) #0 {
; CHECK-LABEL: bitcast_nxv2f16_to_nxv1i32:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: st1h { z0.d }, p0, [sp]
; CHECK-NEXT: ld1w { z0.s }, p1/z, [sp]
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK-NEXT: ret
;
; CHECK_BE-LABEL: bitcast_nxv2f16_to_nxv1i32:
; CHECK_BE: // %bb.0:
; CHECK_BE-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK_BE-NEXT: addvl sp, sp, #-1
; CHECK_BE-NEXT: ptrue p0.d
; CHECK_BE-NEXT: ptrue p1.h
; CHECK_BE-NEXT: st1h { z0.d }, p0, [sp]
; CHECK_BE-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK_BE-NEXT: ptrue p0.h
; CHECK_BE-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK_BE-NEXT: revb z0.h, p0/m, z0.h
; CHECK_BE-NEXT: ptrue p0.s
; CHECK_BE-NEXT: ld1h { z0.h }, p1/z, [sp]
; CHECK_BE-NEXT: revb z0.h, p1/m, z0.h
; CHECK_BE-NEXT: revb z0.s, p0/m, z0.s
; CHECK_BE-NEXT: addvl sp, sp, #1
; CHECK_BE-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK_BE-NEXT: ret
%bc = bitcast <vscale x 2 x half> %v to <vscale x 1 x i32>
ret <vscale x 1 x i32> %bc
Expand All @@ -2366,29 +2339,18 @@ define <vscale x 1 x i32> @bitcast_nxv2f16_to_nxv1i32(<vscale x 2 x half> %v) #0
define <vscale x 1 x i32> @bitcast_nxv2bf16_to_nxv1i32(<vscale x 2 x bfloat> %v) #0 {
; CHECK-LABEL: bitcast_nxv2bf16_to_nxv1i32:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: st1h { z0.d }, p0, [sp]
; CHECK-NEXT: ld1w { z0.s }, p1/z, [sp]
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK-NEXT: ret
;
; CHECK_BE-LABEL: bitcast_nxv2bf16_to_nxv1i32:
; CHECK_BE: // %bb.0:
; CHECK_BE-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK_BE-NEXT: addvl sp, sp, #-1
; CHECK_BE-NEXT: ptrue p0.d
; CHECK_BE-NEXT: ptrue p1.h
; CHECK_BE-NEXT: st1h { z0.d }, p0, [sp]
; CHECK_BE-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK_BE-NEXT: ptrue p0.h
; CHECK_BE-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK_BE-NEXT: revb z0.h, p0/m, z0.h
; CHECK_BE-NEXT: ptrue p0.s
; CHECK_BE-NEXT: ld1h { z0.h }, p1/z, [sp]
; CHECK_BE-NEXT: revb z0.h, p1/m, z0.h
; CHECK_BE-NEXT: revb z0.s, p0/m, z0.s
; CHECK_BE-NEXT: addvl sp, sp, #1
; CHECK_BE-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK_BE-NEXT: ret
%bc = bitcast <vscale x 2 x bfloat> %v to <vscale x 1 x i32>
ret <vscale x 1 x i32> %bc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,9 @@ define <4 x i64> @extract_v4i64_nxv8i64_0(<vscale x 8 x i64> %arg) {
define <4 x half> @extract_v4f16_nxv2f16_0(<vscale x 2 x half> %arg) {
; CHECK-LABEL: extract_v4f16_nxv2f16_0:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: st1h { z0.d }, p0, [sp]
; CHECK-NEXT: ldr d0, [sp]
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%ext = call <4 x half> @llvm.vector.extract.v4f16.nxv2f16(<vscale x 2 x half> %arg, i64 0)
ret <4 x half> %ext
Expand All @@ -313,18 +307,10 @@ define <4 x half> @extract_v4f16_nxv2f16_0(<vscale x 2 x half> %arg) {
define <4 x half> @extract_v4f16_nxv2f16_4(<vscale x 2 x half> %arg) {
; CHECK-LABEL: extract_v4f16_nxv2f16_4:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: ptrue p1.h
; CHECK-NEXT: st1h { z0.d }, p0, [sp]
; CHECK-NEXT: ld1h { z0.h }, p1/z, [sp]
; CHECK-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK-NEXT: uzp1 z0.h, z0.h, z0.h
; CHECK-NEXT: ext z0.b, z0.b, z0.b, #8
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%ext = call <4 x half> @llvm.vector.extract.v4f16.nxv2f16(<vscale x 2 x half> %arg, i64 4)
ret <4 x half> %ext
Expand Down
Loading
Loading