Skip to content

Commit 9d7ee1d

Browse files
committed
[TTI][WebAssembly] Pairwise reduction expansion
WebAssembly doesn't support horizontal operations nor does it have a way of expressing fast-math or reassoc flags, so runtimes are currently unable to use pairwise operations when generating code from the existing shuffle patterns. This patch allows the backend to select which, arbitary, shuffle pattern to be used per reduction intrinsic. The default behaviour is the same as the existing, which is by splitting the vector into a top and bottom half. The other pattern introduced is for a pairwise shuffle. WebAssembly enables pairwise reductions for int/fp add/sub.
1 parent b86a9c5 commit 9d7ee1d

File tree

9 files changed

+210
-16
lines changed

9 files changed

+210
-16
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,16 @@ class TargetTransformInfo {
16961696
/// into a shuffle sequence.
16971697
bool shouldExpandReduction(const IntrinsicInst *II) const;
16981698

1699+
enum struct ReductionShuffle {
1700+
SplitHalf,
1701+
Pairwise
1702+
};
1703+
1704+
/// \returns The shuffle sequence pattern used to expand the given reduction
1705+
/// intrinsic.
1706+
ReductionShuffle getPreferredExpandedReductionShuffle(
1707+
const IntrinsicInst *II) const;
1708+
16991709
/// \returns the size cost of rematerializing a GlobalValue address relative
17001710
/// to a stack reload.
17011711
unsigned getGISelRematGlobalCost() const;
@@ -2145,6 +2155,8 @@ class TargetTransformInfo::Concept {
21452155
virtual bool preferEpilogueVectorization() const = 0;
21462156

21472157
virtual bool shouldExpandReduction(const IntrinsicInst *II) const = 0;
2158+
virtual ReductionShuffle
2159+
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const = 0;
21482160
virtual unsigned getGISelRematGlobalCost() const = 0;
21492161
virtual unsigned getMinTripCountTailFoldingThreshold() const = 0;
21502162
virtual bool enableScalableVectorization() const = 0;
@@ -2881,6 +2893,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
28812893
return Impl.shouldExpandReduction(II);
28822894
}
28832895

2896+
ReductionShuffle
2897+
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override {
2898+
return Impl.getPreferredExpandedReductionShuffle(II);
2899+
}
2900+
28842901
unsigned getGISelRematGlobalCost() const override {
28852902
return Impl.getGISelRematGlobalCost();
28862903
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,11 @@ class TargetTransformInfoImplBase {
927927

928928
bool shouldExpandReduction(const IntrinsicInst *II) const { return true; }
929929

930+
TTI::ReductionShuffle
931+
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const {
932+
return TTI::ReductionShuffle::SplitHalf;
933+
}
934+
930935
unsigned getGISelRematGlobalCost() const { return 1; }
931936

932937
unsigned getMinTripCountTailFoldingThreshold() const { return 0; }

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "llvm/Analysis/IVDescriptors.h"
1717
#include "llvm/Analysis/LoopAccessAnalysis.h"
18+
#include "llvm/Analysis/TargetTransformInfo.h"
1819
#include "llvm/Transforms/Utils/ValueMapper.h"
1920

2021
namespace llvm {
@@ -384,6 +385,7 @@ Value *getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src,
384385
/// Generates a vector reduction using shufflevectors to reduce the value.
385386
/// Fast-math-flags are propagated using the IRBuilder's setting.
386387
Value *getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op,
388+
TargetTransformInfo::ReductionShuffle RS,
387389
RecurKind MinMaxKind = RecurKind::None);
388390

389391
/// Create a target reduction of the given vector. The reduction operation

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,12 @@ bool TargetTransformInfo::shouldExpandReduction(const IntrinsicInst *II) const {
13091309
return TTIImpl->shouldExpandReduction(II);
13101310
}
13111311

1312+
TargetTransformInfo::ReductionShuffle
1313+
TargetTransformInfo::getPreferredExpandedReductionShuffle(
1314+
const IntrinsicInst *II) const {
1315+
return TTIImpl->getPreferredExpandedReductionShuffle(II);
1316+
}
1317+
13121318
unsigned TargetTransformInfo::getGISelRematGlobalCost() const {
13131319
return TTIImpl->getGISelRematGlobalCost();
13141320
}

llvm/lib/CodeGen/ExpandReductions.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
5959
isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
6060
Intrinsic::ID ID = II->getIntrinsicID();
6161
RecurKind RK = getMinMaxReductionRecurKind(ID);
62+
TargetTransformInfo::ReductionShuffle RS =
63+
TTI->getPreferredExpandedReductionShuffle(II);
6264

6365
Value *Rdx = nullptr;
6466
IRBuilder<> Builder(II);
@@ -79,7 +81,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
7981
if (!isPowerOf2_32(
8082
cast<FixedVectorType>(Vec->getType())->getNumElements()))
8183
continue;
82-
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
84+
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
8385
Rdx = Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, Acc, Rdx,
8486
"bin.rdx");
8587
}
@@ -112,7 +114,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
112114
break;
113115
}
114116
unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
115-
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
117+
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
116118
break;
117119
}
118120
case Intrinsic::vector_reduce_add:
@@ -127,7 +129,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
127129
cast<FixedVectorType>(Vec->getType())->getNumElements()))
128130
continue;
129131
unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
130-
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
132+
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
131133
break;
132134
}
133135
case Intrinsic::vector_reduce_fmax:
@@ -140,7 +142,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
140142
!FMF.noNaNs())
141143
continue;
142144
unsigned RdxOpcode = getArithmeticReductionInstruction(ID);
143-
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RK);
145+
Rdx = getShuffleReduction(Builder, Vec, RdxOpcode, RS, RK);
144146
break;
145147
}
146148
}

llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@ WebAssemblyTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
9494
return Cost;
9595
}
9696

97+
TTI::ReductionShuffle WebAssemblyTTIImpl::getPreferredExpandedReductionShuffle(
98+
const IntrinsicInst *II) const {
99+
100+
switch (II->getIntrinsicID()) {
101+
default:
102+
break;
103+
case Intrinsic::vector_reduce_add:
104+
case Intrinsic::vector_reduce_fadd:
105+
return TTI::ReductionShuffle::Pairwise;
106+
}
107+
return TTI::ReductionShuffle::SplitHalf;
108+
}
109+
97110
bool WebAssemblyTTIImpl::areInlineCompatible(const Function *Caller,
98111
const Function *Callee) const {
99112
// Allow inlining only when the Callee has a subset of the Caller's

llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
7070
TTI::TargetCostKind CostKind,
7171
unsigned Index, Value *Op0, Value *Op1);
7272

73+
TTI::ReductionShuffle
74+
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const;
7375
/// @}
7476

7577
bool areInlineCompatible(const Function *Caller,

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,9 @@ Value *llvm::getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src,
10771077

10781078
// Helper to generate a log2 shuffle reduction.
10791079
Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
1080-
unsigned Op, RecurKind RdxKind) {
1080+
unsigned Op,
1081+
TargetTransformInfo::ReductionShuffle RS,
1082+
RecurKind RdxKind) {
10811083
unsigned VF = cast<FixedVectorType>(Src->getType())->getNumElements();
10821084
// VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
10831085
// and vector ops, reducing the set of values being computed by half each
@@ -1091,18 +1093,9 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
10911093
// will never be relevant here. Note that it would be generally unsound to
10921094
// propagate these from an intrinsic call to the expansion anyways as we/
10931095
// change the order of operations.
1094-
Value *TmpVec = Src;
1095-
SmallVector<int, 32> ShuffleMask(VF);
1096-
for (unsigned i = VF; i != 1; i >>= 1) {
1097-
// Move the upper half of the vector to the lower half.
1098-
for (unsigned j = 0; j != i / 2; ++j)
1099-
ShuffleMask[j] = i / 2 + j;
1100-
1101-
// Fill the rest of the mask with undef.
1102-
std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), -1);
1103-
1096+
auto BuildShuffledOp = [&Builder, &Op, &RdxKind](
1097+
SmallVectorImpl<int> &ShuffleMask, Value*& TmpVec) -> void {
11041098
Value *Shuf = Builder.CreateShuffleVector(TmpVec, ShuffleMask, "rdx.shuf");
1105-
11061099
if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
11071100
TmpVec = Builder.CreateBinOp((Instruction::BinaryOps)Op, TmpVec, Shuf,
11081101
"bin.rdx");
@@ -1111,6 +1104,30 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
11111104
"Invalid min/max");
11121105
TmpVec = createMinMaxOp(Builder, RdxKind, TmpVec, Shuf);
11131106
}
1107+
};
1108+
1109+
Value *TmpVec = Src;
1110+
if (TargetTransformInfo::ReductionShuffle::Pairwise == RS) {
1111+
SmallVector<int, 32> ShuffleMask(VF);
1112+
for (unsigned stride = 1; stride < VF; stride <<= 1) {
1113+
// Initialise the mask with undef.
1114+
std::fill(ShuffleMask.begin(), ShuffleMask.end(), -1);
1115+
for (unsigned j = 0; j < VF; j += stride << 1) {
1116+
ShuffleMask[j] = j + stride;
1117+
}
1118+
BuildShuffledOp(ShuffleMask, TmpVec);
1119+
}
1120+
} else {
1121+
SmallVector<int, 32> ShuffleMask(VF);
1122+
for (unsigned i = VF; i != 1; i >>= 1) {
1123+
// Move the upper half of the vector to the lower half.
1124+
for (unsigned j = 0; j != i / 2; ++j)
1125+
ShuffleMask[j] = i / 2 + j;
1126+
1127+
// Fill the rest of the mask with undef.
1128+
std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), -1);
1129+
BuildShuffledOp(ShuffleMask, TmpVec);
1130+
}
11141131
}
11151132
// The result is in the first element of the vector.
11161133
return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0));
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc < %s -mtriple=wasm32 -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s --check-prefix=SIMD128
3+
4+
define i64 @pairwise_add_v2i64(<2 x i64> %arg) {
5+
; SIMD128-LABEL: pairwise_add_v2i64:
6+
; SIMD128: .functype pairwise_add_v2i64 (v128) -> (i64)
7+
; SIMD128-NEXT: # %bb.0:
8+
; SIMD128-NEXT: i8x16.shuffle $push0=, $0, $0, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7
9+
; SIMD128-NEXT: i64x2.add $push1=, $0, $pop0
10+
; SIMD128-NEXT: i64x2.extract_lane $push2=, $pop1, 0
11+
; SIMD128-NEXT: return $pop2
12+
%res = tail call i64 @llvm.vector.reduce.add.i64.v4i64(<2 x i64> %arg)
13+
ret i64 %res
14+
}
15+
16+
define i32 @pairwise_add_v4i32(<4 x i32> %arg) {
17+
; SIMD128-LABEL: pairwise_add_v4i32:
18+
; SIMD128: .functype pairwise_add_v4i32 (v128) -> (i32)
19+
; SIMD128-NEXT: # %bb.0:
20+
; SIMD128-NEXT: i8x16.shuffle $push0=, $0, $0, 4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3
21+
; SIMD128-NEXT: i32x4.add $push5=, $0, $pop0
22+
; SIMD128-NEXT: local.tee $push4=, $0=, $pop5
23+
; SIMD128-NEXT: i8x16.shuffle $push1=, $0, $0, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3
24+
; SIMD128-NEXT: i32x4.add $push2=, $pop4, $pop1
25+
; SIMD128-NEXT: i32x4.extract_lane $push3=, $pop2, 0
26+
; SIMD128-NEXT: return $pop3
27+
%res = tail call i32 @llvm.vector.reduce.add.i32.v4f32(<4 x i32> %arg)
28+
ret i32 %res
29+
}
30+
31+
define i16 @pairwise_add_v8i16(<8 x i16> %arg) {
32+
; SIMD128-LABEL: pairwise_add_v8i16:
33+
; SIMD128: .functype pairwise_add_v8i16 (v128) -> (i32)
34+
; SIMD128-NEXT: # %bb.0:
35+
; SIMD128-NEXT: i8x16.shuffle $push0=, $0, $0, 2, 3, 0, 1, 6, 7, 0, 1, 10, 11, 0, 1, 14, 15, 0, 1
36+
; SIMD128-NEXT: i16x8.add $push8=, $0, $pop0
37+
; SIMD128-NEXT: local.tee $push7=, $0=, $pop8
38+
; SIMD128-NEXT: i8x16.shuffle $push1=, $0, $0, 4, 5, 0, 1, 0, 1, 0, 1, 12, 13, 0, 1, 0, 1, 0, 1
39+
; SIMD128-NEXT: i16x8.add $push6=, $pop7, $pop1
40+
; SIMD128-NEXT: local.tee $push5=, $0=, $pop6
41+
; SIMD128-NEXT: i8x16.shuffle $push2=, $0, $0, 8, 9, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
42+
; SIMD128-NEXT: i16x8.add $push3=, $pop5, $pop2
43+
; SIMD128-NEXT: i16x8.extract_lane_u $push4=, $pop3, 0
44+
; SIMD128-NEXT: return $pop4
45+
%res = tail call i16 @llvm.vector.reduce.add.i16.v8i16(<8 x i16> %arg)
46+
ret i16 %res
47+
}
48+
49+
define i8 @pairwise_add_v16i8(<16 x i8> %arg) {
50+
; SIMD128-LABEL: pairwise_add_v16i8:
51+
; SIMD128: .functype pairwise_add_v16i8 (v128) -> (i32)
52+
; SIMD128-NEXT: # %bb.0:
53+
; SIMD128-NEXT: i8x16.shuffle $push0=, $0, $0, 1, 0, 3, 0, 5, 0, 7, 0, 9, 0, 11, 0, 13, 0, 15, 0
54+
; SIMD128-NEXT: i8x16.add $push11=, $0, $pop0
55+
; SIMD128-NEXT: local.tee $push10=, $0=, $pop11
56+
; SIMD128-NEXT: i8x16.shuffle $push1=, $0, $0, 2, 0, 0, 0, 6, 0, 0, 0, 10, 0, 0, 0, 14, 0, 0, 0
57+
; SIMD128-NEXT: i8x16.add $push9=, $pop10, $pop1
58+
; SIMD128-NEXT: local.tee $push8=, $0=, $pop9
59+
; SIMD128-NEXT: i8x16.shuffle $push2=, $0, $0, 4, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0
60+
; SIMD128-NEXT: i8x16.add $push7=, $pop8, $pop2
61+
; SIMD128-NEXT: local.tee $push6=, $0=, $pop7
62+
; SIMD128-NEXT: i8x16.shuffle $push3=, $0, $0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
63+
; SIMD128-NEXT: i8x16.add $push4=, $pop6, $pop3
64+
; SIMD128-NEXT: i8x16.extract_lane_u $push5=, $pop4, 0
65+
; SIMD128-NEXT: return $pop5
66+
%res = tail call i8 @llvm.vector.reduce.add.i8.v16i8(<16 x i8> %arg)
67+
ret i8 %res
68+
}
69+
70+
define double @pairwise_add_v2f64(<2 x double> %arg) {
71+
; SIMD128-LABEL: pairwise_add_v2f64:
72+
; SIMD128: .functype pairwise_add_v2f64 (v128) -> (f64)
73+
; SIMD128-NEXT: # %bb.0:
74+
; SIMD128-NEXT: f64x2.extract_lane $push1=, $0, 0
75+
; SIMD128-NEXT: f64x2.extract_lane $push0=, $0, 1
76+
; SIMD128-NEXT: f64.add $push2=, $pop1, $pop0
77+
; SIMD128-NEXT: return $pop2
78+
%res = tail call double @llvm.vector.reduce.fadd.f64.v2f64(double -0.0, <2 x double> %arg)
79+
ret double%res
80+
}
81+
82+
define double @pairwise_add_v2f64_fast(<2 x double> %arg) {
83+
; SIMD128-LABEL: pairwise_add_v2f64_fast:
84+
; SIMD128: .functype pairwise_add_v2f64_fast (v128) -> (f64)
85+
; SIMD128-NEXT: # %bb.0:
86+
; SIMD128-NEXT: i8x16.shuffle $push0=, $0, $0, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7
87+
; SIMD128-NEXT: f64x2.add $push1=, $0, $pop0
88+
; SIMD128-NEXT: f64x2.extract_lane $push2=, $pop1, 0
89+
; SIMD128-NEXT: return $pop2
90+
%res = tail call fast double @llvm.vector.reduce.fadd.f64.v2f64(double -0.0, <2 x double> %arg)
91+
ret double%res
92+
}
93+
94+
define float @pairwise_add_v4f32(<4 x float> %arg) {
95+
; SIMD128-LABEL: pairwise_add_v4f32:
96+
; SIMD128: .functype pairwise_add_v4f32 (v128) -> (f32)
97+
; SIMD128-NEXT: # %bb.0:
98+
; SIMD128-NEXT: f32x4.extract_lane $push1=, $0, 0
99+
; SIMD128-NEXT: f32x4.extract_lane $push0=, $0, 1
100+
; SIMD128-NEXT: f32.add $push2=, $pop1, $pop0
101+
; SIMD128-NEXT: f32x4.extract_lane $push3=, $0, 2
102+
; SIMD128-NEXT: f32.add $push4=, $pop2, $pop3
103+
; SIMD128-NEXT: f32x4.extract_lane $push5=, $0, 3
104+
; SIMD128-NEXT: f32.add $push6=, $pop4, $pop5
105+
; SIMD128-NEXT: return $pop6
106+
%res = tail call float @llvm.vector.reduce.fadd.f32.v4f32(float -0.0, <4 x float> %arg)
107+
ret float %res
108+
}
109+
110+
define float @pairwise_add_v4f32_fast(<4 x float> %arg) {
111+
; SIMD128-LABEL: pairwise_add_v4f32_fast:
112+
; SIMD128: .functype pairwise_add_v4f32_fast (v128) -> (f32)
113+
; SIMD128-NEXT: # %bb.0:
114+
; SIMD128-NEXT: i8x16.shuffle $push0=, $0, $0, 4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3
115+
; SIMD128-NEXT: f32x4.add $push5=, $0, $pop0
116+
; SIMD128-NEXT: local.tee $push4=, $0=, $pop5
117+
; SIMD128-NEXT: i8x16.shuffle $push1=, $0, $0, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3
118+
; SIMD128-NEXT: f32x4.add $push2=, $pop4, $pop1
119+
; SIMD128-NEXT: f32x4.extract_lane $push3=, $pop2, 0
120+
; SIMD128-NEXT: return $pop3
121+
%res = tail call fast float @llvm.vector.reduce.fadd.f32.v4f32(float -0.0, <4 x float> %arg)
122+
ret float %res
123+
}
124+
125+
declare i64 @llvm.vector.reduce.add.i64.v4i64(<2 x i64>)
126+
declare i32 @llvm.vector.reduce.add.i32.v4i32(<4 x i32>)
127+
declare i16 @llvm.vector.reduce.add.i16.v8i16(<8 x i16>)
128+
declare i8 @llvm.vector.reduce.add.i8.v16i8(<16 x i8>)
129+
declare double @llvm.vector.reduce.fadd.f64.v2f64(double, <2 x double>)
130+
declare float @llvm.vector.reduce.fadd.f32.v4f32(float, <4 x float>)

0 commit comments

Comments
 (0)