Skip to content

Commit b74e588

Browse files
authored
[NVPTX] Don't use stack memory when bitcasting to/from v2i8 (#113928)
`v2i8` is an unsupported type, so we hit the default legalization rules which perform the bitcast in stack memory and is very inefficient on GPU. This adds a custom lowering where we pack `v2i8` into `i16` and from there use another bitcast node to reach the final desired type. And also the inverse unpacking `i16` into `v2i8`.
1 parent 58f525a commit b74e588

File tree

3 files changed

+97
-0
lines changed

3 files changed

+97
-0
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,13 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
409409
return VectorInfo;
410410
}
411411

412+
static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT,
413+
SDValue Value) {
414+
if (Value->getValueType(0) == VT)
415+
return Value;
416+
return DAG.getNode(ISD::BITCAST, DL, VT, Value);
417+
}
418+
412419
// NVPTXTargetLowering Constructor.
413420
NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
414421
const NVPTXSubtarget &STI)
@@ -551,6 +558,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
551558
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
552559
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
553560
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
561+
562+
// Custom conversions to/from v2i8.
563+
setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
564+
554565
// Only logical ops can be done on v4i8 directly, others must be done
555566
// elementwise.
556567
setOperationAction(
@@ -2309,6 +2320,30 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
23092320
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
23102321
}
23112322

2323+
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2324+
// Handle bitcasting from v2i8 without hitting the default promotion
2325+
// strategy which goes through stack memory.
2326+
EVT FromVT = Op->getOperand(0)->getValueType(0);
2327+
if (FromVT != MVT::v2i8) {
2328+
return Op;
2329+
}
2330+
2331+
// Pack vector elements into i16 and bitcast to final type
2332+
SDLoc DL(Op);
2333+
SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2334+
Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
2335+
SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2336+
Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
2337+
SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
2338+
SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
2339+
SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
2340+
SDValue AsInt = DAG.getNode(
2341+
ISD::OR, DL, MVT::i16,
2342+
{Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
2343+
EVT ToVT = Op->getValueType(0);
2344+
return MaybeBitcast(DAG, DL, ToVT, AsInt);
2345+
}
2346+
23122347
// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
23132348
// would get lowered as two constant loads and vector-packing move.
23142349
// Instead we want just a constant move:
@@ -2817,6 +2852,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28172852
return Op;
28182853
case ISD::BUILD_VECTOR:
28192854
return LowerBUILD_VECTOR(Op, DAG);
2855+
case ISD::BITCAST:
2856+
return LowerBITCAST(Op, DAG);
28202857
case ISD::EXTRACT_SUBVECTOR:
28212858
return Op;
28222859
case ISD::EXTRACT_VECTOR_ELT:
@@ -6127,6 +6164,28 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
61276164
return SDValue();
61286165
}
61296166

6167+
static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
6168+
SmallVectorImpl<SDValue> &Results) {
6169+
// Handle bitcasting to v2i8 without hitting the default promotion
6170+
// strategy which goes through stack memory.
6171+
SDValue Op(Node, 0);
6172+
EVT ToVT = Op->getValueType(0);
6173+
if (ToVT != MVT::v2i8) {
6174+
return;
6175+
}
6176+
6177+
// Bitcast to i16 and unpack elements into a vector
6178+
SDLoc DL(Node);
6179+
SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0));
6180+
SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
6181+
SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
6182+
SDValue Vec1 =
6183+
DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
6184+
DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
6185+
Results.push_back(
6186+
DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
6187+
}
6188+
61306189
/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
61316190
static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
61326191
SmallVectorImpl<SDValue> &Results) {
@@ -6412,6 +6471,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
64126471
switch (N->getOpcode()) {
64136472
default:
64146473
report_fatal_error("Unhandled custom legalization");
6474+
case ISD::BITCAST:
6475+
ReplaceBITCAST(N, DAG, Results);
6476+
return;
64156477
case ISD::LOAD:
64166478
ReplaceLoadVector(N, DAG, Results);
64176479
return;

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,8 @@ class NVPTXTargetLowering : public TargetLowering {
616616
const NVPTXSubtarget &STI; // cache the subtarget here
617617
SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const;
618618

619+
SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
620+
619621
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
620622
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
621623
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx80 -asm-verbose=false \
2+
; RUN: -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
3+
; RUN: | FileCheck %s
4+
; RUN: %if ptxas %{ \
5+
; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -asm-verbose=false \
6+
; RUN: -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
7+
; RUN: | %ptxas-verify -arch=sm_90 \
8+
; RUN: %}
9+
10+
target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
11+
12+
; CHECK-LABEL: test_bitcast_2xi8_i16(
13+
; CHECK: ld.param.u32 %r1, [test_bitcast_2xi8_i16_param_0];
14+
; CHECK: mov.b32 {%rs1, %rs2}, %r1;
15+
; CHECK: shl.b16 %rs3, %rs2, 8;
16+
; CHECK: and.b16 %rs4, %rs1, 255;
17+
; CHECK: or.b16 %rs5, %rs4, %rs3;
18+
; CHECK: cvt.u32.u16 %r2, %rs5;
19+
; CHECK: st.param.b32 [func_retval0], %r2;
20+
define i16 @test_bitcast_2xi8_i16(<2 x i8> %a) {
21+
%res = bitcast <2 x i8> %a to i16
22+
ret i16 %res
23+
}
24+
25+
; CHECK-LABEL: test_bitcast_i16_2xi8(
26+
; CHECK: ld.param.u16 %rs1, [test_bitcast_i16_2xi8_param_0];
27+
; CHECK: shr.u16 %rs2, %rs1, 8;
28+
; CHECK: mov.b32 %r1, {%rs1, %rs2};
29+
; CHECK: st.param.b32 [func_retval0], %r1;
30+
define <2 x i8> @test_bitcast_i16_2xi8(i16 %a) {
31+
%res = bitcast i16 %a to <2 x i8>
32+
ret <2 x i8> %res
33+
}

0 commit comments

Comments
 (0)