Skip to content

[NVPTX] Lower invalid ISD::ADDRSPACECAST #125607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 11, 2025

Conversation

justinfargnoli
Copy link
Contributor

Avoid crashing when lowering addrspacecast ptr addrspace(<non-zero>) %ptr to ptr addrspace(<non-zero>).

@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Justin Fargnoli (justinfargnoli)

Changes

Avoid crashing when lowering addrspacecast ptr addrspace(&lt;non-zero&gt;) %ptr to ptr addrspace(&lt;non-zero&gt;).


Full diff: https://github.com/llvm/llvm-project/pull/125607.diff

3 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+24)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+1)
  • (modified) llvm/test/CodeGen/NVPTX/addrspacecast.ll (+14)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 773c97f7b4dc0f..962c971e6970cd 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -989,6 +989,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
     setOperationAction(ISD::FLOG2, {MVT::v2f16, MVT::v2bf16}, Expand);
   }
 
+  setOperationAction(ISD::ADDRSPACECAST, {MVT::i32, MVT::i64}, Custom);
+
   // No FPOW or FREM in PTX.
 
   // Now deduce the information based on the above mentioned
@@ -2652,6 +2654,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return SDValue();
   case ISD::FRAMEADDR:
     return SDValue();
+  case ISD::ADDRSPACECAST:
+    return LowerADDRSPACECAST(Op, DAG);
   case ISD::GlobalAddress:
     return LowerGlobalAddress(Op, DAG);
   case ISD::INTRINSIC_W_CHAIN:
@@ -2767,6 +2771,26 @@ unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
   return MachineJumpTableInfo::EK_Inline;
 }
 
+SDValue NVPTXTargetLowering::LowerADDRSPACECAST(SDValue Op,
+                                                SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  AddrSpaceCastSDNode *N = cast<AddrSpaceCastSDNode>(Op.getNode());
+
+  EVT OperandVT = Op.getOperand(0).getValueType();
+  unsigned SrcAS = N->getSrcAddressSpace();
+  EVT ResultVT = Op.getValueType();
+  unsigned DestAS = N->getDestAddressSpace();
+
+  if (SrcAS == llvm::ADDRESS_SPACE_GENERIC ||
+      DestAS == llvm::ADDRESS_SPACE_GENERIC)
+    return Op;
+
+  SDValue ToGeneric = DAG.getAddrSpaceCast(DL, OperandVT, Op.getOperand(0),
+                                           SrcAS, llvm::ADDRESS_SPACE_GENERIC);
+  return DAG.getAddrSpaceCast(DL, ResultVT, ToGeneric,
+                              llvm::ADDRESS_SPACE_GENERIC, DestAS);
+}
+
 // This function is almost a copy of SelectionDAG::expandVAArg().
 // The only diff is that this one produces loads from local address space.
 SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 5adf69d621552f..74ec14ba5f8e32 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -264,6 +264,7 @@ class NVPTXTargetLowering : public TargetLowering {
   const NVPTXSubtarget &STI; // cache the subtarget here
   SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const;
 
+  SDValue LowerADDRSPACECAST(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/NVPTX/addrspacecast.ll b/llvm/test/CodeGen/NVPTX/addrspacecast.ll
index 23428b3728674e..e3ebb2f458d46a 100644
--- a/llvm/test/CodeGen/NVPTX/addrspacecast.ll
+++ b/llvm/test/CodeGen/NVPTX/addrspacecast.ll
@@ -99,6 +99,20 @@ define i32 @conv8(ptr %ptr) {
   ret i32 %val
 }
 
+; ALL-LABEL: conv9
+define i32 @conv9(ptr addrspace(1) %ptr) {
+; CLS32: cvta.global.u32
+; CLS32: cvta.to.shared.u32
+; CLS64: cvta.global.u64
+; CLS64: cvta.to.shared.u64
+; PTRCONV: cvt.u32.u64
+; NOPTRCONV-NOT: cvt.u32.u64
+; ALL: ld.shared.u32
+  %specptr = addrspacecast ptr addrspace(1) %ptr to ptr addrspace(3)
+  %val = load i32, ptr addrspace(3) %specptr
+  ret i32 %val
+}
+
 ; Check that we support addrspacecast when splitting the vector
 ; result (<2 x ptr> => 2 x <1 x ptr>).
 ; This also checks that scalarization works for addrspacecast

@AlexMaclean
Copy link
Member

Are there any cases where an addrspacecast like this and the PTX we're emitting after this change would be well defined?

; PTRCONV: cvt.u32.u64
; NOPTRCONV-NOT: cvt.u32.u64
; ALL: ld.shared.u32
%specptr = addrspacecast ptr addrspace(1) %ptr to ptr addrspace(3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that allowing ASCs via generic AS for any AS combination is the right thing to do here.

While we technically can generate PTX that compiles, doing so when it's clearly an error is not a great choice, IMO. It pushes the error detection from compilation phase to runtime and substantially raises the cost of dealing with the consequences. While I agree that diagnostics by crashing is not a good user interface, not diagnosing the problem is worse.

I think incompatible ASC in IR should still be an error. In this particular case we have all the information we need to diagnose invalid ASC early on (IR validation pass on load, perhaps?) and may be able to fail with a somewhat more sensible diagnostic via llvm::report_fatal_error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree and I think we ought to handle invalid addrspacecasts as poison, and stop treating them as a backend error. As it is it is possible to write an assume that introduces UB, resulting in a compiler error depending on optimization level which is a bad property to have

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying that if such an invalid ASC would be placed in a conditional branch that would or would not be eliminated, it would result in back-end crash if that branch was not eliminated by the optimizations?

If that's the case, then it's exactly the problem we have now, with the back-end crashing when we have no way to lower the bad ASC, and I agree that crash in the back-end is not something we want. It's way too late.

I was thinking diagnosing the error early on, if possible. I.e. treat it as if it was a target-specific syntax error, triggered when the back-end knows up-front that such a combination is invalid.

Treating invalid ASC as poison is indeed fundamentally more sound, at the expense of practical usability. We do know that the input is wrong, but have no way to make user aware of that. LLVM's feedback mechanisms are not great.

If we do need compilation to succeed, then I'd rather generate a trap for the invalid ASC, possibly with inline asm comment, explaining what's going on. At least the failure will be obvious, and, maybe, even somewhat debuggable. It's less bad than having to chase mysteriously mangled invalid pointer created by nonsensical ASC laundered via conversion to generic AS and back.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AMDGPU currently uses DiagnosticInfo to report the invalid cast, plus lower to undef. I've been meaning to remove the report error case for a really long time

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in #127487, this happened to come up again since 64-bit flat atomicrmws have apparently been broken since October at -O0

@Artem-B
Copy link
Member

Artem-B commented Feb 4, 2025

Are there any cases where an addrspacecast like this and the PTX we're emitting after this change would be well defined?

It depends on how much we'd be willing to peek under the hood and rely on ptxas implementation details.
https://docs.nvidia.com/cuda/parallel-thread-execution/#generic-addressing says:

The state spaces .const, Kernel Function Parameters (.param), .local and .shared are modeled as windows within the generic address space. Each window is defined by a window base and a window size that is equal to the size of the corresponding state space. A generic address maps to global memory unless it falls within the window for const, local, or shared memory. The Kernel Function Parameters (.param) window is contained within the .global window. Within each window, a generic address maps to an address in the underlying state space by subtracting the window base from the generic address.

The way I see it, ASC from a global pointer to some other AS may happen to work, because global->local conversion may, effectively, be a no-op, assuming that PTX conversion does not actually do any checking on the actual pointer value, and trusts us that the input is indeed in global AS.
Someone could get a shared pointer in generic AS, then convert it to integer, and then pass around as a global pointer. In that case global->generic->shared would happen to work, but it would depend on too many implementation details of how ptxas handles AS conversion operations. The spec does not give us any promises on that.

I guess the short answer is if we know that we can't generate sensible code for the given IR, we should diagnose it, the sooner, the better. LLVM sort of assumes that IR is not only syntactically valid, but is also "sensible" for the target. One would not expect compiler to do anything sensible with the syntactivally valid IR that loads from AS(1234567) on all targets, and ASCs that aren't supported by NVPTX fall into the same category.

@arsenm
Copy link
Contributor

arsenm commented Feb 4, 2025

I guess the short answer is if we know that we can't generate sensible code for the given IR, we should diagnose it, the sooner, the better.

I think we need to never error on addrspacecast, and add an addrspacecast sanitizer.

Consider this OpenCLish example:


void foo(volatile generic int* x) {
  __builtin_assume(is_shared(x));
 *x = 4;
}

void bar() {
  private int y;
  foo(&y); // violation, wrong address space
}

After inlining + infer address spaces, we could end up with an invalid private->shared addrspacecast. This will of course be optimization level dependent. This cast should produce poison, and should not be a hard compiler error.

@Artem-B
Copy link
Member

Artem-B commented Feb 4, 2025

After inlining + infer address spaces, we could end up with an invalid private->shared addrspacecast.

Ugh. I see. C++ CUDA code is often sprinkled with a lot of indiscriminate pointer casts that may point us into the same corner. That's why we can't have good things.

OK, poison it is.

What would be the best place to do it. Ideally we'd do it as soon as we create such an invalid ASC, but for that we must have target info available.

@justinfargnoli justinfargnoli changed the title [NVPTX] Custom lower ADDRSPACECAST [NVPTX] Lower invalid ADDRSPACECAST Feb 7, 2025
@justinfargnoli justinfargnoli changed the title [NVPTX] Lower invalid ADDRSPACECAST [NVPTX] Lower invalid ADDRSPACECAST Feb 7, 2025
@justinfargnoli justinfargnoli changed the title [NVPTX] Lower invalid ADDRSPACECAST [NVPTX] Lower invalid ISD::ADDRSPACECAST Feb 7, 2025
@justinfargnoli
Copy link
Contributor Author

justinfargnoli commented Feb 7, 2025

Thank you all for your comments. I agree that my initial approach did not make sense.

In the meantime, I've modified the PR to custom lower invalid ISD::ADDRSPACECASTs to ISD::UNDEF.

I do agree that:

Ideally we'd do it as soon as we create such an invalid ASC, but for that we must have target info available.

However, I'm unsure of the best place to do this. e.g., InstCombine would not want to contain this optimization because it's target-dependent, correct?

@Artem-B
Copy link
Member

Artem-B commented Feb 7, 2025

Ideally we'd do it as soon as we create such an invalid ASC, but for that we must have target info available.

However, I'm unsure of the best place to do this. e.g., InstCombine would not want to contain this optimization because it's target-dependent, correct?

InstCombine does happen to have access to the TargetTransformInfo, so we may be able to do it there but...

The thing is that we must work correctly regardless of whether we ran any optimizations. That kind of paints us into the corner and means that we can't really do it in the IR. That means that it must be done during lowering.
We may consider replacing bad casts with undef or poison, or figure out better up-front diagnostics, but that would have to be done in addition to the back-end being able to handle those bad casts without crashing, unless we can guarantee that bad casts never get to the back-end.

Here's an idea, instead of laundering the pointer, what if we return an obviously broken pointer with the value that would provide a strong hint that things went wrong. E.g. 0xebadca51 ("Error BAD CAST") for 32-bit pointers and (0xebadca51 | dst_as<<16 | src_as). This will most likely crash and the distinct constant value would be reasonably easy to search for in the sources.

Copy link
Member

@AlexMaclean AlexMaclean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Here's an idea, instead of laundering the pointer, what if we return an obviously broken pointer with the value that would provide a strong hint that things went wrong. E.g. 0xebadca51 ("Error BAD CAST") for 32-bit pointers and (0xebadca51 | dst_as<<16 | src_as). This will most likely crash and the distinct constant value would be reasonably easy to search for in the sources.

I personally prefer using undef in this case as opposed to a value with a message in the hex representation. undef (or implicit def in ptx) seems like it will allow for more optimization both during SDAG and ptxas. If we want to be more user friendly maybe using DiagnosticInfo to report the invalid cast similar to AMDGPU would be the way to go.

I think a good follow up would be using InstCombiner::isValidAddrSpaceCast in InstCombine to convert invalid casts to poison, though there may be something I'm missing about the semantics here.

@justinfargnoli justinfargnoli merged commit 022c9c9 into llvm:main Feb 11, 2025
8 checks passed
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
Avoid [crashing](https://godbolt.org/z/8T58vcM68) when lowering
`addrspacecast ptr addrspace(<non-zero>) %ptr to ptr
addrspace(<non-zero>)`.
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
Avoid [crashing](https://godbolt.org/z/8T58vcM68) when lowering
`addrspacecast ptr addrspace(<non-zero>) %ptr to ptr
addrspace(<non-zero>)`.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
Avoid [crashing](https://godbolt.org/z/8T58vcM68) when lowering
`addrspacecast ptr addrspace(<non-zero>) %ptr to ptr
addrspace(<non-zero>)`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants