Skip to content

Commit c7dde5f

Browse files
committed
[x86] convert masked load of exactly one element to scalar load
This is the load counterpart to the store optimization that was added in: http://reviews.llvm.org/rL260145 llvm-svn: 260325
1 parent 1e5d7e2 commit c7dde5f

File tree

2 files changed

+158
-0
lines changed

2 files changed

+158
-0
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26748,10 +26748,53 @@ static int getOneTrueElt(SDValue V) {
2674826748
return TrueIndex;
2674926749
};
2675026750

26751+
/// If exactly one element of the mask is set for a non-extending masked load,
26752+
/// it is a scalar load and vector insert.
26753+
/// Note: It is expected that the degenerate cases of an all-zeros or all-ones
26754+
/// mask have already been optimized in IR, so we don't bother with those here.
26755+
static SDValue
26756+
reduceMaskedLoadToScalarLoad(MaskedLoadSDNode *ML, SelectionDAG &DAG,
26757+
TargetLowering::DAGCombinerInfo &DCI) {
26758+
// FIXME: Refactor shared/similar logic with reduceMaskedStoreToScalarStore().
26759+
26760+
// TODO: This is not x86-specific, so it could be lifted to DAGCombiner.
26761+
// However, some target hooks may need to be added to know when the transform
26762+
// is profitable. Endianness would also have to be considered.
26763+
26764+
int TrueMaskElt = getOneTrueElt(ML->getMask());
26765+
if (TrueMaskElt < 0)
26766+
return SDValue();
26767+
26768+
SDLoc DL(ML);
26769+
EVT VT = ML->getValueType(0);
26770+
EVT EltVT = VT.getVectorElementType();
26771+
26772+
// Load the one scalar element that is specified by the mask using the
26773+
// appropriate offset from the base pointer.
26774+
SDValue Addr = ML->getBasePtr();
26775+
if (TrueMaskElt != 0) {
26776+
unsigned Offset = TrueMaskElt * EltVT.getStoreSize();
26777+
Addr = DAG.getMemBasePlusOffset(Addr, Offset, DL);
26778+
}
26779+
unsigned Alignment = MinAlign(ML->getAlignment(), EltVT.getStoreSize());
26780+
SDValue Load = DAG.getLoad(EltVT, DL, ML->getChain(), Addr,
26781+
ML->getPointerInfo(), ML->isVolatile(),
26782+
ML->isNonTemporal(), ML->isInvariant(), Alignment);
26783+
26784+
// Insert the loaded element into the appropriate place in the vector.
26785+
SDValue InsertIndex = DAG.getIntPtrConstant(TrueMaskElt, DL);
26786+
SDValue Insert = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, ML->getSrc0(),
26787+
Load, InsertIndex);
26788+
return DCI.CombineTo(ML, Insert, Load.getValue(1), true);
26789+
}
26790+
2675126791
static SDValue PerformMLOADCombine(SDNode *N, SelectionDAG &DAG,
2675226792
TargetLowering::DAGCombinerInfo &DCI,
2675326793
const X86Subtarget &Subtarget) {
2675426794
MaskedLoadSDNode *Mld = cast<MaskedLoadSDNode>(N);
26795+
if (Mld->getExtensionType() == ISD::NON_EXTLOAD)
26796+
return reduceMaskedLoadToScalarLoad(Mld, DAG, DCI);
26797+
2675526798
if (Mld->getExtensionType() != ISD::SEXTLOAD)
2675626799
return SDValue();
2675726800

llvm/test/CodeGen/X86/masked_memop.ll

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,9 +1080,124 @@ define void @one_mask_bit_set5(<8 x double>* %addr, <8 x double> %val) {
10801080
ret void
10811081
}
10821082

1083+
; When only one element of the mask is set, reduce to a scalar load.
1084+
1085+
define <4 x i32> @load_one_mask_bit_set1(<4 x i32>* %addr, <4 x i32> %val) {
1086+
; AVX-LABEL: load_one_mask_bit_set1:
1087+
; AVX: ## BB#0:
1088+
; AVX-NEXT: vpinsrd $0, (%rdi), %xmm0, %xmm0
1089+
; AVX-NEXT: retq
1090+
;
1091+
; AVX512-LABEL: load_one_mask_bit_set1:
1092+
; AVX512: ## BB#0:
1093+
; AVX512-NEXT: vpinsrd $0, (%rdi), %xmm0, %xmm0
1094+
; AVX512-NEXT: retq
1095+
%res = call <4 x i32> @llvm.masked.load.v4i32(<4 x i32>* %addr, i32 4, <4 x i1><i1 true, i1 false, i1 false, i1 false>, <4 x i32> %val)
1096+
ret <4 x i32> %res
1097+
}
1098+
1099+
; Choose a different element to show that the correct address offset is produced.
1100+
1101+
define <4 x float> @load_one_mask_bit_set2(<4 x float>* %addr, <4 x float> %val) {
1102+
; AVX-LABEL: load_one_mask_bit_set2:
1103+
; AVX: ## BB#0:
1104+
; AVX-NEXT: vinsertps $32, 8(%rdi), %xmm0, %xmm0 ## xmm0 = xmm0[0,1],mem[0],xmm0[3]
1105+
; AVX-NEXT: retq
1106+
;
1107+
; AVX512-LABEL: load_one_mask_bit_set2:
1108+
; AVX512: ## BB#0:
1109+
; AVX512-NEXT: vinsertps $32, 8(%rdi), %xmm0, %xmm0 ## xmm0 = xmm0[0,1],mem[0],xmm0[3]
1110+
; AVX512-NEXT: retq
1111+
%res = call <4 x float> @llvm.masked.load.v4f32(<4 x float>* %addr, i32 4, <4 x i1><i1 false, i1 false, i1 true, i1 false>, <4 x float> %val)
1112+
ret <4 x float> %res
1113+
}
1114+
1115+
; Choose a different scalar type and a high element of a 256-bit vector because AVX doesn't support those evenly.
1116+
1117+
define <4 x i64> @load_one_mask_bit_set3(<4 x i64>* %addr, <4 x i64> %val) {
1118+
; AVX1-LABEL: load_one_mask_bit_set3:
1119+
; AVX1: ## BB#0:
1120+
; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm1
1121+
; AVX1-NEXT: vpinsrq $0, 16(%rdi), %xmm1, %xmm1
1122+
; AVX1-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0
1123+
; AVX1-NEXT: retq
1124+
;
1125+
; AVX2-LABEL: load_one_mask_bit_set3:
1126+
; AVX2: ## BB#0:
1127+
; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm1
1128+
; AVX2-NEXT: vpinsrq $0, 16(%rdi), %xmm1, %xmm1
1129+
; AVX2-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0
1130+
; AVX2-NEXT: retq
1131+
;
1132+
; AVX512F-LABEL: load_one_mask_bit_set3:
1133+
; AVX512F: ## BB#0:
1134+
; AVX512F-NEXT: vextracti128 $1, %ymm0, %xmm1
1135+
; AVX512F-NEXT: vpinsrq $0, 16(%rdi), %xmm1, %xmm1
1136+
; AVX512F-NEXT: vinserti128 $1, %xmm1, %ymm0, %ymm0
1137+
; AVX512F-NEXT: retq
1138+
;
1139+
; SKX-LABEL: load_one_mask_bit_set3:
1140+
; SKX: ## BB#0:
1141+
; SKX-NEXT: vextracti128 $1, %ymm0, %xmm1
1142+
; SKX-NEXT: vpinsrq $0, 16(%rdi), %xmm1, %xmm1
1143+
; SKX-NEXT: vinserti32x4 $1, %xmm1, %ymm0, %ymm0
1144+
; SKX-NEXT: retq
1145+
%res = call <4 x i64> @llvm.masked.load.v4i64(<4 x i64>* %addr, i32 4, <4 x i1><i1 false, i1 false, i1 true, i1 false>, <4 x i64> %val)
1146+
ret <4 x i64> %res
1147+
}
1148+
1149+
; Choose a different scalar type and a high element of a 256-bit vector because AVX doesn't support those evenly.
1150+
1151+
define <4 x double> @load_one_mask_bit_set4(<4 x double>* %addr, <4 x double> %val) {
1152+
; AVX-LABEL: load_one_mask_bit_set4:
1153+
; AVX: ## BB#0:
1154+
; AVX-NEXT: vextractf128 $1, %ymm0, %xmm1
1155+
; AVX-NEXT: vmovhpd 24(%rdi), %xmm1, %xmm1 ## xmm1 = xmm1[0],mem[0]
1156+
; AVX-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0
1157+
; AVX-NEXT: retq
1158+
;
1159+
; AVX512F-LABEL: load_one_mask_bit_set4:
1160+
; AVX512F: ## BB#0:
1161+
; AVX512F-NEXT: vextractf128 $1, %ymm0, %xmm1
1162+
; AVX512F-NEXT: vmovhpd 24(%rdi), %xmm1, %xmm1 ## xmm1 = xmm1[0],mem[0]
1163+
; AVX512F-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0
1164+
; AVX512F-NEXT: retq
1165+
;
1166+
; SKX-LABEL: load_one_mask_bit_set4:
1167+
; SKX: ## BB#0:
1168+
; SKX-NEXT: vextractf128 $1, %ymm0, %xmm1
1169+
; SKX-NEXT: vmovhpd 24(%rdi), %xmm1, %xmm1 ## xmm1 = xmm1[0],mem[0]
1170+
; SKX-NEXT: vinsertf32x4 $1, %xmm1, %ymm0, %ymm0
1171+
; SKX-NEXT: retq
1172+
%res = call <4 x double> @llvm.masked.load.v4f64(<4 x double>* %addr, i32 4, <4 x i1><i1 false, i1 false, i1 false, i1 true>, <4 x double> %val)
1173+
ret <4 x double> %res
1174+
}
1175+
1176+
; Try a 512-bit vector to make sure AVX doesn't die and AVX512 works as expected.
1177+
1178+
define <8 x double> @load_one_mask_bit_set5(<8 x double>* %addr, <8 x double> %val) {
1179+
; AVX-LABEL: load_one_mask_bit_set5:
1180+
; AVX: ## BB#0:
1181+
; AVX-NEXT: vextractf128 $1, %ymm1, %xmm2
1182+
; AVX-NEXT: vmovsd {{.*#+}} xmm3 = mem[0],zero
1183+
; AVX-NEXT: vunpcklpd {{.*#+}} xmm2 = xmm2[0],xmm3[0]
1184+
; AVX-NEXT: vinsertf128 $1, %xmm2, %ymm1, %ymm1
1185+
; AVX-NEXT: retq
1186+
;
1187+
; AVX512-LABEL: load_one_mask_bit_set5:
1188+
; AVX512: ## BB#0:
1189+
; AVX512-NEXT: vextractf32x4 $3, %zmm0, %xmm1
1190+
; AVX512-NEXT: vmovhpd {{.*#+}} xmm1 = xmm1[0],mem[0]
1191+
; AVX512-NEXT: vinsertf32x4 $3, %xmm1, %zmm0, %zmm0
1192+
; AVX512-NEXT: retq
1193+
%res = call <8 x double> @llvm.masked.load.v8f64(<8 x double>* %addr, i32 4, <8 x i1><i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 true>, <8 x double> %val)
1194+
ret <8 x double> %res
1195+
}
1196+
10831197
declare <16 x i32> @llvm.masked.load.v16i32(<16 x i32>*, i32, <16 x i1>, <16 x i32>)
10841198
declare <4 x i32> @llvm.masked.load.v4i32(<4 x i32>*, i32, <4 x i1>, <4 x i32>)
10851199
declare <2 x i32> @llvm.masked.load.v2i32(<2 x i32>*, i32, <2 x i1>, <2 x i32>)
1200+
declare <4 x i64> @llvm.masked.load.v4i64(<4 x i64>*, i32, <4 x i1>, <4 x i64>)
10861201
declare void @llvm.masked.store.v16i32(<16 x i32>, <16 x i32>*, i32, <16 x i1>)
10871202
declare void @llvm.masked.store.v8i32(<8 x i32>, <8 x i32>*, i32, <8 x i1>)
10881203
declare void @llvm.masked.store.v4i32(<4 x i32>, <4 x i32>*, i32, <4 x i1>)

0 commit comments

Comments
 (0)