Skip to content

Fix for logic in combineExtract() #108208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 25, 2024
Merged

Conversation

JonPsson1
Copy link
Contributor

@JonPsson1 JonPsson1 commented Sep 11, 2024

A (csmith) test case appeared where combineExtract() crashed when the input vector was a bitcast into a vector of i1:s. Add a check using canTreatAsByteVector() for the immediate (first) Op as well. This takes care of (avoids) this case and does not seem to change any benchmarks or tests.

In this i1 vector case the logic gets confused: 'End' becomes 128 (bytes), so the 'Op.getOperand(End / OpBytesPerElement - 1)' call uses an argument of 15, but Op is 'v2i64 = BUILD_VECTOR Constant:i64<3>, Constant:i64<3>'.

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2024

@llvm/pr-subscribers-backend-systemz

Author: Jonas Paulsson (JonPsson1)

Changes

A (csmith) test case appeared where combineExtract() crashed when the input vector was a bitcast into a vector of i1:s. Add a check using canTreatAsByteVector() for the immediate (first) Op as well. This takes care of (avoids) this case and does not seem to change any benchmarks or tests.

I am not sure how combineExtract() is supposed to work with various theoretical vectors like <3 x i24> or similar, considering the use of getStoreSize() of the vector element type. A vector element when part of a vector would have the same store size as the element size, but when extracted and used as a scalar it would become the next bigger legal integer type, or?

I guess I am confused about the use of getStoreSize() on vector elements: an i16 and i32 element would have the same store size, so I think it's a little weird how the computations works when BITCASTing between i16 and i32 vectors.

It is clear that in the i1 vector case the logic gets confused: 'End' becomes 128 (bytes), so the 'Op.getOperand(End / OpBytesPerElement - 1)' call uses an argument of 15, but Op is 'v2i64 = BUILD_VECTOR Constant:i64<3>, Constant:i64<3>'.


Full diff: https://github.com/llvm/llvm-project/pull/108208.diff

2 Files Affected:

  • (modified) llvm/lib/Target/SystemZ/SystemZISelLowering.cpp (+3-6)
  • (added) llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll (+20)
diff --git a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
index 582a8c139b2937..bcb09e59ffb0c8 100644
--- a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
@@ -6569,13 +6569,12 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
   // The number of bytes being extracted.
   unsigned BytesPerElement = VecVT.getVectorElementType().getStoreSize();
 
-  for (;;) {
+  while (canTreatAsByteVector(Op.getValueType())) {
     unsigned Opcode = Op.getOpcode();
     if (Opcode == ISD::BITCAST)
       // Look through bitcasts.
       Op = Op.getOperand(0);
-    else if ((Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) &&
-             canTreatAsByteVector(Op.getValueType())) {
+    else if (Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) {
       // Get a VPERM-like permute mask and see whether the bytes covered
       // by the extracted element are a contiguous sequence from one
       // source operand.
@@ -6597,8 +6596,7 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
       Index = Byte / BytesPerElement;
       Op = Op.getOperand(unsigned(First) / Bytes.size());
       Force = true;
-    } else if (Opcode == ISD::BUILD_VECTOR &&
-               canTreatAsByteVector(Op.getValueType())) {
+    } else if (Opcode == ISD::BUILD_VECTOR) {
       // We can only optimize this case if the BUILD_VECTOR elements are
       // at least as wide as the extracted value.
       EVT OpVT = Op.getValueType();
@@ -6627,7 +6625,6 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
     } else if ((Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
                 Opcode == ISD::ZERO_EXTEND_VECTOR_INREG ||
                 Opcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
-               canTreatAsByteVector(Op.getValueType()) &&
                canTreatAsByteVector(Op.getOperand(0).getValueType())) {
       // Make sure that only the unextended bits are significant.
       EVT ExtVT = Op.getValueType();
diff --git a/llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll b/llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll
new file mode 100644
index 00000000000000..d568af47dbafd0
--- /dev/null
+++ b/llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll
@@ -0,0 +1,20 @@
+; RUN: llc -mtriple=s390x-linux-gnu -mcpu=z16 < %s  | FileCheck %s
+;
+; Check that DAGCombiner doesn't crash in SystemZ combineExtract()
+; when handling EXTRACT_VECTOR_ELT with a vector of i1:s.
+
+define i32 @fun(i32 %arg) {
+; CHECK-LABEL: fun:
+entry:
+  %cc = icmp eq i32 %arg, 0
+  br label %loop
+
+loop:
+  %P = phi <128 x i1> [ zeroinitializer, %entry ], [ bitcast (<2 x i64> <i64 3, i64 3> to <128 x i1>), %loop ]
+  br i1 %cc, label %exit, label %loop
+
+exit:
+  %E = extractelement <128 x i1> %P, i64 0
+  %Res = zext i1 %E to i32
+  ret i32 %Res
+}

@uweigand
Copy link
Member

uweigand commented Sep 19, 2024

Hmm, I think the intent was for the type of the initial Op to be deliberately ignored, as it is supposed to be considered implicitly bitcast to VecVT. But it is true that the routine assumes that VecVT must be a byte vector. One of the two call sites (in combineTruncateExtract) ensures that, but the other call site (in combineEXTRACT_VECTOR_ELT) does not. So I think the better fix would be to add a check to that call site instead.

@JonPsson1
Copy link
Contributor Author

I guess I am a little confused here as to the lack of the bytevector check for BITCAST inside combineExtract(). Couldn't it at least in theory be that - as this is in a loop - there are more BITCASTs reached which could then involve i1 vectors?

@uweigand
Copy link
Member

I still don't think we need a check for the BITCAST case; it shouldn't matter what the resulting type is. The important thing is that

  • wherever the type of Op is actually used, it is verified to be a byte type (this is already done); and
  • the desired output type VecVT (including the Index which is based on VecVT) is a byte type - this is where we have a problem as it is only validated by one caller but not the other

In your test case, the cause of the crash is that Index (which comes from the caller) is based on a i1 vector type in VecVT - this is the actual bug. Once we fix the caller, that problem should be resolved.

@JonPsson1
Copy link
Contributor Author

ok, I think I get it now: there is a special check for this in each of the different handlings that would follow a BITCAST handling, so that should not be a problem. Patch updated per your suggestion.

Copy link
Member

@uweigand uweigand left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now

@JonPsson1 JonPsson1 merged commit 0ef24aa into llvm:main Sep 25, 2024
8 checks passed
@JonPsson1 JonPsson1 deleted the CombineExtract branch September 25, 2024 10:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants