-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NVPTX] Use v2.u64 to load/store 128-bit values #136638
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Use v2.u64 to load/store 128-bit values #136638
Conversation
CC @dakersnar |
@llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesPatch is 22.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136638.diff 8 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ec1f969494cd1..69bb4a097effa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1168,9 +1168,10 @@ static bool isVectorElementTypeUpsized(EVT EltVT) {
bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
MemSDNode *MemSD = cast<MemSDNode>(N);
- EVT LoadedVT = MemSD->getMemoryVT();
- if (!LoadedVT.isSimple())
+ EVT MemEVT = MemSD->getMemoryVT();
+ if (!MemEVT.isSimple())
return false;
+ MVT MemVT = MemEVT.getSimpleVT();
// Address Space Setting
unsigned int CodeAddrSpace = getCodeAddrSpace(MemSD);
@@ -1178,50 +1179,43 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
return tryLDGLDU(N);
}
+ EVT EltVT = N->getValueType(0);
SDLoc DL(N);
SDValue Chain = N->getOperand(0);
auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, MemSD);
- // Vector Setting
- MVT SimpleVT = LoadedVT.getSimpleVT();
-
// Type Setting: fromType + fromTypeWidth
//
// Sign : ISD::SEXTLOAD
// Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
// type is integer
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
- MVT ScalarVT = SimpleVT.getScalarType();
// Read at least 8 bits (predicates are stored as 8-bit values)
- unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
- unsigned int FromType;
// The last operand holds the original LoadSDNode::getExtensionType() value
- unsigned ExtensionType = cast<ConstantSDNode>(
- N->getOperand(N->getNumOperands() - 1))->getZExtValue();
- if (ExtensionType == ISD::SEXTLOAD)
- FromType = NVPTX::PTXLdStInstCode::Signed;
- else
- FromType = getLdStRegType(ScalarVT);
+ const unsigned TotalWidth = MemVT.getSizeInBits();
+ unsigned ExtensionType = N->getConstantOperandVal(N->getNumOperands() - 1);
+ unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
+ ? NVPTX::PTXLdStInstCode::Signed
+ : getLdStRegType(MemVT.getScalarType());
unsigned VecType;
-
+ unsigned FromTypeWidth;
switch (N->getOpcode()) {
case NVPTXISD::LoadV2:
+ FromTypeWidth = TotalWidth / 2;
VecType = NVPTX::PTXLdStInstCode::V2;
break;
case NVPTXISD::LoadV4:
+ FromTypeWidth = TotalWidth / 4;
VecType = NVPTX::PTXLdStInstCode::V4;
break;
default:
return false;
}
- EVT EltVT = N->getValueType(0);
-
if (isVectorElementTypeUpsized(EltVT)) {
EltVT = MVT::i32;
FromType = NVPTX::PTXLdStInstCode::Untyped;
- FromTypeWidth = 32;
}
SDValue Offset, Base;
@@ -1271,9 +1265,14 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// LDG/LDU SD node (from custom vector handling), then its the second operand
SDValue Op1 = N->getOperand(N->getOpcode() == ISD::INTRINSIC_W_CHAIN ? 2 : 1);
- EVT OrigType = N->getValueType(0);
+ const EVT OrigType = N->getValueType(0);
EVT EltVT = Mem->getMemoryVT();
unsigned NumElts = 1;
+
+ if (EltVT == MVT::i128 || EltVT == MVT::f128) {
+ EltVT = MVT::i64;
+ NumElts = 2;
+ }
if (EltVT.isVector()) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
@@ -1293,11 +1292,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// Build the "promoted" result VTList for the load. If we are really loading
// i8s, then the return type will be promoted to i16 since we do not expose
// 8-bit registers in NVPTX.
- EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
+ const EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
SmallVector<EVT, 5> InstVTs;
- for (unsigned i = 0; i != NumElts; ++i) {
- InstVTs.push_back(NodeVT);
- }
+ InstVTs.append(NumElts, NodeVT);
InstVTs.push_back(MVT::Other);
SDVTList InstVTList = CurDAG->getVTList(InstVTs);
SDValue Chain = N->getOperand(0);
@@ -1476,6 +1473,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
EVT EltVT = Op1.getValueType();
MemSDNode *MemSD = cast<MemSDNode>(N);
EVT StoreVT = MemSD->getMemoryVT();
+ assert(StoreVT.isSimple() && "Store value is not simple");
// Address Space Setting
unsigned CodeAddrSpace = getCodeAddrSpace(MemSD);
@@ -1490,35 +1488,35 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
// Type Setting: toType + toTypeWidth
// - for integer type, always use 'u'
- assert(StoreVT.isSimple() && "Store value is not simple");
- MVT ScalarVT = StoreVT.getSimpleVT().getScalarType();
- unsigned ToTypeWidth = ScalarVT.getSizeInBits();
- unsigned ToType = getLdStRegType(ScalarVT);
+ const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
+ unsigned ToType = getLdStRegType(StoreVT.getSimpleVT().getScalarType());
SmallVector<SDValue, 12> Ops;
SDValue N2;
unsigned VecType;
+ unsigned ToTypeWidth;
switch (N->getOpcode()) {
case NVPTXISD::StoreV2:
VecType = NVPTX::PTXLdStInstCode::V2;
Ops.append({N->getOperand(1), N->getOperand(2)});
N2 = N->getOperand(3);
+ ToTypeWidth = TotalWidth / 2;
break;
case NVPTXISD::StoreV4:
VecType = NVPTX::PTXLdStInstCode::V4;
Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
N->getOperand(4)});
N2 = N->getOperand(5);
+ ToTypeWidth = TotalWidth / 4;
break;
default:
return false;
}
if (isVectorElementTypeUpsized(EltVT)) {
- EltVT = MVT::i32;
ToType = NVPTX::PTXLdStInstCode::Untyped;
- ToTypeWidth = 32;
+ EltVT = MVT::i32;
}
SDValue Offset, Base;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 277a34173e7b8..0cf9a75dda443 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -178,18 +178,24 @@ static bool Is16bitsType(MVT VT) {
// 2. If we do want to handle it, returns two parameters:
// - unsigned int NumElts - The number of elements in the final vector
// - EVT EltVT - The type of the elements in the final vector
-static std::optional<std::pair<unsigned int, EVT>>
-getVectorLoweringShape(EVT VectorVT) {
- if (!VectorVT.isVector() || !VectorVT.isSimple())
+static std::optional<std::pair<unsigned int, MVT>>
+getVectorLoweringShape(EVT VectorEVT) {
+ if (!VectorEVT.isSimple())
return std::nullopt;
+ const MVT VectorVT = VectorEVT.getSimpleVT();
- EVT EltVT = VectorVT.getVectorElementType();
- unsigned NumElts = VectorVT.getVectorNumElements();
+ if (!VectorVT.isVector()) {
+ assert(VectorVT == MVT::i128 || VectorVT == MVT::f128);
+ return {{2, MVT::i64}};
+ }
+
+ const MVT EltVT = VectorVT.getVectorElementType();
+ const unsigned NumElts = VectorVT.getVectorNumElements();
// We only handle "native" vector sizes for now, e.g. <4 x double> is not
// legal. We can (and should) split that into 2 stores of <2 x double> here
// but I'm leaving that as a TODO for now.
- switch (VectorVT.getSimpleVT().SimpleTy) {
+ switch (VectorVT.SimpleTy) {
default:
return std::nullopt;
case MVT::v2i8:
@@ -225,8 +231,7 @@ getVectorLoweringShape(EVT VectorVT) {
// Number of elements to pack in one word.
unsigned NPerWord = 32 / EltVT.getSizeInBits();
- return std::pair(NumElts / NPerWord,
- MVT::getVectorVT(EltVT.getSimpleVT(), NPerWord));
+ return std::pair(NumElts / NPerWord, MVT::getVectorVT(EltVT, NPerWord));
}
llvm_unreachable("All cases in switch should return.");
@@ -749,13 +754,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
// Register custom handling for vector loads/stores
- for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
- if (IsPTXVectorType(VT)) {
- setOperationAction(ISD::LOAD, VT, Custom);
- setOperationAction(ISD::STORE, VT, Custom);
- setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
- }
- }
+ for (MVT VT : MVT::fixedlen_vector_valuetypes())
+ if (IsPTXVectorType(VT))
+ setOperationAction({ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN}, VT,
+ Custom);
+
+ setOperationAction({ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN},
+ {MVT::i128, MVT::f128}, Custom);
// Support varargs.
setOperationAction(ISD::VASTART, MVT::Other, Custom);
@@ -3144,10 +3149,7 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
if (Isv2x16VT(VT) || VT == MVT::v4i8)
return SDValue();
- if (VT.isVector())
- return LowerSTOREVector(Op, DAG);
-
- return SDValue();
+ return LowerSTOREVector(Op, DAG);
}
SDValue
@@ -3157,10 +3159,10 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
SDLoc DL(N);
EVT ValVT = Val.getValueType();
- auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
+ const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
if (!NumEltsAndEltVT)
return SDValue();
- auto [NumElts, EltVT] = NumEltsAndEltVT.value();
+ const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
MemSDNode *MemSD = cast<MemSDNode>(N);
const DataLayout &TD = DAG.getDataLayout();
@@ -3176,14 +3178,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
return SDValue();
}
- // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
- // Therefore, we must ensure the type is legal. For i1 and i8, we set the
- // stored type to i16 and propagate the "real" type as the memory type.
- bool NeedExt = false;
- if (EltVT.getSizeInBits() < 16)
- NeedExt = true;
-
- unsigned Opcode = 0;
+ unsigned Opcode;
switch (NumElts) {
default:
return SDValue();
@@ -3201,28 +3196,33 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
Ops.push_back(N->getOperand(0));
// Then the split values
- assert(NumElts <= ValVT.getVectorNumElements() &&
- "NumElts should not increase, only decrease or stay the same.");
- if (NumElts < ValVT.getVectorNumElements()) {
- // If the number of elements has decreased, getVectorLoweringShape has
- // upsized the element types
- assert(EltVT.isVector() && EltVT.getSizeInBits() == 32 &&
- EltVT.getVectorNumElements() <= 4 && "Unexpected upsized type.");
+ // assert(NumElts <= ValVT.getVectorNumElements() &&
+ // "NumElts should not increase, only decrease or stay the same.");
+ if (EltVT.isVector()) {
+ assert(EVT(EltVT.getVectorElementType()) == ValVT.getVectorElementType());
+ assert(NumElts * EltVT.getVectorNumElements() ==
+ ValVT.getVectorNumElements());
// Combine individual elements into v2[i,f,bf]16/v4i8 subvectors to be
// stored as b32s
- unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
- for (unsigned i = 0; i < NumElts; ++i) {
+ const unsigned NumEltsPerSubVector = EltVT.getVectorNumElements();
+ for (auto I : llvm::seq(NumElts)) {
SmallVector<SDValue, 4> SubVectorElts;
- DAG.ExtractVectorElements(Val, SubVectorElts, i * NumEltsPerSubVector,
+ DAG.ExtractVectorElements(Val, SubVectorElts, I * NumEltsPerSubVector,
NumEltsPerSubVector);
SDValue SubVector = DAG.getBuildVector(EltVT, DL, SubVectorElts);
Ops.push_back(SubVector);
}
} else {
- for (unsigned i = 0; i < NumElts; ++i) {
- SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
- DAG.getIntPtrConstant(i, DL));
- if (NeedExt)
+ SDValue V = DAG.getBitcast(MVT::getVectorVT(EltVT, NumElts), Val);
+ for (auto I : llvm::seq(NumElts)) {
+ SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, V,
+ DAG.getIntPtrConstant(I, DL));
+
+ // Since StoreV2 is a target node, we cannot rely on DAG type
+ // legalization. Therefore, we must ensure the type is legal. For i1 and
+ // i8, we set the stored type to i16 and propagate the "real" type as the
+ // memory type.
+ if (EltVT.getSizeInBits() < 16)
ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
Ops.push_back(ExtVal);
}
@@ -5756,20 +5756,18 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
SmallVectorImpl<SDValue> &Results) {
- EVT ResVT = N->getValueType(0);
+ const EVT ResVT = N->getValueType(0);
SDLoc DL(N);
- assert(ResVT.isVector() && "Vector load must have vector type");
-
- auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
+ const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
if (!NumEltsAndEltVT)
return;
- auto [NumElts, EltVT] = NumEltsAndEltVT.value();
+ const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
LoadSDNode *LD = cast<LoadSDNode>(N);
Align Alignment = LD->getAlign();
- auto &TD = DAG.getDataLayout();
+ const auto &TD = DAG.getDataLayout();
Align PrefAlign =
TD.getPrefTypeAlign(LD->getMemoryVT().getTypeForEVT(*DAG.getContext()));
if (Alignment < PrefAlign) {
@@ -5784,26 +5782,21 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
// Since LoadV2 is a target node, we cannot rely on DAG type legalization.
// Therefore, we must ensure the type is legal. For i1 and i8, we set the
// loaded type to i16 and propagate the "real" type as the memory type.
- bool NeedTrunc = false;
- if (EltVT.getSizeInBits() < 16) {
- EltVT = MVT::i16;
- NeedTrunc = true;
- }
+ const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
- unsigned Opcode = 0;
+ unsigned Opcode;
SDVTList LdResVTs;
-
switch (NumElts) {
default:
return;
case 2:
Opcode = NVPTXISD::LoadV2;
- LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
+ LdResVTs = DAG.getVTList(LoadEltVT, LoadEltVT, MVT::Other);
break;
case 4: {
Opcode = NVPTXISD::LoadV4;
- EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
- LdResVTs = DAG.getVTList(ListVTs);
+ LdResVTs =
+ DAG.getVTList({LoadEltVT, LoadEltVT, LoadEltVT, LoadEltVT, MVT::Other});
break;
}
}
@@ -5820,34 +5813,32 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
LD->getMemOperand());
SmallVector<SDValue> ScalarRes;
- assert(NumElts <= ResVT.getVectorNumElements() &&
- "NumElts should not increase, only decrease or stay the same.");
- if (NumElts < ResVT.getVectorNumElements()) {
- // If the number of elements has decreased, getVectorLoweringShape has
- // upsized the element types
- assert(EltVT.isVector() && EltVT.getSizeInBits() == 32 &&
- EltVT.getVectorNumElements() <= 4 && "Unexpected upsized type.");
+ if (EltVT.isVector()) {
+ assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
+ assert(NumElts * EltVT.getVectorNumElements() ==
+ ResVT.getVectorNumElements());
// Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
// into individual elements.
- for (unsigned i = 0; i < NumElts; ++i) {
- SDValue SubVector = NewLD.getValue(i);
+ for (auto I : llvm::seq(NumElts)) {
+ SDValue SubVector = NewLD.getValue(I);
DAG.ExtractVectorElements(SubVector, ScalarRes);
}
} else {
- for (unsigned i = 0; i < NumElts; ++i) {
- SDValue Res = NewLD.getValue(i);
- if (NeedTrunc)
- Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
+ for (auto I : llvm::seq(NumElts)) {
+ SDValue Res = NewLD.getValue(I);
+ if (LoadEltVT != EltVT)
+ Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
ScalarRes.push_back(Res);
}
}
SDValue LoadChain = NewLD.getValue(NumElts);
- SDValue BuildVec = DAG.getBuildVector(ResVT, DL, ScalarRes);
+ EVT BuildVecVT = ResVT.isVector() ? ResVT : MVT::getVectorVT(EltVT, NumElts);
+ SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes);
+ SDValue LoadValue = DAG.getBitcast(ResVT, BuildVec);
- Results.push_back(BuildVec);
- Results.push_back(LoadChain);
+ Results.append({LoadValue, LoadChain});
}
// Lower vector return type of tcgen05.ld intrinsics
diff --git a/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll b/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
index 5b96f4978a7cb..6907edcd0e04e 100644
--- a/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
+++ b/llvm/test/CodeGen/NVPTX/fp128-storage-type.ll
@@ -23,11 +23,9 @@ define void @load_store(ptr %in, ptr %out) {
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u64 %rd1, [load_store_param_0];
-; CHECK-NEXT: ld.u64 %rd2, [%rd1+8];
-; CHECK-NEXT: ld.u64 %rd3, [%rd1];
+; CHECK-NEXT: ld.v2.u64 {%rd2, %rd3}, [%rd1];
; CHECK-NEXT: ld.param.u64 %rd4, [load_store_param_1];
-; CHECK-NEXT: st.u64 [%rd4], %rd3;
-; CHECK-NEXT: st.u64 [%rd4+8], %rd2;
+; CHECK-NEXT: st.v2.u64 [%rd4], {%rd2, %rd3};
; CHECK-NEXT: ret;
%val = load fp128, ptr %in
store fp128 %val, ptr %out
diff --git a/llvm/test/CodeGen/NVPTX/i128-array.ll b/llvm/test/CodeGen/NVPTX/i128-array.ll
index fb69224e87d11..dd6d48bd5862c 100644
--- a/llvm/test/CodeGen/NVPTX/i128-array.ll
+++ b/llvm/test/CodeGen/NVPTX/i128-array.ll
@@ -30,12 +30,10 @@ define [2 x i128] @foo2(ptr byval([2 x i128]) %a) {
; CHECK-NEXT: .reg .b64 %rd<7>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.u64 %rd3, [foo2_param_0+8];
-; CHECK-NEXT: ld.param.u64 %rd4, [foo2_param_0];
-; CHECK-NEXT: ld.param.u64 %rd5, [foo2_param_0+24];
-; CHECK-NEXT: ld.param.u64 %rd6, [foo2_param_0+16];
-; CHECK-NEXT: st.param.v2.b64 [func_retval0], {%rd4, %rd3};
-; CHECK-NEXT: st.param.v2.b64 [func_retval0+16], {%rd6, %rd5};
+; CHECK-NEXT: ld.param.v2.u64 {%rd3, %rd4}, [foo2_param_0];
+; CHECK-NEXT: ld.param.v2.u64 {%rd5, %rd6}, [foo2_param_0+16];
+; CHECK-NEXT: st.param.v2.b64 [func_retval0], {%rd3, %rd4};
+; CHECK-NEXT: st.param.v2.b64 [func_retval0+16], {%rd5, %rd6};
; CHECK-NEXT: ret;
%ptr0 = getelementptr [2 x i128], ptr %a, i64 0, i32 0
%1 = load i128, i128* %ptr0
diff --git a/llvm/test/CodeGen/NVPTX/i128-retval.ll b/llvm/test/CodeGen/NVPTX/i128-retval.ll
index f9a23900484e4..a01d14d5ca776 100644
--- a/llvm/test/CodeGen/NVPTX/i128-retval.ll
+++ b/llvm/test/CodeGen/NVPTX/i128-retval.ll
@@ -21,8 +21,7 @@ start:
; CHECK: } // callseq 0
%a = call i128 @callee(i128 %0)
- ; CHECK-DAG: st.u64 [%[[OUT]]], %[[REG2]];
- ; CHECK-DAG: st.u64 [%[[OUT]]+8], %[[REG3]];
+ ; CHECK-DAG: st.v2.u64 [%[[OUT]]], {%[[REG2]], %[[REG3]]};
store i128 %a, ptr %1
ret void
diff --git a/llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll b/llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll
index 311741f737adc..67c074ca73156 100644
--- a/llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll
+++ b/llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll
@@ -35,11 +35,10 @@ define void @test_b128_input_from_load(ptr nocapture readonly %data) {
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u64 %rd2, [test_b128_input_from_load_param_0];
; CHECK-NEXT: cvta.to.global.u64 %rd3, %rd2;
-; CHECK-NEXT: ld.global.u64 %rd4, [%rd3+8];
-; CHECK-NEXT: ld.global.u64 %rd5, [%rd3];
-; CHECK-NEXT: mov.b128 %rq1, {%rd5, %rd4};
+; CHECK-NEXT: ld.global.v2.u64 {%rd4, %rd5}, [%rd3];
; CHECK-NEXT: mov.b64 %rd6, value;
; CHECK-NEXT: cvta.global.u64 %rd1, %rd6;
+; CHECK-NEXT: mov.b128 %rq1, {%rd4, %rd5};
; CHECK-NEXT: // begin inline asm
; CHECK-NEXT: { st.b128 [%rd1], %rq1; }
; CHECK-NEXT: // end inline asm
@@ -94,8 +93,7 @@ define void @test_store_b128_output() {
; CHECK-NEXT: mov.b128 {%rd1, %rd2}, %rq1;
; CHECK-NEXT: add.cc.s64 %rd3, %rd1, 1;
; CHECK-NEXT: addc.cc.s64 %rd4, %rd2, 0;
-; CHECK-NEXT: st.global.u64 [value+8], %rd4;
-; CHECK-NEXT: st.global.u64 [value], %rd3;
+; CHECK-NEXT: st.global.v2.u64 [value], {%rd3, %rd4};
; CHECK-NEXT: ret;
%1 = tail call i128 asm "{ mov.b128 $0, 41; }", "=q"()
%add = add nsw i128 %1, 1
@@ -113,17 +111,15 @@ define void @test_use_of_b128_output(ptr nocapture readonly %data) {
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u64 %rd1,...
[truncated]
|
3b81597
to
7f9e4b1
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
7f9e4b1
to
8fe589e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't have approval permissions but LGTM. Seems like a clean addition of the new coverage + some bundled simplifications that are nice to have.
if (EltVT.isVector()) { | ||
assert(EVT(EltVT.getVectorElementType()) == ValVT.getVectorElementType()); | ||
assert(NumElts * EltVT.getVectorNumElements() == | ||
ValVT.getVectorNumElements()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good simplification, easier to follow than the old logic.
const MVT BuildVecVT = | ||
MVT::getVectorVT(EltVT.getScalarType(), ScalarRes.size()); | ||
SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes); | ||
SDValue LoadValue = DAG.getBitcast(ResVT, BuildVec); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this extra logic is needed to handle the new 128-bit cases, right?
PTX apparently supports 128-bit ld/st/mov on sm_70+ and it looks like we've already landed bits and pieces of support for it last fall, while I wasn't looking. We'll probably need to plumb it through to support 128-bit atomics on sm_90, and it would simplify things here, too (though we'll still need these changes for older GPUs). @AlexMaclean is that something already on your radar? I was about to start poking at the atomics (#122760), but will be happy if it's about to be done by NVIDIA. :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall w/ few nits.
VecType = NVPTX::PTXLdStInstCode::V4; | ||
break; | ||
default: | ||
return false; | ||
} | ||
|
||
EVT EltVT = N->getValueType(0); | ||
|
||
if (isVectorElementTypeUpsized(EltVT)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably rename it to something more meaningful. isSubVectorPackedInI32
?
Otherwise it's not clear why we're using i32 here without having to go and look at the implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I've switched to isSubVectorPackedInI32
. This function could be used several other places as well, but leaving that for a subsequent change.
SDValue SubVector = DAG.getBuildVector(EltVT, DL, SubVectorElts); | ||
Ops.push_back(SubVector); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: We could just push_back getBuildVector result directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
// Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back | ||
// into individual elements. | ||
for (unsigned i = 0; i < NumElts; ++i) { | ||
SDValue SubVector = NewLD.getValue(i); | ||
for (const auto I : llvm::seq(NumElts)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd use the actual type instead of auto here. One less thing to think of. We're just enumerating elements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
SDValue Res = NewLD.getValue(i); | ||
if (NeedTrunc) | ||
Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res); | ||
for (const auto I : llvm::seq(NumElts)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
; CHECK-NEXT: ld.global.u64 %rd4, [%rd3+8]; | ||
; CHECK-NEXT: ld.global.u64 %rd5, [%rd3]; | ||
; CHECK-NEXT: mov.b128 %rq1, {%rd5, %rd4}; | ||
; CHECK-NEXT: ld.global.v2.u64 {%rd4, %rd5}, [%rd3]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^^^ This is where we could've loaded .b128 directly. Probably makes no difference on the SASS, though, just nicer looking PTX.
Yea, I've been thinking about improving our 128-bit register support for a while now. I'm hoping to get some time in the next week or two to take a deeper look and clean the situation up a bit. Sorry to have ignored #122760! I'll look into that too since I'm already trying to do some cleanup around atomics as well. |
There's nothing to be sorry about. It's a low priority bug that I just happened to look at yesterday. I mostly brought it up to make sure we're not stepping on each other's toes. |
No description provided.