Skip to content

Commit 173fcf7

Browse files
authored
[NVPTX] Lower 16xi8 and 8xi8 stores efficiently (llvm#73646)
Lower 16xi8 vector stores in NVPTX ISel efficiently using st.v4.b32 instead of multiple st.v4.u8 along the lines of vector loads and 8xf16. Similarly, 8xi8 using st.v2.u32.
1 parent dbbc7c3 commit 173fcf7

File tree

3 files changed

+69
-7
lines changed

3 files changed

+69
-7
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
508508
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
509509
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);
510510

511+
// Conversion to/from i8/i8x4 is always legal.
511512
setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom);
512513
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
513514
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
@@ -717,8 +718,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
717718

718719
// We have some custom DAG combine patterns for these nodes
719720
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
720-
ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM,
721-
ISD::VSELECT});
721+
ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::STORE,
722+
ISD::UREM, ISD::VSELECT});
722723

723724
// setcc for f16x2 and bf16x2 needs special handling to prevent
724725
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -2916,7 +2917,6 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
29162917
DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
29172918
MemSD->getMemoryVT(), MemSD->getMemOperand());
29182919

2919-
// return DCI.CombineTo(N, NewSt, true);
29202920
return NewSt;
29212921
}
29222922

@@ -5557,6 +5557,51 @@ static SDValue PerformLOADCombine(SDNode *N,
55575557
DL);
55585558
}
55595559

5560+
// Lower a v16i8 (or a v8i8) store into a StoreV4 (or StoreV2) operation with
5561+
// i32 results instead of letting ReplaceLoadVector split it into smaller stores
5562+
// during legalization. This is done at dag-combine1 time, so that vector
5563+
// operations with i8 elements can be optimised away instead of being needlessly
5564+
// split during legalization, which involves storing to the stack and loading it
5565+
// back.
5566+
static SDValue PerformSTORECombine(SDNode *N,
5567+
TargetLowering::DAGCombinerInfo &DCI) {
5568+
SelectionDAG &DAG = DCI.DAG;
5569+
StoreSDNode *ST = cast<StoreSDNode>(N);
5570+
EVT VT = ST->getValue().getValueType();
5571+
if (VT != MVT::v16i8 && VT != MVT::v8i8)
5572+
return SDValue();
5573+
5574+
// Create a v4i32 vector store operation, effectively <4 x v4i8>.
5575+
unsigned Opc = VT == MVT::v16i8 ? NVPTXISD::StoreV4 : NVPTXISD::StoreV2;
5576+
EVT NewVT = VT == MVT::v16i8 ? MVT::v4i32 : MVT::v2i32;
5577+
unsigned NumElts = NewVT.getVectorNumElements();
5578+
5579+
// Create a vector of the type required by the new store: v16i8 -> v4i32.
5580+
SDValue NewStoreValue = DCI.DAG.getBitcast(NewVT, ST->getValue());
5581+
5582+
// Operands for the store.
5583+
SmallVector<SDValue, 8> Ops;
5584+
Ops.reserve(N->getNumOperands() + NumElts - 1);
5585+
// Chain value.
5586+
Ops.push_back(N->ops().front());
5587+
5588+
SDLoc DL(N);
5589+
SmallVector<SDValue> Elts(NumElts);
5590+
// Break v4i32 (or v2i32) into four (or two) elements.
5591+
for (unsigned I = 0; I < NumElts; ++I)
5592+
Elts[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
5593+
NewStoreValue.getValueType().getVectorElementType(),
5594+
NewStoreValue, DAG.getIntPtrConstant(I, DL));
5595+
Ops.append(Elts.begin(), Elts.end());
5596+
// Any remaining operands.
5597+
Ops.append(N->op_begin() + 2, N->op_end());
5598+
5599+
SDValue NewStore = DAG.getMemIntrinsicNode(Opc, DL, DAG.getVTList(MVT::Other),
5600+
Ops, NewVT, ST->getMemOperand());
5601+
// Return the new chain.
5602+
return NewStore.getValue(0);
5603+
}
5604+
55605605
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
55615606
DAGCombinerInfo &DCI) const {
55625607
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5578,6 +5623,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
55785623
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
55795624
case ISD::LOAD:
55805625
return PerformLOADCombine(N, DCI);
5626+
case ISD::STORE:
5627+
return PerformSTORECombine(N, DCI);
55815628
case NVPTXISD::StoreRetval:
55825629
case NVPTXISD::StoreRetvalV2:
55835630
case NVPTXISD::StoreRetvalV4:

llvm/test/CodeGen/NVPTX/i8x4-instructions.ll

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -790,10 +790,9 @@ define void @test_ldst_v8i8(ptr %a, ptr %b) {
790790
; CHECK-NEXT: // %bb.0:
791791
; CHECK-NEXT: ld.param.u64 %rd2, [test_ldst_v8i8_param_1];
792792
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldst_v8i8_param_0];
793-
; CHECK-NEXT: ld.u32 %r1, [%rd1];
794-
; CHECK-NEXT: ld.u32 %r2, [%rd1+4];
795-
; CHECK-NEXT: st.u32 [%rd2+4], %r2;
796-
; CHECK-NEXT: st.u32 [%rd2], %r1;
793+
; CHECK-NEXT: ld.u32 %r1, [%rd1+4];
794+
; CHECK-NEXT: ld.u32 %r2, [%rd1];
795+
; CHECK-NEXT: st.v2.u32 [%rd2], {%r2, %r1};
797796
; CHECK-NEXT: ret;
798797
%t1 = load <8 x i8>, ptr %a
799798
store <8 x i8> %t1, ptr %b, align 16

llvm/test/CodeGen/NVPTX/vector-stores.ll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,19 @@ define void @v16i8(ptr %a, ptr %b) {
3737
store <16 x i8> %v, ptr %b
3838
ret void
3939
}
40+
41+
; CHECK-LABEL: .visible .func v16i8_store
42+
define void @v16i8_store(ptr %a, <16 x i8> %v) {
43+
; CHECK: ld.param.u64 %rd1, [v16i8_store_param_0];
44+
; CHECK-NEXT: ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [v16i8_store_param_1];
45+
; CHECK-NEXT: st.v4.u32 [%rd1], {%r1, %r2, %r3, %r4};
46+
store <16 x i8> %v, ptr %a
47+
ret void
48+
}
49+
50+
; CHECK-LABEL: .visible .func v8i8_store
51+
define void @v8i8_store(ptr %a, <8 x i8> %v) {
52+
; CHECK: st.v2.u32
53+
store <8 x i8> %v, ptr %a
54+
ret void
55+
}

0 commit comments

Comments
 (0)