Skip to content

[AArch64] Add tablegen patterns for fmla index with extract 0. #114976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 8, 2024

Conversation

davemgreen
Copy link
Collaborator

@davemgreen davemgreen commented Nov 5, 2024

We have tablegen patterns to produce an indexed fmla s0, s1, v2.s[2] from
fma extract(Rn, lane), Rm, Ra -> fmla
But for the case of lane==0, we want to prefer the simple fmadd s0, s1, s2. So we have patterns for
fma extract(Rn, 0), Rm, Ra -> fmadd

The problem arises when we have two extracts, as tablegen starts to prefer the second pattern, as it looks more specialized. This patch adds additional patterns to catch this case:
fma extract(Rn, index), extract(Rm, 0), Ra -> fmla
To make sure the simpler fmadd keeps being selected when both lanes are extracted from lane 0 we need to add patterns for that case too:
fma extract(Rn, 0), extract(Rm, 0), Ra -> fmadd

We have tablegen patterns to produce an indexed `fmla s0, s1, v2.s[2]` from
  fma extract(Rn, lane), Rm, Ra -> fmla
But for the case of lane==0, we want to prefer the simple `fmadd s0, s1, s2. So
we have patterns for
  fma extract(Rn, 0), Rm, Ra -> fmadd

The problem arises when we have two extracts, as tablegen starts to prefer the
second pattern, as it looks more specialized. This patch adds addition patterns
to catch this case:
  fma extract(Rn, index), extract(Rm, 0), Ra -> fmla
To make sure the simpler fmadd keeps being used when both lanes are extracted
from lane 0 we need to add patterns for that case too:
  fma extract(Rn, 0), extract(Rm, 0), Ra -> fmadd
@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-backend-aarch64

Author: David Green (davemgreen)

Changes

We have tablegen patterns to produce an indexed fmla s0, s1, v2.s[2] from
fma extract(Rn, lane), Rm, Ra -> fmla
But for the case of lane==0, we want to prefer the simple fmadd s0, s1, s2. So we have patterns for
fma extract(Rn, 0), Rm, Ra -> fmadd

The problem arises when we have two extracts, as tablegen starts to prefer the second pattern, as it looks more specialized. This patch adds addition patterns to catch this case:
fma extract(Rn, index), extract(Rm, 0), Ra -> fmla
To make sure the simpler fmadd keeps being used when both lanes are extracted from lane 0 we need to add patterns for that case too:
fma extract(Rn, 0), extract(Rm, 0), Ra -> fmadd


Full diff: https://github.com/llvm/llvm-project/pull/114976.diff

4 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64InstrFormats.td (+36)
  • (modified) llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll (+4-4)
  • (modified) llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll (+2-4)
  • (modified) llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll (+8-16)
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index e44caef686be29..b5f6388ea00285 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -5821,6 +5821,13 @@ multiclass ThreeOperandFPData<bit isNegated, bit isSub,string asm,
                        (f16 FPR16:$Ra))),
             (!cast<Instruction>(NAME # Hrrr)
               (f16 (EXTRACT_SUBREG V128:$Rn, hsub)), FPR16:$Rm, FPR16:$Ra)>;
+
+  def : Pat<(f16 (node (f16 (extractelt (v8f16 V128:$Rn), (i64 0))),
+                       (f16 (extractelt (v8f16 V128:$Rm), (i64 0))),
+                       (f16 FPR16:$Ra))),
+            (!cast<Instruction>(NAME # Hrrr)
+              (f16 (EXTRACT_SUBREG V128:$Rn, hsub)),
+              (f16 (EXTRACT_SUBREG V128:$Rm, hsub)), FPR16:$Ra)>;
   }
 
   def : Pat<(f32 (node (f32 FPR32:$Rn),
@@ -5835,6 +5842,13 @@ multiclass ThreeOperandFPData<bit isNegated, bit isSub,string asm,
             (!cast<Instruction>(NAME # Srrr)
               (EXTRACT_SUBREG V128:$Rn, ssub), FPR32:$Rm, FPR32:$Ra)>;
 
+  def : Pat<(f32 (node (f32 (extractelt (v4f32 V128:$Rn), (i64 0))),
+                       (f32 (extractelt (v4f32 V128:$Rm), (i64 0))),
+                       (f32 FPR32:$Ra))),
+            (!cast<Instruction>(NAME # Srrr)
+              (EXTRACT_SUBREG V128:$Rn, ssub),
+              (EXTRACT_SUBREG V128:$Rm, ssub), FPR32:$Ra)>;
+
   def : Pat<(f64 (node (f64 FPR64:$Rn),
                        (f64 (extractelt (v2f64 V128:$Rm), (i64 0))),
                        (f64 FPR64:$Ra))),
@@ -5846,6 +5860,13 @@ multiclass ThreeOperandFPData<bit isNegated, bit isSub,string asm,
                        (f64 FPR64:$Ra))),
             (!cast<Instruction>(NAME # Drrr)
               (EXTRACT_SUBREG V128:$Rn, dsub), FPR64:$Rm, FPR64:$Ra)>;
+
+  def : Pat<(f64 (node (f64 (extractelt (v2f64 V128:$Rn), (i64 0))),
+                       (f64 (extractelt (v2f64 V128:$Rm), (i64 0))),
+                       (f64 FPR64:$Ra))),
+            (!cast<Instruction>(NAME # Drrr)
+              (EXTRACT_SUBREG V128:$Rn, dsub),
+              (EXTRACT_SUBREG V128:$Rm, dsub), FPR64:$Ra)>;
 }
 
 //---
@@ -9282,6 +9303,11 @@ multiclass SIMDFPIndexedTiedPatterns<string INST, SDPatternOperator OpNode> {
                          (vector_extract (v8f16 V128_lo:$Rm), VectorIndexH:$idx))),
             (!cast<Instruction>(INST # "v1i16_indexed") FPR16:$Rd, FPR16:$Rn,
                 V128_lo:$Rm, VectorIndexH:$idx)>;
+  def : Pat<(f16 (OpNode (f16 FPR16:$Rd),
+                         (vector_extract (v8f16 V128:$Rn), (i64 0)),
+                         (vector_extract (v8f16 V128_lo:$Rm), VectorIndexH:$idx))),
+            (!cast<Instruction>(INST # "v1i16_indexed") FPR16:$Rd,
+                (f16 (EXTRACT_SUBREG V128:$Rn, hsub)), V128_lo:$Rm, VectorIndexH:$idx)>;
   } // Predicates = [HasNEON, HasFullFP16]
 
   // 2 variants for the .2s version: DUPLANE from 128-bit and DUP scalar.
@@ -9323,12 +9349,22 @@ multiclass SIMDFPIndexedTiedPatterns<string INST, SDPatternOperator OpNode> {
                          (vector_extract (v4f32 V128:$Rm), VectorIndexS:$idx))),
             (!cast<Instruction>(INST # "v1i32_indexed") FPR32:$Rd, FPR32:$Rn,
                 V128:$Rm, VectorIndexS:$idx)>;
+  def : Pat<(f32 (OpNode (f32 FPR32:$Rd),
+                         (vector_extract (v4f32 V128:$Rn), (i64 0)),
+                         (vector_extract (v4f32 V128:$Rm), VectorIndexS:$idx))),
+            (!cast<Instruction>(INST # "v1i32_indexed") FPR32:$Rd,
+                (f32 (EXTRACT_SUBREG V128:$Rn, ssub)), V128:$Rm, VectorIndexS:$idx)>;
 
   // 1 variant for 64-bit scalar version: extract from .1d or from .2d
   def : Pat<(f64 (OpNode (f64 FPR64:$Rd), (f64 FPR64:$Rn),
                          (vector_extract (v2f64 V128:$Rm), VectorIndexD:$idx))),
             (!cast<Instruction>(INST # "v1i64_indexed") FPR64:$Rd, FPR64:$Rn,
                 V128:$Rm, VectorIndexD:$idx)>;
+  def : Pat<(f64 (OpNode (f64 FPR64:$Rd),
+                         (vector_extract (v2f64 V128:$Rn), (i64 0)),
+                         (vector_extract (v2f64 V128:$Rm), VectorIndexD:$idx))),
+            (!cast<Instruction>(INST # "v1i64_indexed") FPR64:$Rd,
+                (f64 (EXTRACT_SUBREG V128:$Rn, dsub)), V128:$Rm, VectorIndexD:$idx)>;
 }
 
 let mayRaiseFPException = 1, Uses = [FPCR] in
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll
index fbe913e5472cc2..afcdb76067f433 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-f16-mul.ll
@@ -11,10 +11,10 @@ define <2 x half> @complex_mul_v2f16(<2 x half> %a, <2 x half> %b) {
 ; CHECK-NEXT:    mov h2, v0.h[1]
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
 ; CHECK-NEXT:    fmul h3, h0, v1.h[1]
-; CHECK-NEXT:    fmul h4, h2, v1.h[1]
-; CHECK-NEXT:    fmadd h2, h1, h2, h3
-; CHECK-NEXT:    fnmsub h0, h1, h0, h4
-; CHECK-NEXT:    mov v0.h[1], v2.h[0]
+; CHECK-NEXT:    fmul h2, h2, v1.h[1]
+; CHECK-NEXT:    fmla h3, h1, v0.h[1]
+; CHECK-NEXT:    fnmsub h0, h1, h0, h2
+; CHECK-NEXT:    mov v0.h[1], v3.h[0]
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $q0
 ; CHECK-NEXT:    ret
 entry:
diff --git a/llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll b/llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll
index 725c44c9788988..368683e2b93af4 100644
--- a/llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll
+++ b/llvm/test/CodeGen/AArch64/fp16_intrinsic_lane.ll
@@ -120,8 +120,7 @@ define half @t_vfmah_lane_f16_3_0(half %a, <4 x half> %c) {
 ; CHECK-LABEL: t_vfmah_lane_f16_3_0:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    mov h2, v1.h[3]
-; CHECK-NEXT:    fmadd h0, h1, h2, h0
+; CHECK-NEXT:    fmla h0, h1, v1.h[3]
 ; CHECK-NEXT:    ret
 entry:
   %b = extractelement <4 x half> %c, i32 0
@@ -310,8 +309,7 @@ define half @t_vfmsh_lane_f16_0_3(half %a, <4 x half> %c, i32 %lane) {
 ; CHECK-LABEL: t_vfmsh_lane_f16_0_3:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    mov h2, v1.h[3]
-; CHECK-NEXT:    fmsub h0, h2, h1, h0
+; CHECK-NEXT:    fmls h0, h1, v1.h[3]
 ; CHECK-NEXT:    ret
 entry:
   %b = extractelement <4 x half> %c, i32 0
diff --git a/llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll b/llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll
index b2ea6ff200be1d..544d7680f01b80 100644
--- a/llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll
+++ b/llvm/test/CodeGen/AArch64/neon-scalar-by-elem-fma.ll
@@ -84,8 +84,7 @@ define float @test_fmla_ss2S_1(float %a, float %b, <2 x float> %v) {
 define float @test_fmla_ss4S_3_ext0(float %a, <4 x float> %v) {
 ; CHECK-LABEL: test_fmla_ss4S_3_ext0:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov s2, v1.s[3]
-; CHECK-NEXT:    fmadd s0, s1, s2, s0
+; CHECK-NEXT:    fmla s0, s1, v1.s[3]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <4 x float> %v, i32 0
   %tmp1 = extractelement <4 x float> %v, i32 3
@@ -96,8 +95,7 @@ define float @test_fmla_ss4S_3_ext0(float %a, <4 x float> %v) {
 define float @test_fmla_ss4S_3_ext0_swp(float %a, <4 x float> %v) {
 ; CHECK-LABEL: test_fmla_ss4S_3_ext0_swp:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov s2, v1.s[3]
-; CHECK-NEXT:    fmadd s0, s2, s1, s0
+; CHECK-NEXT:    fmla s0, s1, v1.s[3]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <4 x float> %v, i32 0
   %tmp1 = extractelement <4 x float> %v, i32 3
@@ -120,8 +118,7 @@ define float @test_fmla_ss2S_3_ext0(float %a, <2 x float> %v) {
 ; CHECK-LABEL: test_fmla_ss2S_3_ext0:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    mov s2, v1.s[1]
-; CHECK-NEXT:    fmadd s0, s1, s2, s0
+; CHECK-NEXT:    fmla s0, s1, v1.s[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x float> %v, i32 0
   %tmp1 = extractelement <2 x float> %v, i32 1
@@ -133,8 +130,7 @@ define float @test_fmla_ss2S_3_ext0_swp(float %a, <2 x float> %v) {
 ; CHECK-LABEL: test_fmla_ss2S_3_ext0_swp:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
-; CHECK-NEXT:    mov s2, v1.s[1]
-; CHECK-NEXT:    fmadd s0, s2, s1, s0
+; CHECK-NEXT:    fmla s0, s1, v1.s[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x float> %v, i32 0
   %tmp1 = extractelement <2 x float> %v, i32 1
@@ -218,8 +214,7 @@ define double @test_fmla_dd2D_1_swap(double %a, double %b, <2 x double> %v) {
 define double @test_fmla_ss2D_1_ext0(double %a, <2 x double> %v) {
 ; CHECK-LABEL: test_fmla_ss2D_1_ext0:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov d2, v1.d[1]
-; CHECK-NEXT:    fmadd d0, d1, d2, d0
+; CHECK-NEXT:    fmla d0, d1, v1.d[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x double> %v, i32 0
   %tmp1 = extractelement <2 x double> %v, i32 1
@@ -230,8 +225,7 @@ define double @test_fmla_ss2D_1_ext0(double %a, <2 x double> %v) {
 define double @test_fmla_ss2D_1_ext0_swp(double %a, <2 x double> %v) {
 ; CHECK-LABEL: test_fmla_ss2D_1_ext0_swp:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov d2, v1.d[1]
-; CHECK-NEXT:    fmadd d0, d2, d1, d0
+; CHECK-NEXT:    fmla d0, d1, v1.d[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x double> %v, i32 0
   %tmp1 = extractelement <2 x double> %v, i32 1
@@ -340,8 +334,7 @@ define float @test_fmls_ss2S_1(float %a, float %b, <2 x float> %v) {
 define float @test_fmls_ss4S_3_ext0(float %a, <4 x float> %v) {
 ; CHECK-LABEL: test_fmls_ss4S_3_ext0:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov s2, v1.s[3]
-; CHECK-NEXT:    fmsub s0, s1, s2, s0
+; CHECK-NEXT:    fmls s0, s1, v1.s[3]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <4 x float> %v, i32 0
   %tmp1 = extractelement <4 x float> %v, i32 3
@@ -437,8 +430,7 @@ define double @test_fmls_dd2D_1_swap(double %a, double %b, <2 x double> %v) {
 define double @test_fmls_dd2D_1_ext0(double %a, <2 x double> %v) {
 ; CHECK-LABEL: test_fmls_dd2D_1_ext0:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov d2, v1.d[1]
-; CHECK-NEXT:    fmsub d0, d1, d2, d0
+; CHECK-NEXT:    fmls d0, d1, v1.d[1]
 ; CHECK-NEXT:    ret
   %tmp0 = extractelement <2 x double> %v, i32 0
   %tmp1 = extractelement <2 x double> %v, i32 1

Copy link
Collaborator

@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@davemgreen davemgreen merged commit 92a9bcc into llvm:main Nov 8, 2024
8 of 10 checks passed
@davemgreen davemgreen deleted the gh-a64-fmlaext0patterns branch November 8, 2024 16:18
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
…114976)

We have tablegen patterns to produce an indexed `fmla s0, s1, v2.s[2]`
from
  `fma extract(Rn, lane), Rm, Ra -> fmla`
But for the case of lane==0, we want to prefer the simple `fmadd s0, s1,
s2`. So we have patterns for
  `fma extract(Rn, 0), Rm, Ra -> fmadd`

The problem arises when we have two extracts, as tablegen starts to
prefer the second pattern, as it looks more specialized. This patch adds
additional patterns to catch this case:
  `fma extract(Rn, index), extract(Rm, 0), Ra -> fmla`
To make sure the simpler fmadd keeps being selected when both lanes are
extracted from lane 0 we need to add patterns for that case too:
  `fma extract(Rn, 0), extract(Rm, 0), Ra -> fmadd`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants