Skip to content

Commit c197301

Browse files
committed
[NVPTX] Lower 16xi8 and 8xi8 stores efficiently
Lower 16xi8 vector stores in NVPTX ISel efficiently using st.v4.b32 instead of multiple st.v4.u8 along the lines of vector loads and 8xf16. Similarly, 8xi8 using st.v2.u32.
1 parent 246b8ea commit c197301

File tree

3 files changed

+69
-7
lines changed

3 files changed

+69
-7
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
508508
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
509509
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);
510510

511+
// Conversion to/from i8/i8x4 is always legal.
511512
setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom);
512513
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
513514
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
@@ -717,8 +718,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
717718

718719
// We have some custom DAG combine patterns for these nodes
719720
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
720-
ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM,
721-
ISD::VSELECT});
721+
ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::STORE,
722+
ISD::UREM, ISD::VSELECT});
722723

723724
// setcc for f16x2 and bf16x2 needs special handling to prevent
724725
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -2916,7 +2917,6 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
29162917
DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
29172918
MemSD->getMemoryVT(), MemSD->getMemOperand());
29182919

2919-
// return DCI.CombineTo(N, NewSt, true);
29202920
return NewSt;
29212921
}
29222922

@@ -5557,6 +5557,51 @@ static SDValue PerformLOADCombine(SDNode *N,
55575557
DL);
55585558
}
55595559

5560+
// Lower a v16i8 (or a v8i8) store into a StoreV4 (or StoreV2) operation with
5561+
// i32 results instead of letting ReplaceLoadVector split it into smaller stores
5562+
// during legalization. This is done at dag-combine1 time, so that vector
5563+
// operations with i8 elements can be optimised away instead of being needlessly
5564+
// split during legalization, which involves storing to the stack and loading it
5565+
// back.
5566+
static SDValue PerformSTORECombine(SDNode *N,
5567+
TargetLowering::DAGCombinerInfo &DCI) {
5568+
SelectionDAG &DAG = DCI.DAG;
5569+
StoreSDNode *ST = cast<StoreSDNode>(N);
5570+
EVT VT = ST->getValue().getValueType();
5571+
if (VT != MVT::v16i8 && VT != MVT::v8i8)
5572+
return SDValue();
5573+
5574+
// Create a v4i32 vector store operation, effectively <4 x v4i8>.
5575+
unsigned Opc = VT == MVT::v16i8 ? NVPTXISD::StoreV4 : NVPTXISD::StoreV2;
5576+
EVT NewVT = VT == MVT::v16i8 ? MVT::v4i32 : MVT::v2i32;
5577+
unsigned NumElts = NewVT.getVectorNumElements();
5578+
5579+
// Create a vector of the type required by the new store: v16i8 -> v4i32.
5580+
SDValue NewStoreValue = DCI.DAG.getBitcast(NewVT, ST->getValue());
5581+
5582+
// Operands for the store.
5583+
SmallVector<SDValue, 8> Ops;
5584+
Ops.reserve(N->getNumOperands() + NumElts - 1);
5585+
// Chain value.
5586+
Ops.push_back(N->ops().front());
5587+
5588+
SDLoc DL(N);
5589+
SmallVector<SDValue> Elts(NumElts);
5590+
// Break v4i32 (or v2i32) into four (or two) elements.
5591+
for (unsigned I = 0; I < NumElts; ++I)
5592+
Elts[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
5593+
NewStoreValue.getValueType().getVectorElementType(),
5594+
NewStoreValue, DAG.getIntPtrConstant(I, DL));
5595+
Ops.append(Elts.begin(), Elts.end());
5596+
// Any remaining operands.
5597+
Ops.append(N->op_begin() + 2, N->op_end());
5598+
5599+
SDValue NewStore = DAG.getMemIntrinsicNode(Opc, DL, DAG.getVTList(MVT::Other),
5600+
Ops, NewVT, ST->getMemOperand());
5601+
// Return the new chain.
5602+
return NewStore.getValue(0);
5603+
}
5604+
55605605
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
55615606
DAGCombinerInfo &DCI) const {
55625607
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5578,6 +5623,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
55785623
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
55795624
case ISD::LOAD:
55805625
return PerformLOADCombine(N, DCI);
5626+
case ISD::STORE:
5627+
return PerformSTORECombine(N, DCI);
55815628
case NVPTXISD::StoreRetval:
55825629
case NVPTXISD::StoreRetvalV2:
55835630
case NVPTXISD::StoreRetvalV4:

llvm/test/CodeGen/NVPTX/i8x4-instructions.ll

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -790,10 +790,9 @@ define void @test_ldst_v8i8(ptr %a, ptr %b) {
790790
; CHECK-NEXT: // %bb.0:
791791
; CHECK-NEXT: ld.param.u64 %rd2, [test_ldst_v8i8_param_1];
792792
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldst_v8i8_param_0];
793-
; CHECK-NEXT: ld.u32 %r1, [%rd1];
794-
; CHECK-NEXT: ld.u32 %r2, [%rd1+4];
795-
; CHECK-NEXT: st.u32 [%rd2+4], %r2;
796-
; CHECK-NEXT: st.u32 [%rd2], %r1;
793+
; CHECK-NEXT: ld.u32 %r1, [%rd1+4];
794+
; CHECK-NEXT: ld.u32 %r2, [%rd1];
795+
; CHECK-NEXT: st.v2.u32 [%rd2], {%r2, %r1};
797796
; CHECK-NEXT: ret;
798797
%t1 = load <8 x i8>, ptr %a
799798
store <8 x i8> %t1, ptr %b, align 16

llvm/test/CodeGen/NVPTX/vector-stores.ll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,19 @@ define void @v16i8(ptr %a, ptr %b) {
3737
store <16 x i8> %v, ptr %b
3838
ret void
3939
}
40+
41+
; CHECK-LABEL: .visible .func v16i8_store
42+
define void @v16i8_store(ptr %a, <16 x i8> %v) {
43+
; CHECK: ld.param.u64 %rd1, [v16i8_store_param_0];
44+
; CHECK-NEXT: ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [v16i8_store_param_1];
45+
; CHECK-NEXT: st.v4.u32 [%rd1], {%r1, %r2, %r3, %r4};
46+
store <16 x i8> %v, ptr %a
47+
ret void
48+
}
49+
50+
; CHECK-LABEL: .visible .func v8i8_store
51+
define void @v8i8_store(ptr %a, <8 x i8> %v) {
52+
; CHECK: st.v2.u32
53+
store <8 x i8> %v, ptr %a
54+
ret void
55+
}

0 commit comments

Comments
 (0)