Skip to content

Commit 6efdcc1

Browse files
authored
[NVPTX] Fixup EXT_LOAD lowering for i128 values (#138049)
Ensure that when custom lowering a vector load/store to a multi-output load/store node we confirm that the memory value type matches the type used by the node. Also add some asserts for basic sanity checking of load size. Fixes #138034
1 parent e6d7f46 commit 6efdcc1

File tree

3 files changed

+58
-12
lines changed

3 files changed

+58
-12
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "llvm/Support/CommandLine.h"
2727
#include "llvm/Support/ErrorHandling.h"
2828
#include "llvm/Support/FormatVariadic.h"
29+
#include "llvm/Support/MathExtras.h"
2930
#include <optional>
3031

3132
using namespace llvm;
@@ -1141,6 +1142,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11411142
else
11421143
FromType = getLdStRegType(ScalarVT);
11431144

1145+
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
1146+
FromTypeWidth <= 128 && "Invalid width for load");
1147+
11441148
// Create the machine instruction DAG
11451149
SDValue Offset, Base;
11461150
SelectADDR(N->getOperand(1), Base, Offset);
@@ -1236,6 +1240,9 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
12361240
FromType = NVPTX::PTXLdStInstCode::Untyped;
12371241
}
12381242

1243+
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
1244+
FromTypeWidth <= 128 && TotalWidth <= 128 && "Invalid width for load");
1245+
12391246
SDValue Offset, Base;
12401247
SelectADDR(N->getOperand(1), Base, Offset);
12411248
SDValue Ops[] = {getI32Imm(Ordering, DL),
@@ -1453,6 +1460,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
14531460
// Create the machine instruction DAG
14541461
SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal();
14551462

1463+
assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 && ToTypeWidth <= 128 &&
1464+
"Invalid width for store");
1465+
14561466
SDValue Offset, Base;
14571467
SelectADDR(ST->getBasePtr(), Base, Offset);
14581468

@@ -1537,6 +1547,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
15371547
ToType = NVPTX::PTXLdStInstCode::Untyped;
15381548
}
15391549

1550+
assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 && ToTypeWidth <= 128 &&
1551+
TotalWidth <= 128 && "Invalid width for store");
1552+
15401553
SDValue Offset, Base;
15411554
SelectADDR(N2, Base, Offset);
15421555

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3191,20 +3191,25 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
31913191

31923192
SDValue
31933193
NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3194-
SDNode *N = Op.getNode();
3194+
MemSDNode *N = cast<MemSDNode>(Op.getNode());
31953195
SDValue Val = N->getOperand(1);
31963196
SDLoc DL(N);
3197-
EVT ValVT = Val.getValueType();
3197+
const EVT ValVT = Val.getValueType();
3198+
const EVT MemVT = N->getMemoryVT();
3199+
3200+
// If we're truncating as part of the store, avoid lowering to a StoreV node.
3201+
// TODO: consider relaxing this restriction.
3202+
if (ValVT != MemVT)
3203+
return SDValue();
31983204

31993205
const auto NumEltsAndEltVT = getVectorLoweringShape(ValVT);
32003206
if (!NumEltsAndEltVT)
32013207
return SDValue();
32023208
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
32033209

3204-
MemSDNode *MemSD = cast<MemSDNode>(N);
32053210
const DataLayout &TD = DAG.getDataLayout();
32063211

3207-
Align Alignment = MemSD->getAlign();
3212+
Align Alignment = N->getAlign();
32083213
Align PrefAlign = TD.getPrefTypeAlign(ValVT.getTypeForEVT(*DAG.getContext()));
32093214
if (Alignment < PrefAlign) {
32103215
// This store is not sufficiently aligned, so bail out and let this vector
@@ -3267,7 +3272,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32673272

32683273
SDValue NewSt =
32693274
DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
3270-
MemSD->getMemoryVT(), MemSD->getMemOperand());
3275+
N->getMemoryVT(), N->getMemOperand());
32713276

32723277
// return DCI.CombineTo(N, NewSt, true);
32733278
return NewSt;
@@ -5762,20 +5767,23 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
57625767
/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
57635768
static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
57645769
SmallVectorImpl<SDValue> &Results) {
5765-
const EVT ResVT = N->getValueType(0);
5766-
SDLoc DL(N);
5770+
LoadSDNode *LD = cast<LoadSDNode>(N);
5771+
const EVT ResVT = LD->getValueType(0);
5772+
const EVT MemVT = LD->getMemoryVT();
5773+
5774+
// If we're doing sign/zero extension as part of the load, avoid lowering to
5775+
// a LoadV node. TODO: consider relaxing this restriction.
5776+
if (ResVT != MemVT)
5777+
return;
57675778

57685779
const auto NumEltsAndEltVT = getVectorLoweringShape(ResVT);
57695780
if (!NumEltsAndEltVT)
57705781
return;
57715782
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
57725783

5773-
LoadSDNode *LD = cast<LoadSDNode>(N);
5774-
57755784
Align Alignment = LD->getAlign();
57765785
const auto &TD = DAG.getDataLayout();
5777-
Align PrefAlign =
5778-
TD.getPrefTypeAlign(LD->getMemoryVT().getTypeForEVT(*DAG.getContext()));
5786+
Align PrefAlign = TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DAG.getContext()));
57795787
if (Alignment < PrefAlign) {
57805788
// This load is not sufficiently aligned, so bail out and let this vector
57815789
// load be scalarized. Note that we may still be able to emit smaller
@@ -5806,9 +5814,10 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
58065814
break;
58075815
}
58085816
}
5817+
SDLoc DL(LD);
58095818

58105819
// Copy regular operands
5811-
SmallVector<SDValue, 8> OtherOps(N->ops());
5820+
SmallVector<SDValue, 8> OtherOps(LD->ops());
58125821

58135822
// The select routine does not have access to the LoadSDNode instance, so
58145823
// pass along the extension information

llvm/test/CodeGen/NVPTX/i128-ld-st.ll

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -O0 -mcpu=sm_20 | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s -O0 -mcpu=sm_20 | %ptxas-verify %}
4+
5+
target triple = "nvptx64-nvidia-cuda"
6+
7+
define i128 @foo(ptr %p, ptr %o) {
8+
; CHECK-LABEL: foo(
9+
; CHECK: {
10+
; CHECK-NEXT: .reg .b64 %rd<5>;
11+
; CHECK-EMPTY:
12+
; CHECK-NEXT: // %bb.0:
13+
; CHECK-NEXT: ld.param.u64 %rd2, [foo_param_1];
14+
; CHECK-NEXT: ld.param.u64 %rd1, [foo_param_0];
15+
; CHECK-NEXT: ld.u8 %rd3, [%rd1];
16+
; CHECK-NEXT: mov.b64 %rd4, 0;
17+
; CHECK-NEXT: st.v2.u64 [%rd2], {%rd3, %rd4};
18+
; CHECK-NEXT: st.param.v2.b64 [func_retval0], {%rd3, %rd4};
19+
; CHECK-NEXT: ret;
20+
%c = load i8, ptr %p, align 1
21+
%i = zext i8 %c to i128
22+
store i128 %i, ptr %o, align 16
23+
ret i128 %i
24+
}

0 commit comments

Comments
 (0)