Skip to content

Commit 2fa8e2d

Browse files
committed
[X86] Correct the MaskVT for avx512 gather/scatter intrinsics to use the min of the number of index and data elements.
When the result type is v2i64/v2f64 and the index element size is i32, the index vector has two unused elements making the type v4i32. The mask VT should match the number of memory accesses that will be made. This is consistent with the isel patterns used for the target independent gather/scatter intrinsic. llvm-svn: 350687
1 parent 634a143 commit 2fa8e2d

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22248,14 +22248,16 @@ static SDValue getGatherNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
2224822248
SDValue Src, SDValue Mask, SDValue Base,
2224922249
SDValue Index, SDValue ScaleOp, SDValue Chain,
2225022250
const X86Subtarget &Subtarget) {
22251+
MVT VT = Op.getSimpleValueType();
2225122252
SDLoc dl(Op);
2225222253
auto *C = dyn_cast<ConstantSDNode>(ScaleOp);
2225322254
// Scale must be constant.
2225422255
if (!C)
2225522256
return SDValue();
2225622257
SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8);
22257-
MVT MaskVT = MVT::getVectorVT(MVT::i1,
22258-
Index.getSimpleValueType().getVectorNumElements());
22258+
unsigned MinElts = std::min(Index.getSimpleValueType().getVectorNumElements(),
22259+
VT.getVectorNumElements());
22260+
MVT MaskVT = MVT::getVectorVT(MVT::i1, MinElts);
2225922261

2226022262
SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
2226122263
SDVTList VTs = DAG.getVTList(Op.getValueType(), MaskVT, MVT::Other);
@@ -22284,8 +22286,9 @@ static SDValue getScatterNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
2228422286
SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8);
2228522287
SDValue Disp = DAG.getTargetConstant(0, dl, MVT::i32);
2228622288
SDValue Segment = DAG.getRegister(0, MVT::i32);
22287-
MVT MaskVT = MVT::getVectorVT(MVT::i1,
22288-
Index.getSimpleValueType().getVectorNumElements());
22289+
unsigned MinElts = std::min(Index.getSimpleValueType().getVectorNumElements(),
22290+
Src.getSimpleValueType().getVectorNumElements());
22291+
MVT MaskVT = MVT::getVectorVT(MVT::i1, MinElts);
2228922292

2229022293
SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
2229122294
SDVTList VTs = DAG.getVTList(MaskVT, MVT::Other);

0 commit comments

Comments
 (0)