Skip to content

[AArch64][InstCombine] Bail from combining SRAD on +/-1 divisor #109274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 20, 2024

Conversation

MDevereau
Copy link
Contributor

This fixes a crash when svdiv's third parameter is svdup_s64(1)

This fixes a crash when svdiv's third parameter is svdup_s64(1)
@MDevereau
Copy link
Contributor Author

Related to #109155

@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Matthew Devereau (MDevereau)

Changes

This fixes a crash when svdiv's third parameter is svdup_s64(1)


Full diff: https://github.com/llvm/llvm-project/pull/109274.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+3)
  • (modified) llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll (+24)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 649ba1ac8749f5..686503a794d7be 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -1972,7 +1972,10 @@ static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
   ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
   if (!SplatConstantInt)
     return std::nullopt;
+
   APInt Divisor = SplatConstantInt->getValue();
+  if (Divisor.abs().getZExtValue() == 1)
+    return std::nullopt;
 
   if (Divisor.isPowerOf2()) {
     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll
index 68d871355e2a23..dc134ef109fe75 100644
--- a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll
+++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll
@@ -68,6 +68,30 @@ define <vscale x 4 x i32> @sdiv_i32_not_zero(<vscale x 4 x i32> %a, <vscale x 4
   ret <vscale x 4 x i32> %out
 }
 
+; Don't instcombine to SRAD when the divisor is (+/-)1
+define <vscale x 2 x i64> @divide_by_1(<vscale x 16 x i1> %p, <vscale x 2 x i64> %a) #0 {
+; CHECK-LABEL: @divide_by_1(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> [[TMP1]], <vscale x 2 x i64> [[A:%.*]], <vscale x 2 x i64> shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 1, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 2 x i64> [[TMP2]]
+;
+  %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 1)
+  %2 = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %p)
+  %3 = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> %2, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+  ret <vscale x 2 x i64> %3
+}
+
+define <vscale x 2 x i64> @divide_by_m1(<vscale x 16 x i1> %p, <vscale x 2 x i64> %a) #0 {
+; CHECK-LABEL: @divide_by_m1(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> [[TMP1]], <vscale x 2 x i64> [[A:%.*]], <vscale x 2 x i64> shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 -1, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 2 x i64> [[TMP2]]
+;
+  %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -1)
+  %2 = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %p)
+  %3 = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> %2, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+  ret <vscale x 2 x i64> %3
+}
 
 declare <vscale x 4 x i32> @llvm.aarch64.sve.sdiv.nxv4i32(<vscale x 4 x i1>, <vscale x 4 x i32>, <vscale x 4 x i32>)
 declare <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1>, <vscale x 2 x i64>, <vscale x 2 x i64>)

@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Matthew Devereau (MDevereau)

Changes

This fixes a crash when svdiv's third parameter is svdup_s64(1)


Full diff: https://github.com/llvm/llvm-project/pull/109274.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+3)
  • (modified) llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll (+24)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 649ba1ac8749f5..686503a794d7be 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -1972,7 +1972,10 @@ static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
   ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
   if (!SplatConstantInt)
     return std::nullopt;
+
   APInt Divisor = SplatConstantInt->getValue();
+  if (Divisor.abs().getZExtValue() == 1)
+    return std::nullopt;
 
   if (Divisor.isPowerOf2()) {
     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll
index 68d871355e2a23..dc134ef109fe75 100644
--- a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll
+++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-sdiv.ll
@@ -68,6 +68,30 @@ define <vscale x 4 x i32> @sdiv_i32_not_zero(<vscale x 4 x i32> %a, <vscale x 4
   ret <vscale x 4 x i32> %out
 }
 
+; Don't instcombine to SRAD when the divisor is (+/-)1
+define <vscale x 2 x i64> @divide_by_1(<vscale x 16 x i1> %p, <vscale x 2 x i64> %a) #0 {
+; CHECK-LABEL: @divide_by_1(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> [[TMP1]], <vscale x 2 x i64> [[A:%.*]], <vscale x 2 x i64> shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 1, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 2 x i64> [[TMP2]]
+;
+  %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 1)
+  %2 = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %p)
+  %3 = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> %2, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+  ret <vscale x 2 x i64> %3
+}
+
+define <vscale x 2 x i64> @divide_by_m1(<vscale x 16 x i1> %p, <vscale x 2 x i64> %a) #0 {
+; CHECK-LABEL: @divide_by_m1(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> [[P:%.*]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> [[TMP1]], <vscale x 2 x i64> [[A:%.*]], <vscale x 2 x i64> shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 -1, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 2 x i64> [[TMP2]]
+;
+  %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -1)
+  %2 = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %p)
+  %3 = call <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1> %2, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+  ret <vscale x 2 x i64> %3
+}
 
 declare <vscale x 4 x i32> @llvm.aarch64.sve.sdiv.nxv4i32(<vscale x 4 x i1>, <vscale x 4 x i32>, <vscale x 4 x i32>)
 declare <vscale x 2 x i64> @llvm.aarch64.sve.sdiv.nxv2i64(<vscale x 2 x i1>, <vscale x 2 x i64>, <vscale x 2 x i64>)

APInt Divisor = SplatConstantInt->getValue();
if (Divisor.abs().getZExtValue() == 1)
return std::nullopt;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the +1 case, would it be wrong to always return Vec here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense. I guess the predicate wouldn't matter in that case since it would be a no-op on the whole vector

@MDevereau MDevereau merged commit 1808fc1 into llvm:main Sep 20, 2024
8 checks passed
@MDevereau MDevereau deleted the srad branch September 20, 2024 12:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants