Skip to content

[AArch64][SME] Extend FORM_TRANSPOSED pseudos to all multi-vector intrinsics #124258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 4, 2025

Conversation

kmclaughlin-arm
Copy link
Contributor

@kmclaughlin-arm kmclaughlin-arm commented Jan 24, 2025

All patterns for multi-vector intrinsics should try to use the FORM_TRANSPOSED_REG_TUPLE
pseudos so that they can benefit from register allocation hints when SME is available.

This patch removes the post-isel hook for the pseudo and instead extends the
SMEPeepholeOpt pass to replace a REG_SEQENCE with the pseudo if the
expected pattern of StridedOrContiguous copies is found. With this change,
the tablegen patterns for the intrinsics can remain unchanged.

One test has been added for each multiclass this affects.

@llvmbot
Copy link
Member

llvmbot commented Jan 24, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

Changes

All uses of REG_SEQUENCE by multiclasses contained in SMEInstrFormats.td
now use the FORM_TRANSPOSED_REG_TUPLE pseudos so they can benefit
from register allocation hints.
One test has been added for each multiclass changed.


Patch is 144.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124258.diff

8 Files Affected:

  • (modified) llvm/lib/Target/AArch64/SMEInstrFormats.td (+12-12)
  • (modified) llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll (+133-13)
  • (modified) llvm/test/CodeGen/AArch64/sme2-intrinsics-add.ll (+78-61)
  • (modified) llvm/test/CodeGen/AArch64/sme2-intrinsics-fp-dots.ll (+92-33)
  • (modified) llvm/test/CodeGen/AArch64/sme2-intrinsics-insert-mova.ll (+153-145)
  • (modified) llvm/test/CodeGen/AArch64/sme2-intrinsics-int-dots.ll (+280)
  • (modified) llvm/test/CodeGen/AArch64/sme2-intrinsics-qcvt.ll (+108-1)
  • (modified) llvm/test/CodeGen/AArch64/sme2-intrinsics-qrshr.ll (+128-1)
diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index a01d59d0e5c43d..01c3d3af5f2e92 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -177,15 +177,15 @@ class SME2_ZA_TwoOp_VG4_Multi_Single_Pat<string name, SDPatternOperator intrinsi
 class SME2_ZA_TwoOp_VG2_Multi_Multi_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ValueType vt, ComplexPattern tileslice>
     : Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)), vt:$Zn1, vt:$Zn2, vt:$Zm1, vt:$Zm2),
           (!cast<Instruction>(name # _PSEUDO) $base, $offset,
-                                              (REG_SEQUENCE ZPR2Mul2, vt:$Zn1, zsub0, vt:$Zn2, zsub1),
-                                              (REG_SEQUENCE ZPR2Mul2, vt:$Zm1, zsub0, vt:$Zm2, zsub1))>;
+                                              (FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO vt:$Zn1, vt:$Zn2),
+                                              (FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO vt:$Zm1, vt:$Zm2))>;
 
 class SME2_ZA_TwoOp_VG4_Multi_Multi_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ValueType vt, ComplexPattern tileslice>
     : Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)),
                      vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4, vt:$Zm1, vt:$Zm2, vt:$Zm3, vt:$Zm4),
           (!cast<Instruction>(name # _PSEUDO) $base, $offset,
-                                              (REG_SEQUENCE ZPR4Mul4, vt:$Zn1, zsub0, vt:$Zn2, zsub1, vt:$Zn3, zsub2, vt:$Zn4, zsub3),
-                                              (REG_SEQUENCE ZPR4Mul4, vt:$Zm1, zsub0, vt:$Zm2, zsub1, vt:$Zm3, zsub2, vt:$Zm4, zsub3))>;
+                                              (FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4),
+                                              (FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO vt:$Zm1, vt:$Zm2, vt:$Zm3, vt:$Zm4))>;
 
 class SME2_ZA_TwoOp_Multi_Index_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ZPRRegOp zpr_ty, ValueType vt,
                                     Operand imm_ty, ComplexPattern tileslice>
@@ -209,32 +209,32 @@ class SME2_ZA_TwoOp_VG4_Multi_Index_Pat<string name, SDPatternOperator intrinsic
 
 class SME2_Sat_Shift_VG2_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt, Operand imm_ty>
     : Pat<(out_vt (intrinsic in_vt:$Zn1, in_vt:$Zn2, (i32 imm_ty:$i))),
-                  (!cast<Instruction>(name) (REG_SEQUENCE ZPR2Mul2, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1), imm_ty:$i)>;
+                  (!cast<Instruction>(name) (FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO in_vt:$Zn1, in_vt:$Zn2), imm_ty:$i)>;
 
 class SME2_Sat_Shift_VG4_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt, Operand imm_ty>
     : Pat<(out_vt (intrinsic in_vt:$Zn1, in_vt:$Zn2, in_vt:$Zn3, in_vt:$Zn4, (i32 imm_ty:$i))),
-                  (!cast<Instruction>(name) (REG_SEQUENCE ZPR4Mul4, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1, in_vt:$Zn3, zsub2, in_vt:$Zn4, zsub3),
+                  (!cast<Instruction>(name) (FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO in_vt:$Zn1, in_vt:$Zn2, in_vt:$Zn3, in_vt:$Zn4),
                                             imm_ty:$i)>;
 
 class SME2_Cvt_VG4_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt>
     : Pat<(out_vt (intrinsic in_vt:$Zn1, in_vt:$Zn2, in_vt:$Zn3, in_vt:$Zn4)),
-                  (!cast<Instruction>(name) (REG_SEQUENCE ZPR4Mul4, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1, in_vt:$Zn3, zsub2, in_vt:$Zn4, zsub3))>;
+                  (!cast<Instruction>(name) (FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO in_vt:$Zn1, in_vt:$Zn2, in_vt:$Zn3, in_vt:$Zn4))>;
 
 class SME2_ZA_VG1x2_Multi_Pat<string name, SDPatternOperator intrinsic, ValueType vt, Operand index_ty, ComplexPattern tileslice>
     : Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)), vt:$Zn1, vt:$Zn2),
-          (!cast<Instruction>(name # _PSEUDO) $base, $offset, (REG_SEQUENCE ZPR2Mul2, vt:$Zn1, zsub0, vt:$Zn2, zsub1))>;
+          (!cast<Instruction>(name # _PSEUDO) $base, $offset, (FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO vt:$Zn1, vt:$Zn2))>;
 
 class SME2_ZA_VG1x4_Multi_Pat<string name, SDPatternOperator intrinsic, ValueType vt, Operand index_ty, ComplexPattern tileslice>
     : Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)), vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4),
-          (!cast<Instruction>(name # _PSEUDO) $base, $offset, (REG_SEQUENCE ZPR4Mul4, vt:$Zn1, zsub0, vt:$Zn2, zsub1, vt:$Zn3, zsub2, vt:$Zn4, zsub3))>;
+          (!cast<Instruction>(name # _PSEUDO) $base, $offset, (FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4))>;
 
 class SME2_Tile_VG2_Multi_Pat<string name, SDPatternOperator intrinsic, Operand tile_imm, ValueType vt, Operand index_ty, ComplexPattern tileslice>
     : Pat<(intrinsic tile_imm:$tile, (i32 (tileslice MatrixIndexGPR32Op12_15:$base, index_ty:$offset)), vt:$Zn1, vt:$Zn2),
-          (!cast<Instruction>(name # _PSEUDO) $tile, $base, $offset, (REG_SEQUENCE ZPR2Mul2, vt:$Zn1, zsub0, vt:$Zn2, zsub1))>;
+          (!cast<Instruction>(name # _PSEUDO) $tile, $base, $offset, (FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO vt:$Zn1, vt:$Zn2))>;
 
 class SME2_Tile_VG4_Multi_Pat<string name, SDPatternOperator intrinsic, Operand tile_imm, ValueType vt, Operand index_ty, ComplexPattern tileslice>
     : Pat<(intrinsic tile_imm:$tile, (i32 (tileslice MatrixIndexGPR32Op12_15:$base, index_ty:$offset)), vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4),
-          (!cast<Instruction>(name # _PSEUDO) $tile, $base, $offset, (REG_SEQUENCE ZPR4Mul4, vt:$Zn1, zsub0, vt:$Zn2, zsub1, vt:$Zn3, zsub2, vt:$Zn4, zsub3))>;
+          (!cast<Instruction>(name # _PSEUDO) $tile, $base, $offset, (FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4))>;
 
 class SME2_Zero_Matrix_Pat<string name, SDPatternOperator intrinsic, Operand offset_ty, ComplexPattern tileslice>
     : Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, offset_ty:$offset))),
@@ -2446,7 +2446,7 @@ multiclass sme2_fp8_cvt_vg2_single<string mnemonic, bit op, ValueType in_vt, SDP
     let Uses = [FPMR, FPCR];
   }
   def : Pat<(nxv16i8 (intrinsic in_vt:$Zn1, in_vt:$Zn2)),
-            (!cast<Instruction>(NAME) (REG_SEQUENCE ZPR2Mul2, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1))>;
+            (!cast<Instruction>(NAME) (FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO in_vt:$Zn1, in_vt:$Zn2))>;
 }
 
 class sme2_cvt_unpk_vector_vg2<bits<2>sz, bits<3> op, bit u, RegisterOperand first_ty,
diff --git a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll
index 38d3bed2eaf907..dc71773140a2cc 100644
--- a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll
+++ b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll
@@ -1,13 +1,11 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
-; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2,+fp8 -verify-machineinstrs -force-streaming < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2,+fp8 -verify-machineinstrs -force-streaming -enable-subreg-liveness < %s | FileCheck %s
 
 ; FCVT / FCVTN / BFCVT
 
 define <vscale x 16 x i8> @fcvt_x2(<vscale x 8 x half> %zn0, <vscale x 8 x half> %zn1) {
 ; CHECK-LABEL: fcvt_x2:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    // kill: def $z1 killed $z1 killed $z0_z1 def $z0_z1
-; CHECK-NEXT:    // kill: def $z0 killed $z0 killed $z0_z1 def $z0_z1
 ; CHECK-NEXT:    fcvt z0.b, { z0.h, z1.h }
 ; CHECK-NEXT:    ret
   %res = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x2.nxv8f16(<vscale x 8 x half> %zn0, <vscale x 8 x half> %zn1)
@@ -17,10 +15,6 @@ define <vscale x 16 x i8> @fcvt_x2(<vscale x 8 x half> %zn0, <vscale x 8 x half>
 define <vscale x 16 x i8> @fcvt_x4(<vscale x 4 x float> %zn0, <vscale x 4 x float> %zn1, <vscale x 4 x float> %zn2, <vscale x 4 x float> %zn3) {
 ; CHECK-LABEL: fcvt_x4:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    // kill: def $z3 killed $z3 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
-; CHECK-NEXT:    // kill: def $z2 killed $z2 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
-; CHECK-NEXT:    // kill: def $z1 killed $z1 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
-; CHECK-NEXT:    // kill: def $z0 killed $z0 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
 ; CHECK-NEXT:    fcvt z0.b, { z0.s - z3.s }
 ; CHECK-NEXT:    ret
   %res = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x4(<vscale x 4 x float> %zn0, <vscale x 4 x float> %zn1,
@@ -28,13 +22,100 @@ define <vscale x 16 x i8> @fcvt_x4(<vscale x 4 x float> %zn0, <vscale x 4 x floa
   ret <vscale x 16 x i8> %res
 }
 
+define { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } @fcvt_x4_tuple(i64 %stride, ptr %ptr) {
+; CHECK-LABEL: fcvt_x4_tuple:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-9
+; CHECK-NEXT:    str p8, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str z15, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z14, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z13, [sp, #3, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z12, [sp, #4, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z11, [sp, #5, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z10, [sp, #6, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z9, [sp, #7, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z8, [sp, #8, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0d, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0xc8, 0x00, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 72 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NEXT:    .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NEXT:    .cfi_escape 0x10, 0x4a, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x68, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d10 @ cfa - 16 - 24 * VG
+; CHECK-NEXT:    .cfi_escape 0x10, 0x4b, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x60, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d11 @ cfa - 16 - 32 * VG
+; CHECK-NEXT:    .cfi_escape 0x10, 0x4c, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x58, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d12 @ cfa - 16 - 40 * VG
+; CHECK-NEXT:    .cfi_escape 0x10, 0x4d, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x50, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d13 @ cfa - 16 - 48 * VG
+; CHECK-NEXT:    .cfi_escape 0x10, 0x4e, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x48, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d14 @ cfa - 16 - 56 * VG
+; CHECK-NEXT:    .cfi_escape 0x10, 0x4f, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x40, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d15 @ cfa - 16 - 64 * VG
+; CHECK-NEXT:    lsl x9, x0, #1
+; CHECK-NEXT:    ptrue pn8.b
+; CHECK-NEXT:    add x8, x1, x0
+; CHECK-NEXT:    ld1w { z0.s, z4.s, z8.s, z12.s }, pn8/z, [x1]
+; CHECK-NEXT:    ld1w { z1.s, z5.s, z9.s, z13.s }, pn8/z, [x8]
+; CHECK-NEXT:    add x10, x1, x9
+; CHECK-NEXT:    add x8, x8, x9
+; CHECK-NEXT:    ld1w { z2.s, z6.s, z10.s, z14.s }, pn8/z, [x10]
+; CHECK-NEXT:    ld1w { z3.s, z7.s, z11.s, z15.s }, pn8/z, [x8]
+; CHECK-NEXT:    mov z24.d, z8.d
+; CHECK-NEXT:    mov z25.d, z5.d
+; CHECK-NEXT:    mov z26.d, z10.d
+; CHECK-NEXT:    mov z27.d, z11.d
+; CHECK-NEXT:    fcvt z0.b, { z0.s - z3.s }
+; CHECK-NEXT:    fcvt z1.b, { z4.s - z7.s }
+; CHECK-NEXT:    fcvt z2.b, { z24.s - z27.s }
+; CHECK-NEXT:    fcvt z3.b, { z12.s - z15.s }
+; CHECK-NEXT:    ldr z15, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z14, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z13, [sp, #3, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z12, [sp, #4, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z11, [sp, #5, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z10, [sp, #6, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z9, [sp, #7, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z8, [sp, #8, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    addvl sp, sp, #9
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  %0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
+  %1 = tail call { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } @llvm.aarch64.sve.ld1.pn.x4.nxv4f32(target("aarch64.svcount") %0, ptr %ptr)
+  %2 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %1, 0
+  %3 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %1, 1
+  %4 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %1, 2
+  %5 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %1, 3
+  %arrayidx2 = getelementptr inbounds i8, ptr %ptr, i64 %stride
+  %6 = tail call { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } @llvm.aarch64.sve.ld1.pn.x4.nxv4f32(target("aarch64.svcount") %0, ptr %arrayidx2)
+  %7 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %6, 0
+  %8 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %6, 1
+  %9 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %6, 2
+  %10 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %6, 3
+  %mul3 = shl i64 %stride, 1
+  %arrayidx4 = getelementptr inbounds i8, ptr %ptr, i64 %mul3
+  %11 = tail call { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } @llvm.aarch64.sve.ld1.pn.x4.nxv4f32(target("aarch64.svcount") %0, ptr %arrayidx4)
+  %12 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %11, 0
+  %13 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %11, 1
+  %14 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %11, 2
+  %15 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %11, 3
+  %mul5 = mul i64 %stride, 3
+  %arrayidx6 = getelementptr inbounds i8, ptr %ptr, i64 %mul5
+  %16 = tail call { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } @llvm.aarch64.sve.ld1.pn.x4.nxv4f32(target("aarch64.svcount") %0, ptr %arrayidx6)
+  %17 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %16, 0
+  %18 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %16, 1
+  %19 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %16, 2
+  %20 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %16, 3
+  %res1 = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x4(<vscale x 4 x float> %2, <vscale x 4 x float> %7, <vscale x 4 x float> %12, <vscale x 4 x float> %17)
+  %res2 = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x4(<vscale x 4 x float> %3, <vscale x 4 x float> %8, <vscale x 4 x float> %13, <vscale x 4 x float> %18)
+  %res3 = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x4(<vscale x 4 x float> %4, <vscale x 4 x float> %8, <vscale x 4 x float> %14, <vscale x 4 x float> %19)
+  %res4 = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x4(<vscale x 4 x float> %5, <vscale x 4 x float> %10, <vscale x 4 x float> %15, <vscale x 4 x float> %20)
+  %ins1 = insertvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } poison, <vscale x 16 x i8> %res1, 0
+  %ins2 = insertvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %ins1, <vscale x 16 x i8> %res2, 1
+  %ins3 = insertvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %ins2, <vscale x 16 x i8> %res3, 2
+  %ins4 = insertvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %ins3, <vscale x 16 x i8> %res4, 3
+  ret { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %ins4
+}
+
 define <vscale x 16 x i8> @fcvtn(<vscale x 4 x float> %zn0, <vscale x 4 x float> %zn1, <vscale x 4 x float> %zn2, <vscale x 4 x float> %zn3) {
 ; CHECK-LABEL: fcvtn:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    // kill: def $z3 killed $z3 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
-; CHECK-NEXT:    // kill: def $z2 killed $z2 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
-; CHECK-NEXT:    // kill: def $z1 killed $z1 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
-; CHECK-NEXT:    // kill: def $z0 killed $z0 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
 ; CHECK-NEXT:    fcvtn z0.b, { z0.s - z3.s }
 ; CHECK-NEXT:    ret
   %res = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvtn.x4(<vscale x 4 x float> %zn0, <vscale x 4 x float> %zn1,
@@ -45,14 +126,53 @@ define <vscale x 16 x i8> @fcvtn(<vscale x 4 x float> %zn0, <vscale x 4 x float>
 define <vscale x 16 x i8> @bfcvt(<vscale x 8 x bfloat> %zn0, <vscale x 8 x bfloat> %zn1) {
 ; CHECK-LABEL: bfcvt:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    // kill: def $z1 killed $z1 killed $z0_z1 def $z0_z1
-; CHECK-NEXT:    // kill: def $z0 killed $z0 killed $z0_z1 def $z0_z1
 ; CHECK-NEXT:    bfcvt z0.b, { z0.h, z1.h }
 ; CHECK-NEXT:    ret
   %res = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x2.nxv8bf16(<vscale x 8 x bfloat> %zn0, <vscale x 8 x bfloat> %zn1)
   ret <vscale x 16 x i8> %res
 }
 
+
+define { <vscale x 16 x i8>, <vscale x 16 x i8> } @bfcvt_tuple(i64 %stride, ptr %ptr) {
+; CHECK-LABEL: bfcvt_tuple:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-3
+; CHECK-NEXT:    str p8, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str z9, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z8, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x18, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 24 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NEXT:    .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NEXT:    ptrue pn8.b
+; CHECK-NEXT:    add x8, x1, x0
+; CHECK-NEXT:    ld1h { z0.h, z8.h }, pn8/z, [x1]
+; CHECK-NEXT:    ld1h { z1.h, z9.h }, pn8/z, [x8]
+; CHECK-NEXT:    bfcvt z0.b, { z0.h, z1.h }
+; CHECK-NEXT:    bfcvt z1.b, { z8.h, z9.h }
+; CHECK-NEXT:    ldr z9, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z8, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    addvl sp, sp, #3
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  %0 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
+  %1 = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.ld1.pn.x2.nxv8bf16(target("aarch64.svcount") %0, ptr %ptr)
+  %2 = extractvalue { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %1, 0
+  %3 = extractvalue { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %1, 1
+  %arrayidx2 = getelementptr inbounds i8, ptr %ptr, i64 %stride
+  %4 = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.ld1.pn.x2.nxv8bf16(target("aarch64.svcount") %0, ptr %arrayidx2)
+  %5 = extractvalue { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %4, 0
+  %6 = extractvalue { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %4, 1
+  %res1 = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x2.nxv8bf16(<vscale x 8 x bfloat> %2, <vscale x 8 x bfloat> %5)
+  %res2 = call <vscale x 16 x i8> @llvm.aarch64.sve.fp8.cvt.x2.nxv8bf16(<vscale x 8 x...
[truncated]

@kmclaughlin-arm kmclaughlin-arm changed the title [AArch64][SME] Extend FORM_TRANSPOSED pseudos to more SME multi-vector intrinsics [AArch64][SME] Extend FORM_TRANSPOSED pseudos to all multi-vector intrinsics Feb 3, 2025
Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, with two nits.

Comment on lines 267 to 268
if (!MI.getOperand(I).isReg())
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is guranteed to be a reg?

if (SubReg == MCRegister::NoRegister)
SubReg = OpSubReg;

MachineOperand *CopySrcOp = MRI.getOneDef(CopySrc.getReg());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This only returns a value when the MIR is in SSA form. There is an assert for this in the runOnmachineFunction, but maybe it's worth adding it to this function too in case that one accidentally gets removed?

@kmclaughlin-arm kmclaughlin-arm force-pushed the form-transpose-expand branch 2 times, most recently from afc9acc to 0bab43d Compare February 4, 2025 11:27
…r intrinsics

All uses of REG_SEQUENCE by multiclasses contained in SMEInstrFormats.td
now use the FORM_TRANSPOSED_REG_TUPLE pseudos so that they can benefit
from register allocation hints.
One test has been added for each multiclass changed.
- Added visitRegSequence to the AArch64MIPeepholeOpt pass to create the
  pseudo if a REG_SEQUENCE matches the pattern
- Removed uses of FORM_TRANSPOSED_REG_TUPLE from SME2 multiclasses
- Added tests for every other multiclass which can now use the pseudo
- Reworded comment above the FORM_TRANSPOSED_REG_TUPLE definitions
@kmclaughlin-arm kmclaughlin-arm merged commit 25daf7b into llvm:main Feb 4, 2025
8 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…rinsics (llvm#124258)

All patterns for multi-vector intrinsics should try to use the FORM_TRANSPOSED
pseudos so that they can benefit from register allocation hints when SME is available.

This patch removes the post-isel hook for the pseudo and instead extends the
SMEPeepholeOpt pass to replace a REG_SEQENCE with the pseudo if the
expected pattern of StridedOrContiguous copies is found. With this change,
the tablegen patterns for the intrinsics can remain unchanged.

One test has been added for each multiclass this affects.
@kmclaughlin-arm kmclaughlin-arm deleted the form-transpose-expand branch May 27, 2025 10:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants