Skip to content

Commit 619d7dc

Browse files
committed
[DAGCombiner] recognize shuffle (shuffle X, Mask0), Mask --> splat X
We get the simple cases of this via demanded elements and other folds, but that doesn't work if the values have >1 use, so add a dedicated match for the pattern. We already have this transform in IR, but it doesn't help the motivating x86 tests (based on PR42024) because the shuffles don't exist until after legalization and other combines have happened. The AArch64 test shows a minimal IR example of the problem. Differential Revision: https://reviews.llvm.org/D75348
1 parent 624dbfc commit 619d7dc

File tree

4 files changed

+104
-62
lines changed

4 files changed

+104
-62
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "llvm/ADT/Statistic.h"
3131
#include "llvm/Analysis/AliasAnalysis.h"
3232
#include "llvm/Analysis/MemoryLocation.h"
33+
#include "llvm/Analysis/VectorUtils.h"
3334
#include "llvm/CodeGen/DAGCombine.h"
3435
#include "llvm/CodeGen/ISDOpcodes.h"
3536
#include "llvm/CodeGen/MachineFrameInfo.h"
@@ -19259,6 +19260,56 @@ static SDValue combineShuffleOfSplatVal(ShuffleVectorSDNode *Shuf,
1925919260
NewMask);
1926019261
}
1926119262

19263+
/// Combine shuffle of shuffle of the form:
19264+
/// shuf (shuf X, undef, InnerMask), undef, OuterMask --> splat X
19265+
static SDValue formSplatFromShuffles(ShuffleVectorSDNode *OuterShuf,
19266+
SelectionDAG &DAG) {
19267+
if (!OuterShuf->getOperand(1).isUndef())
19268+
return SDValue();
19269+
auto *InnerShuf = dyn_cast<ShuffleVectorSDNode>(OuterShuf->getOperand(0));
19270+
if (!InnerShuf || !InnerShuf->getOperand(1).isUndef())
19271+
return SDValue();
19272+
19273+
ArrayRef<int> OuterMask = OuterShuf->getMask();
19274+
ArrayRef<int> InnerMask = InnerShuf->getMask();
19275+
unsigned NumElts = OuterMask.size();
19276+
assert(NumElts == InnerMask.size() && "Mask length mismatch");
19277+
SmallVector<int, 32> CombinedMask(NumElts, -1);
19278+
int SplatIndex = -1;
19279+
for (unsigned i = 0; i != NumElts; ++i) {
19280+
// Undef lanes remain undef.
19281+
int OuterMaskElt = OuterMask[i];
19282+
if (OuterMaskElt == -1)
19283+
continue;
19284+
19285+
// Peek through the shuffle masks to get the underlying source element.
19286+
int InnerMaskElt = InnerMask[OuterMaskElt];
19287+
if (InnerMaskElt == -1)
19288+
continue;
19289+
19290+
// Initialize the splatted element.
19291+
if (SplatIndex == -1)
19292+
SplatIndex = InnerMaskElt;
19293+
19294+
// Non-matching index - this is not a splat.
19295+
if (SplatIndex != InnerMaskElt)
19296+
return SDValue();
19297+
19298+
CombinedMask[i] = InnerMaskElt;
19299+
}
19300+
assert(all_of(CombinedMask, [](int M) { return M == -1; }) ||
19301+
getSplatIndex(CombinedMask) != -1 && "Expected a splat mask");
19302+
19303+
// TODO: The transform may be a win even if the mask is not legal.
19304+
EVT VT = OuterShuf->getValueType(0);
19305+
assert(VT == InnerShuf->getValueType(0) && "Expected matching shuffle types");
19306+
if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(CombinedMask, VT))
19307+
return SDValue();
19308+
19309+
return DAG.getVectorShuffle(VT, SDLoc(OuterShuf), InnerShuf->getOperand(0),
19310+
InnerShuf->getOperand(1), CombinedMask);
19311+
}
19312+
1926219313
/// If the shuffle mask is taking exactly one element from the first vector
1926319314
/// operand and passing through all other elements from the second vector
1926419315
/// operand, return the index of the mask element that is choosing an element
@@ -19417,6 +19468,9 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
1941719468
if (SDValue V = combineShuffleOfSplatVal(SVN, DAG))
1941819469
return V;
1941919470

19471+
if (SDValue V = formSplatFromShuffles(SVN, DAG))
19472+
return V;
19473+
1942019474
// If it is a splat, check if the argument vector is another splat or a
1942119475
// build_vector.
1942219476
if (SVN->isSplat() && SVN->getSplatIndex() < (int)NumElts) {

llvm/test/CodeGen/AArch64/arm64-dup.ll

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,6 @@ define void @disguised_dup(<4 x float> %x, <4 x float>* %p1, <4 x float>* %p2) {
449449
; CHECK-NEXT: dup.4s v1, v0[0]
450450
; CHECK-NEXT: ext.16b v0, v0, v0, #12
451451
; CHECK-NEXT: ext.16b v0, v0, v1, #8
452-
; CHECK-NEXT: zip2.4s v1, v0, v0
453-
; CHECK-NEXT: ext.16b v1, v0, v1, #12
454452
; CHECK-NEXT: str q0, [x0]
455453
; CHECK-NEXT: str q1, [x1]
456454
; CHECK-NEXT: ret

llvm/test/CodeGen/X86/vector-reduce-mul.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,7 @@ define i32 @test_v4i32(<4 x i32> %a0) {
811811
; SSE2-LABEL: test_v4i32:
812812
; SSE2: # %bb.0:
813813
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
814-
; SSE2-NEXT: pshufd {{.*#+}} xmm2 = xmm0[3,3,1,1]
814+
; SSE2-NEXT: pshufd {{.*#+}} xmm2 = xmm0[3,1,2,3]
815815
; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm0[1,1,3,3]
816816
; SSE2-NEXT: pmuludq %xmm2, %xmm3
817817
; SSE2-NEXT: pmuludq %xmm0, %xmm1
@@ -858,7 +858,7 @@ define i32 @test_v8i32(<8 x i32> %a0) {
858858
; SSE2-NEXT: pmuludq %xmm1, %xmm0
859859
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
860860
; SSE2-NEXT: pmuludq %xmm0, %xmm1
861-
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm3[2,2,0,0]
861+
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm3[2,0,2,2]
862862
; SSE2-NEXT: pmuludq %xmm3, %xmm0
863863
; SSE2-NEXT: pmuludq %xmm1, %xmm0
864864
; SSE2-NEXT: movd %xmm0, %eax
@@ -928,7 +928,7 @@ define i32 @test_v16i32(<16 x i32> %a0) {
928928
; SSE2-NEXT: pmuludq %xmm1, %xmm2
929929
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
930930
; SSE2-NEXT: pmuludq %xmm0, %xmm1
931-
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm2[2,2,0,0]
931+
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm2[2,0,2,2]
932932
; SSE2-NEXT: pmuludq %xmm2, %xmm0
933933
; SSE2-NEXT: pmuludq %xmm1, %xmm0
934934
; SSE2-NEXT: movd %xmm0, %eax
@@ -1018,7 +1018,7 @@ define i32 @test_v32i32(<32 x i32> %a0) {
10181018
; SSE2-NEXT: pmuludq %xmm0, %xmm1
10191019
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[2,3,0,1]
10201020
; SSE2-NEXT: pmuludq %xmm1, %xmm0
1021-
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm11[2,2,0,0]
1021+
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm11[2,0,2,2]
10221022
; SSE2-NEXT: pmuludq %xmm11, %xmm1
10231023
; SSE2-NEXT: pmuludq %xmm0, %xmm1
10241024
; SSE2-NEXT: movd %xmm1, %eax

llvm/test/CodeGen/X86/x86-interleaved-access.ll

Lines changed: 46 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,45 +1826,40 @@ define void @splat4_v8i32_load_store(<8 x i32>* %s, <32 x i32>* %d) {
18261826
define void @splat4_v4f64_load_store(<4 x double>* %s, <16 x double>* %d) {
18271827
; AVX1-LABEL: splat4_v4f64_load_store:
18281828
; AVX1: # %bb.0:
1829-
; AVX1-NEXT: vmovupd (%rdi), %ymm0
1830-
; AVX1-NEXT: vperm2f128 {{.*#+}} ymm1 = ymm0[0,1,0,1]
1831-
; AVX1-NEXT: vperm2f128 {{.*#+}} ymm0 = ymm0[2,3,2,3]
1832-
; AVX1-NEXT: vmovddup {{.*#+}} ymm2 = ymm1[0,0,2,2]
1833-
; AVX1-NEXT: vmovddup {{.*#+}} ymm3 = ymm0[0,0,2,2]
1834-
; AVX1-NEXT: vpermilpd {{.*#+}} ymm1 = ymm1[1,1,3,3]
1835-
; AVX1-NEXT: vpermilpd {{.*#+}} ymm0 = ymm0[1,1,3,3]
1836-
; AVX1-NEXT: vmovupd %ymm0, 96(%rsi)
1837-
; AVX1-NEXT: vmovupd %ymm3, 64(%rsi)
1838-
; AVX1-NEXT: vmovupd %ymm1, 32(%rsi)
1839-
; AVX1-NEXT: vmovupd %ymm2, (%rsi)
1829+
; AVX1-NEXT: vbroadcastsd (%rdi), %ymm0
1830+
; AVX1-NEXT: vbroadcastsd 16(%rdi), %ymm1
1831+
; AVX1-NEXT: vbroadcastsd 8(%rdi), %ymm2
1832+
; AVX1-NEXT: vbroadcastsd 24(%rdi), %ymm3
1833+
; AVX1-NEXT: vmovups %ymm3, 96(%rsi)
1834+
; AVX1-NEXT: vmovups %ymm1, 64(%rsi)
1835+
; AVX1-NEXT: vmovups %ymm2, 32(%rsi)
1836+
; AVX1-NEXT: vmovups %ymm0, (%rsi)
18401837
; AVX1-NEXT: vzeroupper
18411838
; AVX1-NEXT: retq
18421839
;
18431840
; AVX2-LABEL: splat4_v4f64_load_store:
18441841
; AVX2: # %bb.0:
1845-
; AVX2-NEXT: vmovups (%rdi), %ymm0
1846-
; AVX2-NEXT: vbroadcastsd (%rdi), %ymm1
1847-
; AVX2-NEXT: vpermpd {{.*#+}} ymm2 = ymm0[2,2,2,2]
1848-
; AVX2-NEXT: vpermpd {{.*#+}} ymm3 = ymm0[1,1,1,1]
1849-
; AVX2-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[3,3,3,3]
1850-
; AVX2-NEXT: vmovups %ymm0, 96(%rsi)
1851-
; AVX2-NEXT: vmovups %ymm2, 64(%rsi)
1852-
; AVX2-NEXT: vmovups %ymm3, 32(%rsi)
1853-
; AVX2-NEXT: vmovups %ymm1, (%rsi)
1842+
; AVX2-NEXT: vbroadcastsd (%rdi), %ymm0
1843+
; AVX2-NEXT: vbroadcastsd 16(%rdi), %ymm1
1844+
; AVX2-NEXT: vbroadcastsd 8(%rdi), %ymm2
1845+
; AVX2-NEXT: vbroadcastsd 24(%rdi), %ymm3
1846+
; AVX2-NEXT: vmovups %ymm3, 96(%rsi)
1847+
; AVX2-NEXT: vmovups %ymm1, 64(%rsi)
1848+
; AVX2-NEXT: vmovups %ymm2, 32(%rsi)
1849+
; AVX2-NEXT: vmovups %ymm0, (%rsi)
18541850
; AVX2-NEXT: vzeroupper
18551851
; AVX2-NEXT: retq
18561852
;
18571853
; AVX512-LABEL: splat4_v4f64_load_store:
18581854
; AVX512: # %bb.0:
1859-
; AVX512-NEXT: vmovups (%rdi), %ymm0
1860-
; AVX512-NEXT: vbroadcastsd (%rdi), %ymm1
1861-
; AVX512-NEXT: vpermpd {{.*#+}} ymm2 = ymm0[2,2,2,2]
1862-
; AVX512-NEXT: vpermpd {{.*#+}} ymm3 = ymm0[1,1,1,1]
1863-
; AVX512-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[3,3,3,3]
1855+
; AVX512-NEXT: vbroadcastsd (%rdi), %ymm0
1856+
; AVX512-NEXT: vbroadcastsd 16(%rdi), %ymm1
1857+
; AVX512-NEXT: vbroadcastsd 8(%rdi), %ymm2
1858+
; AVX512-NEXT: vbroadcastsd 24(%rdi), %ymm3
1859+
; AVX512-NEXT: vinsertf64x4 $1, %ymm2, %zmm0, %zmm0
18641860
; AVX512-NEXT: vinsertf64x4 $1, %ymm3, %zmm1, %zmm1
1865-
; AVX512-NEXT: vinsertf64x4 $1, %ymm0, %zmm2, %zmm0
1866-
; AVX512-NEXT: vmovups %zmm0, 64(%rsi)
1867-
; AVX512-NEXT: vmovups %zmm1, (%rsi)
1861+
; AVX512-NEXT: vmovups %zmm1, 64(%rsi)
1862+
; AVX512-NEXT: vmovups %zmm0, (%rsi)
18681863
; AVX512-NEXT: vzeroupper
18691864
; AVX512-NEXT: retq
18701865
%x = load <4 x double>, <4 x double>* %s, align 8
@@ -1878,45 +1873,40 @@ define void @splat4_v4f64_load_store(<4 x double>* %s, <16 x double>* %d) {
18781873
define void @splat4_v4i64_load_store(<4 x i64>* %s, <16 x i64>* %d) {
18791874
; AVX1-LABEL: splat4_v4i64_load_store:
18801875
; AVX1: # %bb.0:
1881-
; AVX1-NEXT: vmovupd (%rdi), %ymm0
1882-
; AVX1-NEXT: vperm2f128 {{.*#+}} ymm1 = ymm0[0,1,0,1]
1883-
; AVX1-NEXT: vperm2f128 {{.*#+}} ymm0 = ymm0[2,3,2,3]
1884-
; AVX1-NEXT: vmovddup {{.*#+}} ymm2 = ymm1[0,0,2,2]
1885-
; AVX1-NEXT: vmovddup {{.*#+}} ymm3 = ymm0[0,0,2,2]
1886-
; AVX1-NEXT: vpermilpd {{.*#+}} ymm1 = ymm1[1,1,3,3]
1887-
; AVX1-NEXT: vpermilpd {{.*#+}} ymm0 = ymm0[1,1,3,3]
1888-
; AVX1-NEXT: vmovupd %ymm0, 96(%rsi)
1889-
; AVX1-NEXT: vmovupd %ymm3, 64(%rsi)
1890-
; AVX1-NEXT: vmovupd %ymm1, 32(%rsi)
1891-
; AVX1-NEXT: vmovupd %ymm2, (%rsi)
1876+
; AVX1-NEXT: vbroadcastsd (%rdi), %ymm0
1877+
; AVX1-NEXT: vbroadcastsd 16(%rdi), %ymm1
1878+
; AVX1-NEXT: vbroadcastsd 8(%rdi), %ymm2
1879+
; AVX1-NEXT: vbroadcastsd 24(%rdi), %ymm3
1880+
; AVX1-NEXT: vmovups %ymm3, 96(%rsi)
1881+
; AVX1-NEXT: vmovups %ymm1, 64(%rsi)
1882+
; AVX1-NEXT: vmovups %ymm2, 32(%rsi)
1883+
; AVX1-NEXT: vmovups %ymm0, (%rsi)
18921884
; AVX1-NEXT: vzeroupper
18931885
; AVX1-NEXT: retq
18941886
;
18951887
; AVX2-LABEL: splat4_v4i64_load_store:
18961888
; AVX2: # %bb.0:
1897-
; AVX2-NEXT: vmovups (%rdi), %ymm0
1898-
; AVX2-NEXT: vbroadcastsd (%rdi), %ymm1
1899-
; AVX2-NEXT: vpermpd {{.*#+}} ymm2 = ymm0[2,2,2,2]
1900-
; AVX2-NEXT: vpermpd {{.*#+}} ymm3 = ymm0[1,1,1,1]
1901-
; AVX2-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[3,3,3,3]
1902-
; AVX2-NEXT: vmovups %ymm0, 96(%rsi)
1903-
; AVX2-NEXT: vmovups %ymm2, 64(%rsi)
1904-
; AVX2-NEXT: vmovups %ymm3, 32(%rsi)
1905-
; AVX2-NEXT: vmovups %ymm1, (%rsi)
1889+
; AVX2-NEXT: vbroadcastsd (%rdi), %ymm0
1890+
; AVX2-NEXT: vbroadcastsd 16(%rdi), %ymm1
1891+
; AVX2-NEXT: vbroadcastsd 8(%rdi), %ymm2
1892+
; AVX2-NEXT: vbroadcastsd 24(%rdi), %ymm3
1893+
; AVX2-NEXT: vmovups %ymm3, 96(%rsi)
1894+
; AVX2-NEXT: vmovups %ymm1, 64(%rsi)
1895+
; AVX2-NEXT: vmovups %ymm2, 32(%rsi)
1896+
; AVX2-NEXT: vmovups %ymm0, (%rsi)
19061897
; AVX2-NEXT: vzeroupper
19071898
; AVX2-NEXT: retq
19081899
;
19091900
; AVX512-LABEL: splat4_v4i64_load_store:
19101901
; AVX512: # %bb.0:
1911-
; AVX512-NEXT: vmovups (%rdi), %ymm0
1912-
; AVX512-NEXT: vbroadcastsd (%rdi), %ymm1
1913-
; AVX512-NEXT: vpermpd {{.*#+}} ymm2 = ymm0[2,2,2,2]
1914-
; AVX512-NEXT: vpermpd {{.*#+}} ymm3 = ymm0[1,1,1,1]
1915-
; AVX512-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[3,3,3,3]
1902+
; AVX512-NEXT: vbroadcastsd (%rdi), %ymm0
1903+
; AVX512-NEXT: vbroadcastsd 16(%rdi), %ymm1
1904+
; AVX512-NEXT: vbroadcastsd 8(%rdi), %ymm2
1905+
; AVX512-NEXT: vbroadcastsd 24(%rdi), %ymm3
1906+
; AVX512-NEXT: vinsertf64x4 $1, %ymm2, %zmm0, %zmm0
19161907
; AVX512-NEXT: vinsertf64x4 $1, %ymm3, %zmm1, %zmm1
1917-
; AVX512-NEXT: vinsertf64x4 $1, %ymm0, %zmm2, %zmm0
1918-
; AVX512-NEXT: vmovups %zmm0, 64(%rsi)
1919-
; AVX512-NEXT: vmovups %zmm1, (%rsi)
1908+
; AVX512-NEXT: vmovups %zmm1, 64(%rsi)
1909+
; AVX512-NEXT: vmovups %zmm0, (%rsi)
19201910
; AVX512-NEXT: vzeroupper
19211911
; AVX512-NEXT: retq
19221912
%x = load <4 x i64>, <4 x i64>* %s, align 8

0 commit comments

Comments
 (0)