Skip to content
This repository was archived by the owner on Mar 28, 2020. It is now read-only.

Commit ab75a17

Browse files
committed
[x86] convert masked store of one element to scalar store
Another opportunity to reduce masked stores: in D16691, we decided not to attempt the 'one mask element is set' transform in InstCombine, but this should be a win for any AVX machine. Code comments note that this transform could be extended for other targets / cases. Differential Revision: http://reviews.llvm.org/D16828 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@260145 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent 1d15d57 commit ab75a17

File tree

2 files changed

+161
-29
lines changed

2 files changed

+161
-29
lines changed

lib/Target/X86/X86ISelLowering.cpp

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26784,13 +26784,86 @@ static SDValue PerformMLOADCombine(SDNode *N, SelectionDAG &DAG,
2678426784
return DCI.CombineTo(N, NewVec, WideLd.getValue(1), true);
2678526785
}
2678626786

26787-
/// PerformMSTORECombine - Resolve truncating stores
26787+
26788+
/// If exactly one element of the mask is set for a non-truncating masked store,
26789+
/// it is a vector extract and scalar store.
26790+
/// Note: It is expected that the degenerate cases of an all-zeros or all-ones
26791+
/// mask have already been optimized in IR, so we don't bother with those here.
26792+
static SDValue reduceMaskedStoreToScalarStore(MaskedStoreSDNode *MS,
26793+
SelectionDAG &DAG) {
26794+
// TODO: This is not x86-specific, so it could be lifted to DAGCombiner.
26795+
// However, some target hooks may need to be added to know when the transform
26796+
// is profitable. Endianness would also have to be considered.
26797+
26798+
// If V is a build vector of boolean constants and exactly one of those
26799+
// constants is true, return the operand index of that true element.
26800+
// Otherwise, return -1.
26801+
auto getOneTrueElt = [](SDValue V) {
26802+
// This needs to be a build vector of booleans.
26803+
// TODO: Checking for the i1 type matches the IR definition for the mask,
26804+
// but the mask check could be loosened to i8 or other types. That might
26805+
// also require checking more than 'allOnesValue'; eg, the x86 HW
26806+
// instructions only require that the MSB is set for each mask element.
26807+
// The ISD::MSTORE comments/definition do not specify how the mask operand
26808+
// is formatted.
26809+
auto *BV = dyn_cast<BuildVectorSDNode>(V);
26810+
if (!BV || BV->getValueType(0).getVectorElementType() != MVT::i1)
26811+
return -1;
26812+
26813+
int TrueIndex = -1;
26814+
unsigned NumElts = BV->getValueType(0).getVectorNumElements();
26815+
for (unsigned i = 0; i < NumElts; ++i) {
26816+
const SDValue &Op = BV->getOperand(i);
26817+
if (Op.getOpcode() == ISD::UNDEF)
26818+
continue;
26819+
auto *ConstNode = dyn_cast<ConstantSDNode>(Op);
26820+
if (!ConstNode)
26821+
return -1;
26822+
if (ConstNode->getAPIntValue().isAllOnesValue()) {
26823+
// If we already found a one, this is too many.
26824+
if (TrueIndex >= 0)
26825+
return -1;
26826+
TrueIndex = i;
26827+
}
26828+
}
26829+
return TrueIndex;
26830+
};
26831+
26832+
int TrueMaskElt = getOneTrueElt(MS->getMask());
26833+
if (TrueMaskElt < 0)
26834+
return SDValue();
26835+
26836+
SDLoc DL(MS);
26837+
EVT VT = MS->getValue().getValueType();
26838+
EVT EltVT = VT.getVectorElementType();
26839+
26840+
// Extract the one scalar element that is actually being stored.
26841+
SDValue ExtractIndex = DAG.getIntPtrConstant(TrueMaskElt, DL);
26842+
SDValue Extract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT,
26843+
MS->getValue(), ExtractIndex);
26844+
26845+
// Store that element at the appropriate offset from the base pointer.
26846+
SDValue StoreAddr = MS->getBasePtr();
26847+
unsigned EltSize = EltVT.getStoreSize();
26848+
if (TrueMaskElt != 0) {
26849+
unsigned StoreOffset = TrueMaskElt * EltSize;
26850+
SDValue StoreOffsetVal = DAG.getIntPtrConstant(StoreOffset, DL);
26851+
StoreAddr = DAG.getNode(ISD::ADD, DL, StoreAddr.getValueType(), StoreAddr,
26852+
StoreOffsetVal);
26853+
}
26854+
unsigned Alignment = MinAlign(MS->getAlignment(), EltSize);
26855+
return DAG.getStore(MS->getChain(), DL, Extract, StoreAddr,
26856+
MS->getPointerInfo(), MS->isVolatile(),
26857+
MS->isNonTemporal(), Alignment);
26858+
}
26859+
2678826860
static SDValue PerformMSTORECombine(SDNode *N, SelectionDAG &DAG,
2678926861
const X86Subtarget &Subtarget) {
2679026862
MaskedStoreSDNode *Mst = cast<MaskedStoreSDNode>(N);
2679126863
if (!Mst->isTruncatingStore())
26792-
return SDValue();
26864+
return reduceMaskedStoreToScalarStore(Mst, DAG);
2679326865

26866+
// Resolve truncating stores.
2679426867
EVT VT = Mst->getValue().getValueType();
2679526868
unsigned NumElems = VT.getVectorNumElements();
2679626869
EVT StVT = Mst->getMemoryVT();

test/CodeGen/X86/masked_memop.ll

Lines changed: 86 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -991,36 +991,92 @@ define void @test21(<4 x i32> %trigger, <4 x i32>* %addr, <4 x i32> %val) {
991991
ret void
992992
}
993993

994-
define void @test22(<4 x i32> %trigger, <4 x i32>* %addr, <4 x i32> %val) {
995-
; AVX1-LABEL: test22:
996-
; AVX1: ## BB#0:
997-
; AVX1-NEXT: movl $-1, %eax
998-
; AVX1-NEXT: vmovd %eax, %xmm0
999-
; AVX1-NEXT: vmaskmovps %xmm1, %xmm0, (%rdi)
1000-
; AVX1-NEXT: retq
994+
; When only one element of the mask is set, reduce to a scalar store.
995+
996+
define void @one_mask_bit_set1(<4 x i32>* %addr, <4 x i32> %val) {
997+
; AVX-LABEL: one_mask_bit_set1:
998+
; AVX: ## BB#0:
999+
; AVX-NEXT: vmovd %xmm0, (%rdi)
1000+
; AVX-NEXT: retq
10011001
;
1002-
; AVX2-LABEL: test22:
1003-
; AVX2: ## BB#0:
1004-
; AVX2-NEXT: movl $-1, %eax
1005-
; AVX2-NEXT: vmovd %eax, %xmm0
1006-
; AVX2-NEXT: vpmaskmovd %xmm1, %xmm0, (%rdi)
1007-
; AVX2-NEXT: retq
1002+
; AVX512-LABEL: one_mask_bit_set1:
1003+
; AVX512: ## BB#0:
1004+
; AVX512-NEXT: vmovd %xmm0, (%rdi)
1005+
; AVX512-NEXT: retq
1006+
call void @llvm.masked.store.v4i32(<4 x i32> %val, <4 x i32>* %addr, i32 4, <4 x i1><i1 true, i1 false, i1 false, i1 false>)
1007+
ret void
1008+
}
1009+
1010+
; Choose a different element to show that the correct address offset is produced.
1011+
1012+
define void @one_mask_bit_set2(<4 x float>* %addr, <4 x float> %val) {
1013+
; AVX-LABEL: one_mask_bit_set2:
1014+
; AVX: ## BB#0:
1015+
; AVX-NEXT: vextractps $2, %xmm0, 8(%rdi)
1016+
; AVX-NEXT: retq
10081017
;
1009-
; AVX512F-LABEL: test22:
1010-
; AVX512F: ## BB#0:
1011-
; AVX512F-NEXT: movl $-1, %eax
1012-
; AVX512F-NEXT: vmovd %eax, %xmm0
1013-
; AVX512F-NEXT: vpmaskmovd %xmm1, %xmm0, (%rdi)
1014-
; AVX512F-NEXT: retq
1018+
; AVX512-LABEL: one_mask_bit_set2:
1019+
; AVX512: ## BB#0:
1020+
; AVX512-NEXT: vextractps $2, %xmm0, 8(%rdi)
1021+
; AVX512-NEXT: retq
1022+
call void @llvm.masked.store.v4f32(<4 x float> %val, <4 x float>* %addr, i32 4, <4 x i1><i1 false, i1 false, i1 true, i1 false>)
1023+
ret void
1024+
}
1025+
1026+
; Choose a different scalar type and a high element of a 256-bit vector because AVX doesn't support those evenly.
1027+
1028+
define void @one_mask_bit_set3(<4 x i64>* %addr, <4 x i64> %val) {
1029+
; AVX-LABEL: one_mask_bit_set3:
1030+
; AVX: ## BB#0:
1031+
; AVX-NEXT: vextractf128 $1, %ymm0, %xmm0
1032+
; AVX-NEXT: vmovlps %xmm0, 16(%rdi)
1033+
; AVX-NEXT: vzeroupper
1034+
; AVX-NEXT: retq
10151035
;
1016-
; SKX-LABEL: test22:
1017-
; SKX: ## BB#0:
1018-
; SKX-NEXT: movb $1, %al
1019-
; SKX-NEXT: kmovw %eax, %k1
1020-
; SKX-NEXT: vmovdqu32 %xmm1, (%rdi) {%k1}
1021-
; SKX-NEXT: retq
1022-
%mask = icmp eq <4 x i32> %trigger, zeroinitializer
1023-
call void @llvm.masked.store.v4i32(<4 x i32>%val, <4 x i32>* %addr, i32 4, <4 x i1><i1 true, i1 false, i1 false, i1 false>)
1036+
; AVX512-LABEL: one_mask_bit_set3:
1037+
; AVX512: ## BB#0:
1038+
; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm0
1039+
; AVX512-NEXT: vmovq %xmm0, 16(%rdi)
1040+
; AVX512-NEXT: retq
1041+
call void @llvm.masked.store.v4i64(<4 x i64> %val, <4 x i64>* %addr, i32 4, <4 x i1><i1 false, i1 false, i1 true, i1 false>)
1042+
ret void
1043+
}
1044+
1045+
; Choose a different scalar type and a high element of a 256-bit vector because AVX doesn't support those evenly.
1046+
1047+
define void @one_mask_bit_set4(<4 x double>* %addr, <4 x double> %val) {
1048+
; AVX-LABEL: one_mask_bit_set4:
1049+
; AVX: ## BB#0:
1050+
; AVX-NEXT: vextractf128 $1, %ymm0, %xmm0
1051+
; AVX-NEXT: vmovhpd %xmm0, 24(%rdi)
1052+
; AVX-NEXT: vzeroupper
1053+
; AVX-NEXT: retq
1054+
;
1055+
; AVX512-LABEL: one_mask_bit_set4:
1056+
; AVX512: ## BB#0:
1057+
; AVX512-NEXT: vextractf128 $1, %ymm0, %xmm0
1058+
; AVX512-NEXT: vmovhpd %xmm0, 24(%rdi)
1059+
; AVX512-NEXT: retq
1060+
call void @llvm.masked.store.v4f64(<4 x double> %val, <4 x double>* %addr, i32 4, <4 x i1><i1 false, i1 false, i1 false, i1 true>)
1061+
ret void
1062+
}
1063+
1064+
; Try a 512-bit vector to make sure AVX doesn't die and AVX512 works as expected.
1065+
1066+
define void @one_mask_bit_set5(<8 x double>* %addr, <8 x double> %val) {
1067+
; AVX-LABEL: one_mask_bit_set5:
1068+
; AVX: ## BB#0:
1069+
; AVX-NEXT: vextractf128 $1, %ymm1, %xmm0
1070+
; AVX-NEXT: vmovlps %xmm0, 48(%rdi)
1071+
; AVX-NEXT: vzeroupper
1072+
; AVX-NEXT: retq
1073+
;
1074+
; AVX512-LABEL: one_mask_bit_set5:
1075+
; AVX512: ## BB#0:
1076+
; AVX512-NEXT: vextractf32x4 $3, %zmm0, %xmm0
1077+
; AVX512-NEXT: vmovlpd %xmm0, 48(%rdi)
1078+
; AVX512-NEXT: retq
1079+
call void @llvm.masked.store.v8f64(<8 x double> %val, <8 x double>* %addr, i32 4, <8 x i1><i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 false>)
10241080
ret void
10251081
}
10261082

@@ -1030,8 +1086,10 @@ declare <2 x i32> @llvm.masked.load.v2i32(<2 x i32>*, i32, <2 x i1>, <2 x i32>)
10301086
declare void @llvm.masked.store.v16i32(<16 x i32>, <16 x i32>*, i32, <16 x i1>)
10311087
declare void @llvm.masked.store.v8i32(<8 x i32>, <8 x i32>*, i32, <8 x i1>)
10321088
declare void @llvm.masked.store.v4i32(<4 x i32>, <4 x i32>*, i32, <4 x i1>)
1089+
declare void @llvm.masked.store.v4i64(<4 x i64>, <4 x i64>*, i32, <4 x i1>)
10331090
declare void @llvm.masked.store.v2f32(<2 x float>, <2 x float>*, i32, <2 x i1>)
10341091
declare void @llvm.masked.store.v2i32(<2 x i32>, <2 x i32>*, i32, <2 x i1>)
1092+
declare void @llvm.masked.store.v4f32(<4 x float>, <4 x float>*, i32, <4 x i1>)
10351093
declare void @llvm.masked.store.v16f32(<16 x float>, <16 x float>*, i32, <16 x i1>)
10361094
declare void @llvm.masked.store.v16f32p(<16 x float>*, <16 x float>**, i32, <16 x i1>)
10371095
declare <16 x float> @llvm.masked.load.v16f32(<16 x float>*, i32, <16 x i1>, <16 x float>)
@@ -1043,6 +1101,7 @@ declare <8 x double> @llvm.masked.load.v8f64(<8 x double>*, i32, <8 x i1>, <8 x
10431101
declare <4 x double> @llvm.masked.load.v4f64(<4 x double>*, i32, <4 x i1>, <4 x double>)
10441102
declare <2 x double> @llvm.masked.load.v2f64(<2 x double>*, i32, <2 x i1>, <2 x double>)
10451103
declare void @llvm.masked.store.v8f64(<8 x double>, <8 x double>*, i32, <8 x i1>)
1104+
declare void @llvm.masked.store.v4f64(<4 x double>, <4 x double>*, i32, <4 x i1>)
10461105
declare void @llvm.masked.store.v2f64(<2 x double>, <2 x double>*, i32, <2 x i1>)
10471106
declare void @llvm.masked.store.v2i64(<2 x i64>, <2 x i64>*, i32, <2 x i1>)
10481107

0 commit comments

Comments
 (0)