Skip to content

Commit c13be8f

Browse files
authored
[NVPTX] Add some basic folds for ADDRSPACECAST (#129157)
1 parent d9edca4 commit c13be8f

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
824824
// We have some custom DAG combine patterns for these nodes
825825
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
826826
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
827-
ISD::BUILD_VECTOR});
827+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
828828

829829
// setcc for f16x2 and bf16x2 needs special handling to prevent
830830
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5536,6 +5536,21 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
55365536
return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
55375537
}
55385538

5539+
static SDValue combineADDRSPACECAST(SDNode *N,
5540+
TargetLowering::DAGCombinerInfo &DCI) {
5541+
auto *ASCN1 = cast<AddrSpaceCastSDNode>(N);
5542+
5543+
if (auto *ASCN2 = dyn_cast<AddrSpaceCastSDNode>(ASCN1->getOperand(0))) {
5544+
assert(ASCN2->getDestAddressSpace() == ASCN1->getSrcAddressSpace());
5545+
5546+
// Fold asc[B -> A](asc[A -> B](x)) -> x
5547+
if (ASCN1->getDestAddressSpace() == ASCN2->getSrcAddressSpace())
5548+
return ASCN2->getOperand(0);
5549+
}
5550+
5551+
return SDValue();
5552+
}
5553+
55395554
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
55405555
DAGCombinerInfo &DCI) const {
55415556
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5570,6 +5585,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
55705585
return PerformVSELECTCombine(N, DCI);
55715586
case ISD::BUILD_VECTOR:
55725587
return PerformBUILD_VECTORCombine(N, DCI);
5588+
case ISD::ADDRSPACECAST:
5589+
return combineADDRSPACECAST(N, DCI);
55735590
}
55745591
return SDValue();
55755592
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mcpu=sm_20 -O0 | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s -mcpu=sm_20 -O0 | %ptxas-verify %}
4+
5+
target triple = "nvptx64-unknown-unknown"
6+
7+
define ptr @test1(ptr %p) {
8+
; CHECK-LABEL: test1(
9+
; CHECK: {
10+
; CHECK-NEXT: .reg .b64 %rd<2>;
11+
; CHECK-EMPTY:
12+
; CHECK-NEXT: // %bb.0:
13+
; CHECK-NEXT: ld.param.u64 %rd1, [test1_param_0];
14+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd1;
15+
; CHECK-NEXT: ret;
16+
%a = addrspacecast ptr %p to ptr addrspace(5)
17+
%b = addrspacecast ptr addrspace(5) %a to ptr
18+
ret ptr %b
19+
}
20+
21+
define ptr addrspace(1) @test2(ptr addrspace(5) %p) {
22+
; CHECK-LABEL: test2(
23+
; CHECK: {
24+
; CHECK-NEXT: .reg .b64 %rd<4>;
25+
; CHECK-EMPTY:
26+
; CHECK-NEXT: // %bb.0:
27+
; CHECK-NEXT: ld.param.u64 %rd1, [test2_param_0];
28+
; CHECK-NEXT: cvta.local.u64 %rd2, %rd1;
29+
; CHECK-NEXT: cvta.to.global.u64 %rd3, %rd2;
30+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd3;
31+
; CHECK-NEXT: ret;
32+
%a = addrspacecast ptr addrspace(5) %p to ptr
33+
%b = addrspacecast ptr %a to ptr addrspace(1)
34+
ret ptr addrspace(1) %b
35+
}

0 commit comments

Comments
 (0)