Skip to content

Commit bb62ec2

Browse files
committed
[TTI][WebAssembly] Pairwise reduction expansion
WebAssembly doesn't support horizontal operations nor does it have a way of expressing fast-math or reassoc flags, so runtimes are currently unable to use pairwise operations when generating code from the existing shuffle patterns. This patch allows the backend to select which, arbitary, shuffle pattern to be used per reduction intrinsic. The default behaviour is the same as the existing, which is by splitting the vector into a top and bottom half. The other pattern introduced is for a pairwise shuffle. WebAssembly enables pairwise reductions for int/fp add/sub.
1 parent b86a9c5 commit bb62ec2

File tree

9 files changed

+208
-16
lines changed

9 files changed

+208
-16
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,13 @@ class TargetTransformInfo {
16961696
/// into a shuffle sequence.
16971697
bool shouldExpandReduction(const IntrinsicInst *II) const;
16981698

1699+
enum struct ReductionShuffle { SplitHalf, Pairwise };
1700+
1701+
/// \returns The shuffle sequence pattern used to expand the given reduction
1702+
/// intrinsic.
1703+
ReductionShuffle
1704+
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const;
1705+
16991706
/// \returns the size cost of rematerializing a GlobalValue address relative
17001707
/// to a stack reload.
17011708
unsigned getGISelRematGlobalCost() const;
@@ -2145,6 +2152,8 @@ class TargetTransformInfo::Concept {
21452152
virtual bool preferEpilogueVectorization() const = 0;
21462153

21472154
virtual bool shouldExpandReduction(const IntrinsicInst *II) const = 0;
2155+
virtual ReductionShuffle
2156+
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const = 0;
21482157
virtual unsigned getGISelRematGlobalCost() const = 0;
21492158
virtual unsigned getMinTripCountTailFoldingThreshold() const = 0;
21502159
virtual bool enableScalableVectorization() const = 0;
@@ -2881,6 +2890,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
28812890
return Impl.shouldExpandReduction(II);
28822891
}
28832892

2893+
ReductionShuffle
2894+
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override {
2895+
return Impl.getPreferredExpandedReductionShuffle(II);
2896+
}
2897+
28842898
unsigned getGISelRematGlobalCost() const override {
28852899
return Impl.getGISelRematGlobalCost();
28862900
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,11 @@ class TargetTransformInfoImplBase {
927927

928928
bool shouldExpandReduction(const IntrinsicInst *II) const { return true; }
929929

930+
TTI::ReductionShuffle
931+
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const {
932+
return TTI::ReductionShuffle::SplitHalf;
933+
}
934+
930935
unsigned getGISelRematGlobalCost() const { return 1; }
931936

932937
unsigned getMinTripCountTailFoldingThreshold() const { return 0; }

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "llvm/Analysis/IVDescriptors.h"
1717
#include "llvm/Analysis/LoopAccessAnalysis.h"
18+
#include "llvm/Analysis/TargetTransformInfo.h"
1819
#include "llvm/Transforms/Utils/ValueMapper.h"
1920

2021
namespace llvm {
@@ -384,6 +385,7 @@ Value *getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src,
384385
/// Generates a vector reduction using shufflevectors to reduce the value.
385386
/// Fast-math-flags are propagated using the IRBuilder's setting.
386387
Value *getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op,
388+
TargetTransformInfo::ReductionShuffle RS,
387389
RecurKind MinMaxKind = RecurKind::None);
388390

389391
/// Create a target reduction of the given vector. The reduction operation

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,12 @@ bool TargetTransformInfo::shouldExpandReduction(const IntrinsicInst *II) const {
13091309
return TTIImpl->shouldExpandReduction(II);
13101310
}
13111311

1312+
TargetTransformInfo::ReductionShuffle
1313+
TargetTransformInfo::getPreferredExpandedReductionShuffle(
1314+
const IntrinsicInst *II) const {
1315+
return TTIImpl->getPreferredExpandedReductionShuffle(II);
1316+
}
1317+
13121318
unsigned TargetTransformInfo::getGISelRematGlobalCost() const {
13131319
return TTIImpl->getGISelRematGlobalCost();
13141320
}

llvm/lib/CodeGen/ExpandReductions.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
5959
isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
6060
Intrinsic::ID ID = II->getIntrinsicID();
6161
RecurKind RK = getMinMaxReductionRecurKind(ID);
62+
TargetTransformInfo::ReductionShuffle RS =
63+
TTI->getPreferredExpandedReductionShuffle(II);
6264

6365
Value *Rdx = nullptr;
6466
IRBuilder<> Builder(II);
@@ -79,7 +81,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
7981
if (!isPowerOf2_32(
8082
cast<FixedVectorType>(Vec->getType())->getNumElements()))
8183
continue;
82-
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
84+
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
8385
Rdx = Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, Acc, Rdx,
8486
"bin.rdx");
8587
}
@@ -112,7 +114,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
112114
break;
113115
}
114116
unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
115-
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
117+
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
116118
break;
117119
}
118120
case Intrinsic::vector_reduce_add:
@@ -127,7 +129,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
127129
cast<FixedVectorType>(Vec->getType())->getNumElements()))
128130
continue;
129131
unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
130-
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
132+
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
131133
break;
132134
}
133135
case Intrinsic::vector_reduce_fmax:
@@ -140,7 +142,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
140142
!FMF.noNaNs())
141143
continue;
142144
unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
143-
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
145+
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
144146
break;
145147
}
146148
}

llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@ WebAssemblyTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
9494
return Cost;
9595
}
9696

97+
TTI::ReductionShuffle WebAssemblyTTIImpl::getPreferredExpandedReductionShuffle(
98+
const IntrinsicInst *II) const {
99+
100+
switch (II->getIntrinsicID()) {
101+
default:
102+
break;
103+
case Intrinsic::vector_reduce_add:
104+
case Intrinsic::vector_reduce_fadd:
105+
return TTI::ReductionShuffle::Pairwise;
106+
}
107+
return TTI::ReductionShuffle::SplitHalf;
108+
}
109+
97110
bool WebAssemblyTTIImpl::areInlineCompatible(const Function *Caller,
98111
const Function *Callee) const {
99112
// Allow inlining only when the Callee has a subset of the Caller's

llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
7070
TTI::TargetCostKind CostKind,
7171
unsigned Index, Value *Op0, Value *Op1);
7272

73+
TTI::ReductionShuffle
74+
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const;
7375
/// @}
7476

7577
bool areInlineCompatible(const Function *Caller,

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,9 @@ Value *llvm::getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src,
10771077

10781078
// Helper to generate a log2 shuffle reduction.
10791079
Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
1080-
unsigned Op, RecurKind RdxKind) {
1080+
unsigned Op,
1081+
TargetTransformInfo::ReductionShuffle RS,
1082+
RecurKind RdxKind) {
10811083
unsigned VF = cast<FixedVectorType>(Src->getType())->getNumElements();
10821084
// VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
10831085
// and vector ops, reducing the set of values being computed by half each
@@ -1091,18 +1093,10 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
10911093
// will never be relevant here. Note that it would be generally unsound to
10921094
// propagate these from an intrinsic call to the expansion anyways as we/
10931095
// change the order of operations.
1094-
Value *TmpVec = Src;
1095-
SmallVector<int, 32> ShuffleMask(VF);
1096-
for (unsigned i = VF; i != 1; i >>= 1) {
1097-
// Move the upper half of the vector to the lower half.
1098-
for (unsigned j = 0; j != i / 2; ++j)
1099-
ShuffleMask[j] = i / 2 + j;
1100-
1101-
// Fill the rest of the mask with undef.
1102-
std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), -1);
1103-
1096+
auto BuildShuffledOp = [&Builder, &Op,
1097+
&RdxKind](SmallVectorImpl<int> &ShuffleMask,
1098+
Value *&TmpVec) -> void {
11041099
Value *Shuf = Builder.CreateShuffleVector(TmpVec, ShuffleMask, "rdx.shuf");
1105-
11061100
if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
11071101
TmpVec = Builder.CreateBinOp((Instruction::BinaryOps)Op, TmpVec, Shuf,
11081102
"bin.rdx");
@@ -1111,6 +1105,30 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
11111105
"Invalid min/max");
11121106
TmpVec = createMinMaxOp(Builder, RdxKind, TmpVec, Shuf);
11131107
}
1108+
};
1109+
1110+
Value *TmpVec = Src;
1111+
if (TargetTransformInfo::ReductionShuffle::Pairwise == RS) {
1112+
SmallVector<int, 32> ShuffleMask(VF);
1113+
for (unsigned stride = 1; stride < VF; stride <<= 1) {
1114+
// Initialise the mask with undef.
1115+
std::fill(ShuffleMask.begin(), ShuffleMask.end(), -1);
1116+
for (unsigned j = 0; j < VF; j += stride << 1) {
1117+
ShuffleMask[j] = j + stride;
1118+
}
1119+
BuildShuffledOp(ShuffleMask, TmpVec);
1120+
}
1121+
} else {
1122+
SmallVector<int, 32> ShuffleMask(VF);
1123+
for (unsigned i = VF; i != 1; i >>= 1) {
1124+
// Move the upper half of the vector to the lower half.
1125+
for (unsigned j = 0; j != i / 2; ++j)
1126+
ShuffleMask[j] = i / 2 + j;
1127+
1128+
// Fill the rest of the mask with undef.
1129+
std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), -1);
1130+
BuildShuffledOp(ShuffleMask, TmpVec);
1131+
}
11141132
}
11151133
// The result is in the first element of the vector.
11161134
return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0));
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc < %s -mtriple=wasm32 -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s --check-prefix=SIMD128
3+
4+
define i64 @pairwise_add_v2i64(<2 x i64> %arg) {
5+
; SIMD128-LABEL: pairwise_add_v2i64:
6+
; SIMD128: .functype pairwise_add_v2i64 (v128) -> (i64)
7+
; SIMD128-NEXT: # %bb.0:
8+
; SIMD128-NEXT: i8x16.shuffle $push0=, $0, $0, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7
9+
; SIMD128-NEXT: i64x2.add $push1=, $0, $pop0
10+
; SIMD128-NEXT: i64x2.extract_lane $push2=, $pop1, 0
11+
; SIMD128-NEXT: return $pop2
12+
%res = tail call i64 @llvm.vector.reduce.add.i64.v4i64(<2 x i64> %arg)
13+
ret i64 %res
14+
}
15+
16+
define i32 @pairwise_add_v4i32(<4 x i32> %arg) {
17+
; SIMD128-LABEL: pairwise_add_v4i32:
18+
; SIMD128: .functype pairwise_add_v4i32 (v128) -> (i32)
19+
; SIMD128-NEXT: # %bb.0:
20+
; SIMD128-NEXT: i8x16.shuffle $push0=, $0, $0, 4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3
21+
; SIMD128-NEXT: i32x4.add $push5=, $0, $pop0
22+
; SIMD128-NEXT: local.tee $push4=, $0=, $pop5
23+
; SIMD128-NEXT: i8x16.shuffle $push1=, $0, $0, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3
24+
; SIMD128-NEXT: i32x4.add $push2=, $pop4, $pop1
25+
; SIMD128-NEXT: i32x4.extract_lane $push3=, $pop2, 0
26+
; SIMD128-NEXT: return $pop3
27+
%res = tail call i32 @llvm.vector.reduce.add.i32.v4f32(<4 x i32> %arg)
28+
ret i32 %res
29+
}
30+
31+
define i16 @pairwise_add_v8i16(<8 x i16> %arg) {
32+
; SIMD128-LABEL: pairwise_add_v8i16:
33+
; SIMD128: .functype pairwise_add_v8i16 (v128) -> (i32)
34+
; SIMD128-NEXT: # %bb.0:
35+
; SIMD128-NEXT: i8x16.shuffle $push0=, $0, $0, 2, 3, 0, 1, 6, 7, 0, 1, 10, 11, 0, 1, 14, 15, 0, 1
36+
; SIMD128-NEXT: i16x8.add $push8=, $0, $pop0
37+
; SIMD128-NEXT: local.tee $push7=, $0=, $pop8
38+
; SIMD128-NEXT: i8x16.shuffle $push1=, $0, $0, 4, 5, 0, 1, 0, 1, 0, 1, 12, 13, 0, 1, 0, 1, 0, 1
39+
; SIMD128-NEXT: i16x8.add $push6=, $pop7, $pop1
40+
; SIMD128-NEXT: local.tee $push5=, $0=, $pop6
41+
; SIMD128-NEXT: i8x16.shuffle $push2=, $0, $0, 8, 9, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
42+
; SIMD128-NEXT: i16x8.add $push3=, $pop5, $pop2
43+
; SIMD128-NEXT: i16x8.extract_lane_u $push4=, $pop3, 0
44+
; SIMD128-NEXT: return $pop4
45+
%res = tail call i16 @llvm.vector.reduce.add.i16.v8i16(<8 x i16> %arg)
46+
ret i16 %res
47+
}
48+
49+
define i8 @pairwise_add_v16i8(<16 x i8> %arg) {
50+
; SIMD128-LABEL: pairwise_add_v16i8:
51+
; SIMD128: .functype pairwise_add_v16i8 (v128) -> (i32)
52+
; SIMD128-NEXT: # %bb.0:
53+
; SIMD128-NEXT: i8x16.shuffle $push0=, $0, $0, 1, 0, 3, 0, 5, 0, 7, 0, 9, 0, 11, 0, 13, 0, 15, 0
54+
; SIMD128-NEXT: i8x16.add $push11=, $0, $pop0
55+
; SIMD128-NEXT: local.tee $push10=, $0=, $pop11
56+
; SIMD128-NEXT: i8x16.shuffle $push1=, $0, $0, 2, 0, 0, 0, 6, 0, 0, 0, 10, 0, 0, 0, 14, 0, 0, 0
57+
; SIMD128-NEXT: i8x16.add $push9=, $pop10, $pop1
58+
; SIMD128-NEXT: local.tee $push8=, $0=, $pop9
59+
; SIMD128-NEXT: i8x16.shuffle $push2=, $0, $0, 4, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0
60+
; SIMD128-NEXT: i8x16.add $push7=, $pop8, $pop2
61+
; SIMD128-NEXT: local.tee $push6=, $0=, $pop7
62+
; SIMD128-NEXT: i8x16.shuffle $push3=, $0, $0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
63+
; SIMD128-NEXT: i8x16.add $push4=, $pop6, $pop3
64+
; SIMD128-NEXT: i8x16.extract_lane_u $push5=, $pop4, 0
65+
; SIMD128-NEXT: return $pop5
66+
%res = tail call i8 @llvm.vector.reduce.add.i8.v16i8(<16 x i8> %arg)
67+
ret i8 %res
68+
}
69+
70+
define double @pairwise_add_v2f64(<2 x double> %arg) {
71+
; SIMD128-LABEL: pairwise_add_v2f64:
72+
; SIMD128: .functype pairwise_add_v2f64 (v128) -> (f64)
73+
; SIMD128-NEXT: # %bb.0:
74+
; SIMD128-NEXT: f64x2.extract_lane $push1=, $0, 0
75+
; SIMD128-NEXT: f64x2.extract_lane $push0=, $0, 1
76+
; SIMD128-NEXT: f64.add $push2=, $pop1, $pop0
77+
; SIMD128-NEXT: return $pop2
78+
%res = tail call double @llvm.vector.reduce.fadd.f64.v2f64(double -0.0, <2 x double> %arg)
79+
ret double%res
80+
}
81+
82+
define double @pairwise_add_v2f64_fast(<2 x double> %arg) {
83+
; SIMD128-LABEL: pairwise_add_v2f64_fast:
84+
; SIMD128: .functype pairwise_add_v2f64_fast (v128) -> (f64)
85+
; SIMD128-NEXT: # %bb.0:
86+
; SIMD128-NEXT: i8x16.shuffle $push0=, $0, $0, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7
87+
; SIMD128-NEXT: f64x2.add $push1=, $0, $pop0
88+
; SIMD128-NEXT: f64x2.extract_lane $push2=, $pop1, 0
89+
; SIMD128-NEXT: return $pop2
90+
%res = tail call fast double @llvm.vector.reduce.fadd.f64.v2f64(double -0.0, <2 x double> %arg)
91+
ret double%res
92+
}
93+
94+
define float @pairwise_add_v4f32(<4 x float> %arg) {
95+
; SIMD128-LABEL: pairwise_add_v4f32:
96+
; SIMD128: .functype pairwise_add_v4f32 (v128) -> (f32)
97+
; SIMD128-NEXT: # %bb.0:
98+
; SIMD128-NEXT: f32x4.extract_lane $push1=, $0, 0
99+
; SIMD128-NEXT: f32x4.extract_lane $push0=, $0, 1
100+
; SIMD128-NEXT: f32.add $push2=, $pop1, $pop0
101+
; SIMD128-NEXT: f32x4.extract_lane $push3=, $0, 2
102+
; SIMD128-NEXT: f32.add $push4=, $pop2, $pop3
103+
; SIMD128-NEXT: f32x4.extract_lane $push5=, $0, 3
104+
; SIMD128-NEXT: f32.add $push6=, $pop4, $pop5
105+
; SIMD128-NEXT: return $pop6
106+
%res = tail call float @llvm.vector.reduce.fadd.f32.v4f32(float -0.0, <4 x float> %arg)
107+
ret float %res
108+
}
109+
110+
define float @pairwise_add_v4f32_fast(<4 x float> %arg) {
111+
; SIMD128-LABEL: pairwise_add_v4f32_fast:
112+
; SIMD128: .functype pairwise_add_v4f32_fast (v128) -> (f32)
113+
; SIMD128-NEXT: # %bb.0:
114+
; SIMD128-NEXT: i8x16.shuffle $push0=, $0, $0, 4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3
115+
; SIMD128-NEXT: f32x4.add $push5=, $0, $pop0
116+
; SIMD128-NEXT: local.tee $push4=, $0=, $pop5
117+
; SIMD128-NEXT: i8x16.shuffle $push1=, $0, $0, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3
118+
; SIMD128-NEXT: f32x4.add $push2=, $pop4, $pop1
119+
; SIMD128-NEXT: f32x4.extract_lane $push3=, $pop2, 0
120+
; SIMD128-NEXT: return $pop3
121+
%res = tail call fast float @llvm.vector.reduce.fadd.f32.v4f32(float -0.0, <4 x float> %arg)
122+
ret float %res
123+
}
124+
125+
declare i64 @llvm.vector.reduce.add.i64.v4i64(<2 x i64>)
126+
declare i32 @llvm.vector.reduce.add.i32.v4i32(<4 x i32>)
127+
declare i16 @llvm.vector.reduce.add.i16.v8i16(<8 x i16>)
128+
declare i8 @llvm.vector.reduce.add.i8.v16i8(<16 x i8>)
129+
declare double @llvm.vector.reduce.fadd.f64.v2f64(double, <2 x double>)
130+
declare float @llvm.vector.reduce.fadd.f32.v4f32(float, <4 x float>)

0 commit comments

Comments
 (0)