Skip to content

Commit d6e25c4

Browse files
authored
[SelectionDAG] Take passthru into account when widening ISD::MLOAD (#144170)
#140595 used vp.load in the cases where we need to widen masked.load. However, we didn't account for the passthru operand so it might miscompile when the passthru is not undef. While we can simply avoid using vp.load to widen when passthru is not undef, doing so will ran into the exact same crash described in #140198 , so for scalable vector, this patch manually merges the vp.load result with passthru when the latter is not undef.
1 parent 577199f commit d6e25c4

File tree

3 files changed

+49
-3
lines changed

3 files changed

+49
-3
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6149,20 +6149,33 @@ SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) {
61496149

61506150
if (ExtType == ISD::NON_EXTLOAD &&
61516151
TLI.isOperationLegalOrCustom(ISD::VP_LOAD, WidenVT) &&
6152-
TLI.isTypeLegal(WideMaskVT)) {
6152+
TLI.isTypeLegal(WideMaskVT) &&
6153+
// If there is a passthru, we shouldn't use vp.load. However,
6154+
// type legalizer will struggle on masked.load with
6155+
// scalable vectors, so for scalable vectors, we still use vp.load
6156+
// but manually merge the load result with the passthru using vp.select.
6157+
(N->getPassThru()->isUndef() || VT.isScalableVector())) {
61536158
Mask = DAG.getInsertSubvector(dl, DAG.getUNDEF(WideMaskVT), Mask, 0);
61546159
SDValue EVL = DAG.getElementCount(dl, TLI.getVPExplicitVectorLengthTy(),
61556160
VT.getVectorElementCount());
61566161
SDValue NewLoad =
61576162
DAG.getLoadVP(N->getAddressingMode(), ISD::NON_EXTLOAD, WidenVT, dl,
61586163
N->getChain(), N->getBasePtr(), N->getOffset(), Mask, EVL,
61596164
N->getMemoryVT(), N->getMemOperand());
6165+
SDValue NewVal = NewLoad;
6166+
6167+
// Manually merge with vp.select
6168+
if (!N->getPassThru()->isUndef()) {
6169+
assert(WidenVT.isScalableVector());
6170+
NewVal =
6171+
DAG.getNode(ISD::VP_SELECT, dl, WidenVT, Mask, NewVal, PassThru, EVL);
6172+
}
61606173

61616174
// Modified the chain - switch anything that used the old chain to use
61626175
// the new one.
61636176
ReplaceValueWith(SDValue(N, 1), NewLoad.getValue(1));
61646177

6165-
return NewLoad;
6178+
return NewVal;
61666179
}
61676180

61686181
// The mask should be widened as well

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,16 @@ define <7 x i8> @masked_load_v7i8(ptr %a, <7 x i1> %mask) {
341341
ret <7 x i8> %load
342342
}
343343

344+
define <7 x i8> @masked_load_passthru_v7i8(ptr %a, <7 x i1> %mask) {
345+
; CHECK-LABEL: masked_load_passthru_v7i8:
346+
; CHECK: # %bb.0:
347+
; CHECK-NEXT: li a1, 127
348+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu
349+
; CHECK-NEXT: vmv.s.x v8, a1
350+
; CHECK-NEXT: vmand.mm v0, v0, v8
351+
; CHECK-NEXT: vmv.v.i v8, 0
352+
; CHECK-NEXT: vle8.v v8, (a0), v0.t
353+
; CHECK-NEXT: ret
354+
%load = call <7 x i8> @llvm.masked.load.v7i8(ptr %a, i32 8, <7 x i1> %mask, <7 x i8> zeroinitializer)
355+
ret <7 x i8> %load
356+
}

llvm/test/CodeGen/RISCV/rvv/masked-load-int.ll

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,27 @@ define <vscale x 1 x i8> @masked_load_nxv1i8(ptr %a, <vscale x 1 x i1> %mask) no
2121
%load = call <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr %a, i32 1, <vscale x 1 x i1> %mask, <vscale x 1 x i8> undef)
2222
ret <vscale x 1 x i8> %load
2323
}
24-
declare <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr, i32, <vscale x 1 x i1>, <vscale x 1 x i8>)
24+
25+
define <vscale x 1 x i8> @masked_load_passthru_nxv1i8(ptr %a, <vscale x 1 x i1> %mask) nounwind {
26+
; V-LABEL: masked_load_passthru_nxv1i8:
27+
; V: # %bb.0:
28+
; V-NEXT: vsetvli a1, zero, e8, mf8, ta, mu
29+
; V-NEXT: vmv.v.i v8, 0
30+
; V-NEXT: vle8.v v8, (a0), v0.t
31+
; V-NEXT: ret
32+
;
33+
; ZVE32-LABEL: masked_load_passthru_nxv1i8:
34+
; ZVE32: # %bb.0:
35+
; ZVE32-NEXT: csrr a1, vlenb
36+
; ZVE32-NEXT: srli a1, a1, 3
37+
; ZVE32-NEXT: vsetvli a2, zero, e8, mf4, ta, ma
38+
; ZVE32-NEXT: vmv.v.i v8, 0
39+
; ZVE32-NEXT: vsetvli zero, a1, e8, mf4, ta, mu
40+
; ZVE32-NEXT: vle8.v v8, (a0), v0.t
41+
; ZVE32-NEXT: ret
42+
%load = call <vscale x 1 x i8> @llvm.masked.load.nxv1i8(ptr %a, i32 1, <vscale x 1 x i1> %mask, <vscale x 1 x i8> zeroinitializer)
43+
ret <vscale x 1 x i8> %load
44+
}
2545

2646
define <vscale x 1 x i16> @masked_load_nxv1i16(ptr %a, <vscale x 1 x i1> %mask) nounwind {
2747
; V-LABEL: masked_load_nxv1i16:

0 commit comments

Comments
 (0)