Skip to content

[SDISel] Teach the type legalizer about ADDRSPACECAST #90969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue ScalarizeVecRes_InregOp(SDNode *N);
SDValue ScalarizeVecRes_VecInregOp(SDNode *N);

SDValue ScalarizeVecRes_ADDRSPACECAST(SDNode *N);
SDValue ScalarizeVecRes_BITCAST(SDNode *N);
SDValue ScalarizeVecRes_BUILD_VECTOR(SDNode *N);
SDValue ScalarizeVecRes_EXTRACT_SUBVECTOR(SDNode *N);
Expand Down Expand Up @@ -852,6 +853,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_ADDRSPACECAST(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_FFREXP(SDNode *N, unsigned ResNo, SDValue &Lo, SDValue &Hi);
void SplitVecRes_ExtendOp(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_InregOp(SDNode *N, SDValue &Lo, SDValue &Hi);
Expand Down Expand Up @@ -955,6 +957,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
// Widen Vector Result Promotion.
void WidenVectorResult(SDNode *N, unsigned ResNo);
SDValue WidenVecRes_MERGE_VALUES(SDNode* N, unsigned ResNo);
SDValue WidenVecRes_ADDRSPACECAST(SDNode *N);
SDValue WidenVecRes_AssertZext(SDNode* N);
SDValue WidenVecRes_BITCAST(SDNode* N);
SDValue WidenVecRes_BUILD_VECTOR(SDNode* N);
Expand Down
65 changes: 65 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Analysis/MemoryLocation.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/TypeSize.h"
Expand Down Expand Up @@ -116,6 +117,9 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
case ISD::FCANONICALIZE:
R = ScalarizeVecRes_UnaryOp(N);
break;
case ISD::ADDRSPACECAST:
R = ScalarizeVecRes_ADDRSPACECAST(N);
break;
case ISD::FFREXP:
R = ScalarizeVecRes_FFREXP(N, ResNo);
break;
Expand Down Expand Up @@ -475,6 +479,31 @@ SDValue DAGTypeLegalizer::ScalarizeVecRes_VecInregOp(SDNode *N) {
llvm_unreachable("Illegal extend_vector_inreg opcode");
}

SDValue DAGTypeLegalizer::ScalarizeVecRes_ADDRSPACECAST(SDNode *N) {
EVT DestVT = N->getValueType(0).getVectorElementType();
SDValue Op = N->getOperand(0);
EVT OpVT = Op.getValueType();
SDLoc DL(N);
// The result needs scalarizing, but it's not a given that the source does.
// This is a workaround for targets where it's impossible to scalarize the
// result of a conversion, because the source type is legal.
// For instance, this happens on AArch64: v1i1 is illegal but v1i{8,16,32}
// are widened to v8i8, v4i16, and v2i32, which is legal, because v1i64 is
// legal and was not scalarized.
// See the similar logic in ScalarizeVecRes_SETCC
if (getTypeAction(OpVT) == TargetLowering::TypeScalarizeVector) {
Op = GetScalarizedVector(Op);
} else {
EVT VT = OpVT.getVectorElementType();
Op = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op,
DAG.getVectorIdxConstant(0, DL));
}
auto *AddrSpaceCastN = cast<AddrSpaceCastSDNode>(N);
unsigned SrcAS = AddrSpaceCastN->getSrcAddressSpace();
unsigned DestAS = AddrSpaceCastN->getDestAddressSpace();
return DAG.getAddrSpaceCast(DL, DestVT, Op, SrcAS, DestAS);
}

SDValue DAGTypeLegalizer::ScalarizeVecRes_SCALAR_TO_VECTOR(SDNode *N) {
// If the operand is wider than the vector element type then it is implicitly
// truncated. Make that explicit here.
Expand Down Expand Up @@ -1122,6 +1151,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::FCANONICALIZE:
SplitVecRes_UnaryOp(N, Lo, Hi);
break;
case ISD::ADDRSPACECAST:
SplitVecRes_ADDRSPACECAST(N, Lo, Hi);
break;
case ISD::FFREXP:
SplitVecRes_FFREXP(N, ResNo, Lo, Hi);
break;
Expand Down Expand Up @@ -2353,6 +2385,26 @@ void DAGTypeLegalizer::SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo,
Hi = DAG.getNode(Opcode, dl, HiVT, {Hi, MaskHi, EVLHi}, Flags);
}

void DAGTypeLegalizer::SplitVecRes_ADDRSPACECAST(SDNode *N, SDValue &Lo,
SDValue &Hi) {
SDLoc dl(N);
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(N->getValueType(0));

// If the input also splits, handle it directly for a compile time speedup.
// Otherwise split it by hand.
EVT InVT = N->getOperand(0).getValueType();
if (getTypeAction(InVT) == TargetLowering::TypeSplitVector)
GetSplitVector(N->getOperand(0), Lo, Hi);
else
std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);

auto *AddrSpaceCastN = cast<AddrSpaceCastSDNode>(N);
unsigned SrcAS = AddrSpaceCastN->getSrcAddressSpace();
unsigned DestAS = AddrSpaceCastN->getDestAddressSpace();
Lo = DAG.getAddrSpaceCast(dl, LoVT, Lo, SrcAS, DestAS);
Hi = DAG.getAddrSpaceCast(dl, HiVT, Hi, SrcAS, DestAS);
}

void DAGTypeLegalizer::SplitVecRes_FFREXP(SDNode *N, unsigned ResNo,
SDValue &Lo, SDValue &Hi) {
SDLoc dl(N);
Expand Down Expand Up @@ -4121,6 +4173,9 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
report_fatal_error("Do not know how to widen the result of this operator!");

case ISD::MERGE_VALUES: Res = WidenVecRes_MERGE_VALUES(N, ResNo); break;
case ISD::ADDRSPACECAST:
Res = WidenVecRes_ADDRSPACECAST(N);
break;
case ISD::AssertZext: Res = WidenVecRes_AssertZext(N); break;
case ISD::BITCAST: Res = WidenVecRes_BITCAST(N); break;
case ISD::BUILD_VECTOR: Res = WidenVecRes_BUILD_VECTOR(N); break;
Expand Down Expand Up @@ -5086,6 +5141,16 @@ SDValue DAGTypeLegalizer::WidenVecRes_MERGE_VALUES(SDNode *N, unsigned ResNo) {
return GetWidenedVector(WidenVec);
}

SDValue DAGTypeLegalizer::WidenVecRes_ADDRSPACECAST(SDNode *N) {
EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
SDValue InOp = GetWidenedVector(N->getOperand(0));
auto *AddrSpaceCastN = cast<AddrSpaceCastSDNode>(N);

return DAG.getAddrSpaceCast(SDLoc(N), WidenVT, InOp,
AddrSpaceCastN->getSrcAddressSpace(),
AddrSpaceCastN->getDestAddressSpace());
}

SDValue DAGTypeLegalizer::WidenVecRes_BITCAST(SDNode *N) {
SDValue InOp = N->getOperand(0);
EVT InVT = InOp.getValueType();
Expand Down
92 changes: 92 additions & 0 deletions llvm/test/CodeGen/NVPTX/addrspacecast.ll
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,95 @@ define i32 @conv8(ptr %ptr) {
%val = load i32, ptr addrspace(5) %specptr
ret i32 %val
}

; Check that we support addrspacecast when splitting the vector
; result (<2 x ptr> => 2 x <1 x ptr>).
; This also checks that scalarization works for addrspacecast
; (when going from <1 x ptr> to ptr.)
; ALL-LABEL: split1To0
define void @split1To0(ptr nocapture noundef readonly %xs) {
; CLS32: cvta.global.u32
; CLS32: cvta.global.u32
; CLS64: cvta.global.u64
; CLS64: cvta.global.u64
; ALL: st.u32
; ALL: st.u32
%vec_addr = load <2 x ptr addrspace(1)>, ptr %xs, align 16
%addrspacecast = addrspacecast <2 x ptr addrspace(1)> %vec_addr to <2 x ptr>
%extractelement0 = extractelement <2 x ptr> %addrspacecast, i64 0
store float 0.5, ptr %extractelement0, align 4
%extractelement1 = extractelement <2 x ptr> %addrspacecast, i64 1
store float 1.0, ptr %extractelement1, align 4
ret void
}

; Same as split1To0 but from 0 to 1, to make sure the addrspacecast preserve
; the source and destination addrspaces properly.
; ALL-LABEL: split0To1
define void @split0To1(ptr nocapture noundef readonly %xs) {
; CLS32: cvta.to.global.u32
; CLS32: cvta.to.global.u32
; CLS64: cvta.to.global.u64
; CLS64: cvta.to.global.u64
; ALL: st.global.u32
; ALL: st.global.u32
%vec_addr = load <2 x ptr>, ptr %xs, align 16
%addrspacecast = addrspacecast <2 x ptr> %vec_addr to <2 x ptr addrspace(1)>
%extractelement0 = extractelement <2 x ptr addrspace(1)> %addrspacecast, i64 0
store float 0.5, ptr addrspace(1) %extractelement0, align 4
%extractelement1 = extractelement <2 x ptr addrspace(1)> %addrspacecast, i64 1
store float 1.0, ptr addrspace(1) %extractelement1, align 4
ret void
}

; Check that we support addrspacecast when a widening is required
; (3 x ptr => 4 x ptr).
; ALL-LABEL: widen1To0
define void @widen1To0(ptr nocapture noundef readonly %xs) {
; CLS32: cvta.global.u32
; CLS32: cvta.global.u32
; CLS32: cvta.global.u32

; CLS64: cvta.global.u64
; CLS64: cvta.global.u64
; CLS64: cvta.global.u64

; ALL: st.u32
; ALL: st.u32
; ALL: st.u32
%vec_addr = load <3 x ptr addrspace(1)>, ptr %xs, align 16
%addrspacecast = addrspacecast <3 x ptr addrspace(1)> %vec_addr to <3 x ptr>
%extractelement0 = extractelement <3 x ptr> %addrspacecast, i64 0
store float 0.5, ptr %extractelement0, align 4
%extractelement1 = extractelement <3 x ptr> %addrspacecast, i64 1
store float 1.0, ptr %extractelement1, align 4
%extractelement2 = extractelement <3 x ptr> %addrspacecast, i64 2
store float 1.5, ptr %extractelement2, align 4
ret void
}

; Same as widen1To0 but from 0 to 1, to make sure the addrspacecast preserve
; the source and destination addrspaces properly.
; ALL-LABEL: widen0To1
define void @widen0To1(ptr nocapture noundef readonly %xs) {
; CLS32: cvta.to.global.u32
; CLS32: cvta.to.global.u32
; CLS32: cvta.to.global.u32

; CLS64: cvta.to.global.u64
; CLS64: cvta.to.global.u64
; CLS64: cvta.to.global.u64

; ALL: st.global.u32
; ALL: st.global.u32
; ALL: st.global.u32
%vec_addr = load <3 x ptr>, ptr %xs, align 16
%addrspacecast = addrspacecast <3 x ptr> %vec_addr to <3 x ptr addrspace(1)>
%extractelement0 = extractelement <3 x ptr addrspace(1)> %addrspacecast, i64 0
store float 0.5, ptr addrspace(1) %extractelement0, align 4
%extractelement1 = extractelement <3 x ptr addrspace(1)> %addrspacecast, i64 1
store float 1.0, ptr addrspace(1) %extractelement1, align 4
%extractelement2 = extractelement <3 x ptr addrspace(1)> %addrspacecast, i64 2
store float 1.5, ptr addrspace(1) %extractelement2, align 4
ret void
}