Skip to content

Commit df2de13

Browse files
authored
[WebAssembly] Autovec support for dot (#123207)
Enable the use of partial.reduce.add that we can lower to dot or a tree of (add (extmul_low_u, extmul_high_u)) for the unsigned case. We support both v8i16 and v16i8 inputs.
1 parent f7f3dfc commit df2de13

File tree

7 files changed

+340
-1
lines changed

7 files changed

+340
-1
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISD.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ HANDLE_NODETYPE(Wrapper)
2626
HANDLE_NODETYPE(WrapperREL)
2727
HANDLE_NODETYPE(BR_IF)
2828
HANDLE_NODETYPE(BR_TABLE)
29+
HANDLE_NODETYPE(DOT)
2930
HANDLE_NODETYPE(SHUFFLE)
3031
HANDLE_NODETYPE(SWIZZLE)
3132
HANDLE_NODETYPE(VEC_SHL)

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/IR/DiagnosticInfo.h"
3030
#include "llvm/IR/DiagnosticPrinter.h"
3131
#include "llvm/IR/Function.h"
32+
#include "llvm/IR/IntrinsicInst.h"
3233
#include "llvm/IR/Intrinsics.h"
3334
#include "llvm/IR/IntrinsicsWebAssembly.h"
3435
#include "llvm/Support/ErrorHandling.h"
@@ -177,6 +178,10 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
177178

178179
// SIMD-specific configuration
179180
if (Subtarget->hasSIMD128()) {
181+
182+
// Combine partial.reduce.add before legalization gets confused.
183+
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
184+
180185
// Combine vector mask reductions into alltrue/anytrue
181186
setTargetDAGCombine(ISD::SETCC);
182187

@@ -406,6 +411,35 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
406411
return TargetLowering::getPointerMemTy(DL, AS);
407412
}
408413

414+
bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
415+
const IntrinsicInst *I) const {
416+
if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
417+
return true;
418+
419+
EVT VT = EVT::getEVT(I->getType());
420+
auto Op1 = I->getOperand(1);
421+
422+
if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
423+
if (InstructionOpcodeToISD(InputInst->getOpcode()) != ISD::MUL)
424+
return true;
425+
426+
if (isa<Instruction>(InputInst->getOperand(0)) &&
427+
isa<Instruction>(InputInst->getOperand(1))) {
428+
// dot only supports signed inputs but also support lowering unsigned.
429+
if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
430+
cast<Instruction>(InputInst->getOperand(1))->getOpcode())
431+
return true;
432+
433+
EVT Op1VT = EVT::getEVT(Op1->getType());
434+
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
435+
((VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()) ||
436+
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
437+
return false;
438+
}
439+
}
440+
return true;
441+
}
442+
409443
TargetLowering::AtomicExpansionKind
410444
WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
411445
// We have wasm instructions for these
@@ -2030,6 +2064,94 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
20302064
MachinePointerInfo(SV));
20312065
}
20322066

2067+
// Try to lower partial.reduce.add to a dot or fallback to a sequence with
2068+
// extmul and adds.
2069+
SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
2070+
assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
2071+
if (N->getConstantOperandVal(0) !=
2072+
Intrinsic::experimental_vector_partial_reduce_add)
2073+
return SDValue();
2074+
2075+
assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
2076+
SDLoc DL(N);
2077+
SDValue Mul = N->getOperand(2);
2078+
assert(Mul->getOpcode() == ISD::MUL && "expected mul input");
2079+
2080+
SDValue ExtendLHS = Mul->getOperand(0);
2081+
SDValue ExtendRHS = Mul->getOperand(1);
2082+
assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
2083+
ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
2084+
"expected widening mul");
2085+
assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
2086+
"expected mul to use the same extend for both operands");
2087+
2088+
SDValue ExtendInLHS = ExtendLHS->getOperand(0);
2089+
SDValue ExtendInRHS = ExtendRHS->getOperand(0);
2090+
bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
2091+
2092+
if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
2093+
if (IsSigned) {
2094+
// i32x4.dot_i16x8_s
2095+
SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
2096+
ExtendInLHS, ExtendInRHS);
2097+
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
2098+
}
2099+
2100+
unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
2101+
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
2102+
2103+
// (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
2104+
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInLHS);
2105+
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInRHS);
2106+
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInLHS);
2107+
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInRHS);
2108+
2109+
SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS);
2110+
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS);
2111+
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh);
2112+
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
2113+
} else {
2114+
assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
2115+
"expected v16i8 input types");
2116+
// Lower to a wider tree, using twice the operations compared to above.
2117+
if (IsSigned) {
2118+
// Use two dots
2119+
unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S;
2120+
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S;
2121+
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
2122+
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
2123+
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
2124+
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
2125+
SDValue DotLHS =
2126+
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
2127+
SDValue DotRHS =
2128+
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
2129+
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
2130+
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
2131+
}
2132+
2133+
unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
2134+
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
2135+
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
2136+
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
2137+
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
2138+
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
2139+
2140+
SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
2141+
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
2142+
2143+
SDValue LowLow = DAG.getNode(LowOpc, DL, MVT::v4i32, MulLow);
2144+
SDValue LowHigh = DAG.getNode(LowOpc, DL, MVT::v4i32, MulHigh);
2145+
SDValue HighLow = DAG.getNode(HighOpc, DL, MVT::v4i32, MulLow);
2146+
SDValue HighHigh = DAG.getNode(HighOpc, DL, MVT::v4i32, MulHigh);
2147+
2148+
SDValue AddLow = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowLow, HighLow);
2149+
SDValue AddHigh = DAG.getNode(ISD::ADD, DL, MVT::v4i32, LowHigh, HighHigh);
2150+
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
2151+
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
2152+
}
2153+
}
2154+
20332155
SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
20342156
SelectionDAG &DAG) const {
20352157
MachineFunction &MF = DAG.getMachineFunction();
@@ -3126,5 +3248,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
31263248
return performVectorTruncZeroCombine(N, DCI);
31273249
case ISD::TRUNCATE:
31283250
return performTruncateCombine(N, DCI);
3251+
case ISD::INTRINSIC_WO_CHAIN:
3252+
return performLowerPartialReduction(N, DCI.DAG);
31293253
}
31303254
}

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class WebAssemblyTargetLowering final : public TargetLowering {
4545
/// right decision when generating code for different targets.
4646
const WebAssemblySubtarget *Subtarget;
4747

48+
bool
49+
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
4850
AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override;
4951
bool shouldScalarizeBinop(SDValue VecOp) const override;
5052
FastISel *createFastISel(FunctionLoweringInfo &FuncInfo,

llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,11 +1147,15 @@ def : Pat<(wasm_shr_u
11471147
}
11481148

11491149
// Widening dot product: i32x4.dot_i16x8_s
1150+
def dot_t : SDTypeProfile<1, 2, [SDTCisVT<0, v4i32>, SDTCisVT<1, v8i16>, SDTCisVT<2, v8i16>]>;
1151+
def wasm_dot : SDNode<"WebAssemblyISD::DOT", dot_t>;
11501152
let isCommutable = 1 in
11511153
defm DOT : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs), (outs), (ins),
11521154
[(set V128:$dst, (int_wasm_dot V128:$lhs, V128:$rhs))],
11531155
"i32x4.dot_i16x8_s\t$dst, $lhs, $rhs", "i32x4.dot_i16x8_s",
11541156
186>;
1157+
def : Pat<(wasm_dot V128:$lhs, V128:$rhs),
1158+
(DOT $lhs, $rhs)>;
11551159

11561160
// Extending multiplication: extmul_{low,high}_P, extmul_high
11571161
def extend_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;

llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,53 @@ WebAssemblyTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
196196
return Cost;
197197
}
198198

199+
InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
200+
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
201+
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
202+
TTI::PartialReductionExtendKind OpBExtend,
203+
std::optional<unsigned> BinOp) const {
204+
InstructionCost Invalid = InstructionCost::getInvalid();
205+
if (!VF.isFixed() || !ST->hasSIMD128())
206+
return Invalid;
207+
208+
InstructionCost Cost(TTI::TCC_Basic);
209+
210+
// Possible options:
211+
// - i16x8.extadd_pairwise_i8x16_sx
212+
// - i32x4.extadd_pairwise_i16x8_sx
213+
// - i32x4.dot_i16x8_s
214+
// Only try to support dot, for now.
215+
216+
if (Opcode != Instruction::Add)
217+
return Invalid;
218+
219+
if (!BinOp || *BinOp != Instruction::Mul)
220+
return Invalid;
221+
222+
if (InputTypeA != InputTypeB)
223+
return Invalid;
224+
225+
if (OpAExtend != OpBExtend)
226+
return Invalid;
227+
228+
EVT InputEVT = EVT::getEVT(InputTypeA);
229+
EVT AccumEVT = EVT::getEVT(AccumType);
230+
231+
// TODO: Add i64 accumulator.
232+
if (AccumEVT != MVT::i32)
233+
return Invalid;
234+
235+
// Signed inputs can lower to dot
236+
if (InputEVT == MVT::i16 && VF.getFixedValue() == 8)
237+
return OpAExtend == TTI::PR_SignExtend ? Cost : Cost * 2;
238+
239+
// Double the size of the lowered sequence.
240+
if (InputEVT == MVT::i8 && VF.getFixedValue() == 16)
241+
return OpAExtend == TTI::PR_SignExtend ? Cost * 2 : Cost * 4;
242+
243+
return Invalid;
244+
}
245+
199246
TTI::ReductionShuffle WebAssemblyTTIImpl::getPreferredExpandedReductionShuffle(
200247
const IntrinsicInst *II) const {
201248

llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,12 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
7878
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
7979
TTI::TargetCostKind CostKind,
8080
unsigned Index, Value *Op0, Value *Op1);
81-
81+
InstructionCost
82+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
83+
Type *AccumType, ElementCount VF,
84+
TTI::PartialReductionExtendKind OpAExtend,
85+
TTI::PartialReductionExtendKind OpBExtend,
86+
std::optional<unsigned> BinOp = std::nullopt) const;
8287
TTI::ReductionShuffle
8388
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const;
8489

0 commit comments

Comments
 (0)